refactor: clarify Path/str types in commands.py

This commit is contained in:
Antti Kaihola 2024-08-29 19:23:24 +03:00
parent fc5c040b83
commit e52d2da740

View file

@ -5,6 +5,7 @@ import sys
import tempfile import tempfile
from collections import OrderedDict from collections import OrderedDict
from pathlib import Path from pathlib import Path
from typing import Generator
import git import git
import pyperclip import pyperclip
@ -527,21 +528,23 @@ class Commands:
try: try:
if os.path.isabs(pattern): if os.path.isabs(pattern):
# Handle absolute paths # Handle absolute paths
raw_matched_files = [Path(pattern)] raw_matched_files: list[Path] = [Path(pattern)]
else: else:
raw_matched_files = list(Path(self.coder.root).glob(pattern)) raw_matched_files: list[Path] = list(
Path(self.coder.root).glob(pattern)
)
except ValueError as err: except ValueError as err:
self.io.tool_error(f"Error matching {pattern}: {err}") self.io.tool_error(f"Error matching {pattern}: {err}")
raw_matched_files = [] raw_matched_files: list[Path] = []
matched_files = [] matched_files: list[Path] = []
for fn in raw_matched_files: for fn in raw_matched_files:
matched_files += expand_subdir(fn) matched_files += expand_subdir(fn)
matched_files = [ matched_files = [
str(Path(fn).relative_to(self.coder.root)) fn.relative_to(self.coder.root)
for fn in matched_files for fn in matched_files
if Path(fn).is_relative_to(self.coder.root) if fn.is_relative_to(self.coder.root)
] ]
# if repo, filter against it # if repo, filter against it
@ -1076,8 +1079,7 @@ class Commands:
self.io.tool_output(settings) self.io.tool_output(settings)
def expand_subdir(file_path): def expand_subdir(file_path: Path) -> Generator[Path, None, None]:
file_path = Path(file_path)
if file_path.is_file(): if file_path.is_file():
yield file_path yield file_path
return return
@ -1085,7 +1087,7 @@ def expand_subdir(file_path):
if file_path.is_dir(): if file_path.is_dir():
for file in file_path.rglob("*"): for file in file_path.rglob("*"):
if file.is_file(): if file.is_file():
yield str(file) yield file
def parse_quoted_filenames(args): def parse_quoted_filenames(args):