[bin] Improvements to stowage

This commit is contained in:
Andrew Williams
2026-01-08 12:22:53 +00:00
parent f753debbb9
commit 53dc8a090c

View File

@@ -6,7 +6,7 @@ modified by Andrew Williams <https://github.com/nikdoof/>
A dotfile package manager A dotfile package manager
Copyright (c) Keith Gaughan, 2017. Copyright (c) Keith Gaughan, 2017.
Copyright (c) Andrew Williams, 2021. Copyright (c) Andrew Williams, 2021-2025.
Permission is hereby granted, free of charge, to any person obtaining a copy of Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in this software and associated documentation files (the "Software"), to deal in
@@ -32,57 +32,65 @@ import os
import re import re
import shutil import shutil
import sys import sys
from collections.abc import Callable
from pathlib import Path from pathlib import Path
from typing import Callable, List from typing import List
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def add(args: argparse.Namespace) -> None: def add_file_to_package(
file_path: Path, package: str, args: argparse.Namespace
) -> bool:
"""Add a file to a package by moving it and creating a symlink.""" """Add a file to a package by moving it and creating a symlink."""
target = Path(args.target).resolve() target = args.target.resolve()
file_path = Path(args.file).resolve() package = Path(args.repository, package).resolve()
package = Path(args.repository, args.packages[0]).resolve()
# Check the file is under the target directory
if not file_path.is_relative_to(target): if not file_path.is_relative_to(target):
logger.error("'%s' not under '%s'", args.file, args.target) logger.error("'%s' not under '%s'", file_path, args.target)
sys.exit(1) return False
# Calculate the relative path in the package folder, and final destination
# e.g. /home/bin/x -> /home/.dotfiles/package/bin/x
rest = file_path.relative_to(target) rest = file_path.relative_to(target)
dest_path = package / rest dest_path = package / rest
dest_dir = dest_path.parent dest_dir = dest_path.parent
if dest_path.exists(): if dest_path.exists():
logger.error("file already exists in package: %s", dest_path) logger.error("file already exists in package: %s", dest_path)
sys.exit(1) return False
if not dest_dir.exists(): if not dest_dir.exists():
logger.info("DIR %s", dest_dir) logger.info("DIR %s", dest_dir)
if not args.dry_run: if not args.dry_run:
dest_dir.mkdir(parents=True, mode=0o755, exist_ok=True) dest_dir.mkdir(parents=True, mode=0o755, exist_ok=True)
logger.info("SWAP %s %s", dest_path, file_path) logger.info("SWAP %s <-> %s", dest_path, file_path)
if not args.dry_run: if not args.dry_run:
shutil.move(str(file_path), str(dest_path)) shutil.move(str(file_path), str(dest_path))
try: try:
file_path.symlink_to(dest_path) file_path.symlink_to(dest_path)
except OSError as e: except OSError as e:
logger.error("failed to create symlink: %s", e) logger.error("failed to create symlink: %s", e)
# Attempt to restore the file # Attempt to restore the file if symlink creation fails
try: try:
shutil.move(str(dest_path), str(file_path)) shutil.move(str(dest_path), str(file_path))
except Exception as restore_error: except Exception as restore_error:
logger.error("failed to restore file: %s", restore_error) logger.error("failed to restore file: %s", restore_error)
sys.exit(1) return False
return True
def install(args: argparse.Namespace, is_excluded: Callable[[str], bool]) -> None: def install_package(
"""Install packages by creating symlinks from repository to target.""" package: str, args: argparse.Namespace, is_excluded: Callable[[str], bool]
for package in args.packages: ) -> bool:
package_dir = Path(args.repository, package) """Install a package by creating symlinks from repository to target."""
package_dir = args.repository / package
if not package_dir.is_dir(): if not package_dir.is_dir():
logger.warning("no such package: %s; skipping", package) logger.warning("no such package: %s; skipping", package)
continue return False
# Walk the package # Walk the package
for root, _, files in os.walk(package_dir, followlinks=True): for root, _, files in os.walk(package_dir, followlinks=True):
@@ -92,7 +100,7 @@ def install(args: argparse.Namespace, is_excluded: Callable[[str], bool]) -> Non
continue continue
rest = root_path.relative_to(package_dir) rest = root_path.relative_to(package_dir)
dest = Path(args.target) / rest dest = args.target / rest
# Create the directory path # Create the directory path
if rest != Path("."): if rest != Path("."):
@@ -125,22 +133,26 @@ def install(args: argparse.Namespace, is_excluded: Callable[[str], bool]) -> Non
dest_path.unlink() dest_path.unlink()
# Link the file # Link the file
logger.info("LINK %s %s", src_path, dest_path) logger.info("LINK %s -> %s", src_path, dest_path)
if not args.dry_run: if not args.dry_run:
try: try:
dest_path.symlink_to(src_path) dest_path.symlink_to(src_path)
except OSError as e: except OSError as e:
logger.error("failed to create symlink %s: %s", dest_path, e) logger.error("failed to create symlink %s: %s", dest_path, e)
return True
def uninstall(args: argparse.Namespace, is_excluded: Callable[[str], bool]) -> None:
"""Uninstall packages by removing symlinks.""" def uninstall_package(
package: str, args: argparse.Namespace, is_excluded: Callable[[str], bool]
) -> bool:
"""Uninstalls a package by removing symlinks."""
dirs: List[Path] = [] dirs: List[Path] = []
for package in args.packages:
package_dir = Path(args.repository, package) package_dir = args.repository / package
if not package_dir.is_dir(): if not package_dir.is_dir():
logger.warning("no such package: %s; skipping", package) logger.warning("no such package: %s; skipping", package)
continue return False
for root, _, files in os.walk(package_dir, followlinks=True): for root, _, files in os.walk(package_dir, followlinks=True):
root_path = Path(root) root_path = Path(root)
@@ -149,7 +161,7 @@ def uninstall(args: argparse.Namespace, is_excluded: Callable[[str], bool]) -> N
continue continue
rest = root_path.relative_to(package_dir) rest = root_path.relative_to(package_dir)
dest = Path(args.target) / rest dest = args.target / rest
if rest != Path("."): if rest != Path("."):
dirs.append(dest) dirs.append(dest)
@@ -184,6 +196,72 @@ def uninstall(args: argparse.Namespace, is_excluded: Callable[[str], bool]) -> N
dir_path.rmdir() dir_path.rmdir()
except OSError: except OSError:
logger.debug("directory not empty: %s", dir_path) logger.debug("directory not empty: %s", dir_path)
pass
return True
def get_packages(repo_path: Path) -> List[str]:
"""Get a list of packages in the repository."""
packages = []
for entry in repo_path.iterdir():
if entry.is_dir() and not entry.name.startswith("."):
packages.append(entry.name)
return sorted(packages)
def is_package_file(path: Path, package: str, repo_path: Path) -> bool:
"""
Check if a file is part of a package.
This makes a few assumptions:
- The file is a symlink
- The symlink points inside the repository
- Only a single repository is used to link files from
"""
# Is the file a symlink?
if not path.is_symlink():
return False
# Does the symlink point inside the repository?
try:
target_path = path.resolve()
except OSError:
return False
return target_path.is_relative_to(repo_path / package)
def is_broken_symlink(path: Path) -> bool:
"""Check if a path is a broken symlink."""
return path.is_symlink() and not path.exists()
def cleanup_package(package: str, args: argparse.Namespace) -> None:
"""Cleanup any broken symlinks from a installed package.
- discover the directories used in the package
- iterate the directories and remove any broken symlinks that point back to the package
"""
package_dir = args.repository / package
if not package_dir.is_dir():
return
# Walk the source package folder for its structure
for root, _, _ in os.walk(package_dir, followlinks=True):
root_path = Path(root)
# Calculate the path in the target folder
rest = root_path.relative_to(package_dir)
dest = args.target / rest
# Iterate the files in the target folder
for file in os.listdir(dest):
src_path = dest / file
if is_package_file(
src_path, package, args.repository
) and is_broken_symlink(src_path):
logger.info("UNLINK %s", src_path)
if not args.dry_run:
src_path.unlink()
def make_argparser() -> argparse.ArgumentParser: def make_argparser() -> argparse.ArgumentParser:
@@ -194,13 +272,15 @@ def make_argparser() -> argparse.ArgumentParser:
parser.add_argument( parser.add_argument(
"--target", "--target",
"-t", "-t",
default=str(Path.home()), default=Path.home(),
type=Path,
help="Target directory in which to place symlinks", help="Target directory in which to place symlinks",
) )
parser.add_argument( parser.add_argument(
"--repository", "--repository",
"-r", "-r",
default=str(Path.home() / ".dotfiles"), default=Path.home() / ".dotfiles",
type=Path,
help="The location of the dotfile repository", help="The location of the dotfile repository",
) )
parser.add_argument( parser.add_argument(
@@ -223,11 +303,11 @@ def make_argparser() -> argparse.ArgumentParser:
subparsers.add_parser("list", help="List packages in the repository") subparsers.add_parser("list", help="List packages in the repository")
# Add # Add
parser_add = subparsers.add_parser("add", help="Add a file to a package") parser_add = subparsers.add_parser("add", help="Add a files to a package")
parser_add.add_argument("file", metavar="FILE", help="File to stow")
parser_add.add_argument( parser_add.add_argument(
"packages", metavar="PACKAGE", nargs="+", help="Packages to install" "package", metavar="PACKAGE", help="Package to add the files to"
) )
parser_add.add_argument("files", metavar="FILE", nargs="+", help="Files to add")
# Uninstall # Uninstall
parser_uninstall = subparsers.add_parser("uninstall", help="Remove a package") parser_uninstall = subparsers.add_parser("uninstall", help="Remove a package")
@@ -241,6 +321,14 @@ def make_argparser() -> argparse.ArgumentParser:
"packages", metavar="PACKAGE", nargs="+", help="Packages to install" "packages", metavar="PACKAGE", nargs="+", help="Packages to install"
) )
# Cleanup
parser_cleanup = subparsers.add_parser(
"cleanup", help="Cleanup broken symlinks from a package"
)
parser_cleanup.add_argument(
"packages", metavar="PACKAGE", nargs="+", help="Packages to cleanup"
)
return parser return parser
@@ -251,60 +339,74 @@ def main() -> None:
# Configure logging # Configure logging
log_level = logging.INFO if args.verbose or args.dry_run else logging.WARNING log_level = logging.INFO if args.verbose or args.dry_run else logging.WARNING
logging.basicConfig( logging.basicConfig(level=log_level, format="%(message)s", stream=sys.stdout)
level=log_level,
format='%(message)s',
stream=sys.stdout
)
# Validate repository exists # Validate repository exists
repo_path = Path(args.repository) repo_path = args.repository
if not repo_path.exists(): if not repo_path.exists():
logger.error("repository does not exist: %s", args.repository) logger.error("repository (%s) does not exist", args.repository)
sys.exit(1) sys.exit(1)
# Validate target exists # Validate target exists
target_path = Path(args.target) target_path = args.target
if not target_path.exists(): if not target_path.exists():
logger.error("target directory does not exist: %s", args.target) logger.error("target directory (%s) does not exist", args.target)
sys.exit(1) sys.exit(1)
# Compile exclusion patterns
exclude = [re.compile(fnmatch.translate(pattern)) for pattern in args.exclude] exclude = [re.compile(fnmatch.translate(pattern)) for pattern in args.exclude]
def is_excluded(filename: str) -> bool: def is_excluded(filename: str) -> bool:
"""Check if a filename matches any exclusion pattern.""" """Check if a filename matches any exclusion pattern."""
return any(pattern.match(filename) for pattern in exclude) return any(pattern.match(filename) for pattern in exclude)
# Log that we're running in dry-run mode
if args.dry_run:
logger.warning("running in dry-run mode")
# True indicates a successful execution
command_successful = True
if args.command == "list": if args.command == "list":
if not repo_path.is_dir(): if not repo_path.is_dir():
logger.error("repository is not a directory: %s", args.repository) logger.error("repository is not a directory: %s", args.repository)
sys.exit(1) sys.exit(1)
packages = sorted([
item.name for item in repo_path.iterdir() packages = get_packages(repo_path)
if item.is_dir() and not item.name.startswith(".") if len(packages):
]) print(f"Packages in repository: {repo_path}")
if packages:
for package in packages: for package in packages:
print(package) print(f"- {package}")
else: else:
logger.info("no packages found in repository") logger.info("no packages found in repository")
elif args.command == "add": elif args.command == "add":
if len(args.packages) > 1: for fname in args.files:
parser.error("add only works with a single package") file_path = Path(fname)
file_path = Path(args.file)
# Handle both absolute and relative paths # Handle both absolute and relative paths
if not file_path.is_absolute(): if not file_path.is_absolute():
file_path = Path(args.target) / file_path file_path = file_path.resolve()
if not file_path.is_file(): if not add_file_to_package(file_path, args.package, args):
parser.error(f"no such file: {args.file}") command_successful = False
args.file = str(file_path)
add(args)
elif args.command == "install": elif args.command == "install":
install(args, is_excluded) for package in args.packages:
if not install_package(package, args, is_excluded):
command_successful = False
elif args.command == "uninstall": elif args.command == "uninstall":
uninstall(args, is_excluded) for package in args.packages:
if not uninstall_package(package, args, is_excluded):
command_successful = False
elif args.command == "cleanup":
for package in args.packages:
cleanup_package(package, args)
else: else:
parser.print_help() parser.print_help()
if not command_successful:
sys.exit(1) sys.exit(1)