[bin] Add more validation and checks to Stowage

This commit is contained in:
2025-12-13 13:57:12 +00:00
parent 1bb76e4b24
commit 2c404cf540

View File

@@ -52,6 +52,10 @@ def add(args: argparse.Namespace) -> None:
dest_path = package / rest dest_path = package / rest
dest_dir = dest_path.parent dest_dir = dest_path.parent
if dest_path.exists():
logger.error("file already exists in package: %s", dest_path)
sys.exit(1)
if not dest_dir.exists(): if not dest_dir.exists():
logger.info("DIR %s", dest_dir) logger.info("DIR %s", dest_dir)
if not args.dry_run: if not args.dry_run:
@@ -65,7 +69,10 @@ def add(args: argparse.Namespace) -> None:
except OSError as e: except OSError as e:
logger.error("failed to create symlink: %s", e) logger.error("failed to create symlink: %s", e)
# Attempt to restore the file # Attempt to restore the file
try:
shutil.move(str(dest_path), str(file_path)) shutil.move(str(dest_path), str(file_path))
except Exception as restore_error:
logger.error("failed to restore file: %s", restore_error)
sys.exit(1) sys.exit(1)
@@ -96,8 +103,9 @@ def install(args: argparse.Namespace, is_excluded: Callable[[str], bool]) -> Non
dest.unlink() dest.unlink()
# Make directory # Make directory
if not dest.exists():
logger.info("DIR %s", dest) logger.info("DIR %s", dest)
if not args.dry_run and not dest.exists(): if not args.dry_run:
dest.mkdir(parents=True, mode=0o755, exist_ok=True) dest.mkdir(parents=True, mode=0o755, exist_ok=True)
# Process files # Process files
@@ -111,7 +119,7 @@ def install(args: argparse.Namespace, is_excluded: Callable[[str], bool]) -> Non
continue continue
# Does the file already exist? # Does the file already exist?
if dest_path.is_file(): if dest_path.is_file() or dest_path.is_symlink():
logger.info("UNLINK %s", dest_path) logger.info("UNLINK %s", dest_path)
if not args.dry_run: if not args.dry_run:
dest_path.unlink() dest_path.unlink()
@@ -119,7 +127,10 @@ def install(args: argparse.Namespace, is_excluded: Callable[[str], bool]) -> Non
# Link the file # Link the file
logger.info("LINK %s %s", src_path, dest_path) logger.info("LINK %s %s", src_path, dest_path)
if not args.dry_run: if not args.dry_run:
try:
dest_path.symlink_to(src_path) dest_path.symlink_to(src_path)
except OSError as e:
logger.error("failed to create symlink %s: %s", dest_path, e)
def uninstall(args: argparse.Namespace, is_excluded: Callable[[str], bool]) -> None: def uninstall(args: argparse.Namespace, is_excluded: Callable[[str], bool]) -> None:
@@ -145,25 +156,34 @@ def uninstall(args: argparse.Namespace, is_excluded: Callable[[str], bool]) -> N
for filename in files: for filename in files:
dest_path = dest / filename dest_path = dest / filename
if not dest_path.exists():
logger.debug("does not exist: %s", dest_path)
continue
if dest_path.is_symlink(): if dest_path.is_symlink():
src_path = (root_path / filename).resolve() src_path = (root_path / filename).resolve()
try:
if dest_path.resolve() == src_path: if dest_path.resolve() == src_path:
logger.info("UNLINK %s", dest_path) logger.info("UNLINK %s", dest_path)
if not args.dry_run: if not args.dry_run:
dest_path.unlink() dest_path.unlink()
else: else:
logger.info("SKIP %s", dest_path) logger.info("SKIP %s (points elsewhere)", dest_path)
except (OSError, RuntimeError) as e:
logger.warning("error checking symlink %s: %s", dest_path, e)
else: else:
logger.info("SKIP %s", dest_path) logger.info("SKIP %s (not a symlink)", dest_path)
# Delete the directories if empty. # Delete the directories if empty.
for dir_path in sorted(dirs, key=lambda p: len(str(p)), reverse=True): for dir_path in sorted(dirs, key=lambda p: len(str(p)), reverse=True):
if not dir_path.exists():
continue
try: try:
logger.info("RMDIR %s", dir_path) logger.info("RMDIR %s", dir_path)
if not args.dry_run: if not args.dry_run:
dir_path.rmdir() dir_path.rmdir()
except OSError: except OSError:
pass logger.debug("directory not empty: %s", dir_path)
def make_argparser() -> argparse.ArgumentParser: def make_argparser() -> argparse.ArgumentParser:
@@ -237,6 +257,18 @@ def main() -> None:
stream=sys.stdout stream=sys.stdout
) )
# Validate repository exists
repo_path = Path(args.repository)
if not repo_path.exists():
logger.error("repository does not exist: %s", args.repository)
sys.exit(1)
# Validate target exists
target_path = Path(args.target)
if not target_path.exists():
logger.error("target directory does not exist: %s", args.target)
sys.exit(1)
exclude = [re.compile(fnmatch.translate(pattern)) for pattern in args.exclude] exclude = [re.compile(fnmatch.translate(pattern)) for pattern in args.exclude]
def is_excluded(filename: str) -> bool: def is_excluded(filename: str) -> bool:
@@ -244,14 +276,25 @@ def main() -> None:
return any(pattern.match(filename) for pattern in exclude) return any(pattern.match(filename) for pattern in exclude)
if args.command == "list": if args.command == "list":
repo_path = Path(args.repository) if not repo_path.is_dir():
for item in sorted(repo_path.iterdir()): logger.error("repository is not a directory: %s", args.repository)
if item.is_dir() and not item.name.startswith("."): sys.exit(1)
print(item.name) packages = sorted([
item.name for item in repo_path.iterdir()
if item.is_dir() and not item.name.startswith(".")
])
if packages:
for package in packages:
print(package)
else:
logger.info("no packages found in repository")
elif args.command == "add": elif args.command == "add":
if len(args.packages) > 1: if len(args.packages) > 1:
parser.error("add only works with a single package") parser.error("add only works with a single package")
file_path = Path(args.target, args.file) file_path = Path(args.file)
# Handle both absolute and relative paths
if not file_path.is_absolute():
file_path = Path(args.target) / file_path
if not file_path.is_file(): if not file_path.is_file():
parser.error(f"no such file: {args.file}") parser.error(f"no such file: {args.file}")
args.file = str(file_path) args.file = str(file_path)