unified diffs

This commit is contained in:
Paul Gauthier 2023-12-17 12:54:34 -08:00
parent 3aa17c46dd
commit 7113a30271
18 changed files with 243 additions and 96 deletions

View file

@ -2,6 +2,7 @@ from .base_coder import Coder
from .editblock_coder import EditBlockCoder from .editblock_coder import EditBlockCoder
from .editblock_func_coder import EditBlockFunctionCoder from .editblock_func_coder import EditBlockFunctionCoder
from .single_wholefile_func_coder import SingleWholeFileFunctionCoder from .single_wholefile_func_coder import SingleWholeFileFunctionCoder
from .udiff_coder import UnifiedDiffCoder
from .wholefile_coder import WholeFileCoder from .wholefile_coder import WholeFileCoder
from .wholefile_func_coder import WholeFileFunctionCoder from .wholefile_func_coder import WholeFileFunctionCoder
@ -12,4 +13,5 @@ __all__ = [
WholeFileFunctionCoder, WholeFileFunctionCoder,
EditBlockFunctionCoder, EditBlockFunctionCoder,
SingleWholeFileFunctionCoder, SingleWholeFileFunctionCoder,
UnifiedDiffCoder,
] ]

View file

@ -49,7 +49,9 @@ class Coder:
functions = None functions = None
total_cost = 0.0 total_cost = 0.0
num_exhausted_context_windows = 0 num_exhausted_context_windows = 0
num_malformed_responses = 0
last_keyboard_interrupt = None last_keyboard_interrupt = None
max_apply_update_errors = 3
@classmethod @classmethod
def create( def create(
@ -61,7 +63,7 @@ class Coder:
skip_model_availabily_check=False, skip_model_availabily_check=False,
**kwargs, **kwargs,
): ):
from . import EditBlockCoder, WholeFileCoder from . import EditBlockCoder, UnifiedDiffCoder, WholeFileCoder
if not main_model: if not main_model:
main_model = models.GPT4 main_model = models.GPT4
@ -83,6 +85,8 @@ class Coder:
return EditBlockCoder(client, main_model, io, **kwargs) return EditBlockCoder(client, main_model, io, **kwargs)
elif edit_format == "whole": elif edit_format == "whole":
return WholeFileCoder(client, main_model, io, **kwargs) return WholeFileCoder(client, main_model, io, **kwargs)
elif edit_format == "udiff":
return UnifiedDiffCoder(client, main_model, io, **kwargs)
else: else:
raise ValueError(f"Unknown edit format {edit_format}") raise ValueError(f"Unknown edit format {edit_format}")
@ -296,7 +300,13 @@ class Coder:
prompt += "\n" prompt += "\n"
prompt += relative_fname prompt += relative_fname
prompt += f"\n{self.fence[0]}\n" prompt += f"\n{self.fence[0]}\n"
prompt += content prompt += content
# lines = content.splitlines(keepends=True)
# lines = [f"{i+1:03}:{line}" for i, line in enumerate(lines)]
# prompt += "".join(lines)
prompt += f"{self.fence[1]}\n" prompt += f"{self.fence[1]}\n"
return prompt return prompt
@ -346,7 +356,7 @@ class Coder:
new_user_message = self.send_new_user_message(new_user_message) new_user_message = self.send_new_user_message(new_user_message)
if with_message: if with_message:
return return self.partial_response_content
except KeyboardInterrupt: except KeyboardInterrupt:
self.keyboard_interrupt() self.keyboard_interrupt()
@ -456,12 +466,12 @@ class Coder:
# add the reminder anyway # add the reminder anyway
total_tokens = 0 total_tokens = 0
messages += self.cur_messages
# Add the reminder prompt if we still have room to include it. # Add the reminder prompt if we still have room to include it.
if total_tokens < self.main_model.max_context_tokens: if total_tokens < self.main_model.max_context_tokens:
messages += reminder_message messages += reminder_message
messages += self.cur_messages
return messages return messages
def send_new_user_message(self, inp): def send_new_user_message(self, inp):
@ -850,19 +860,19 @@ class Coder:
return set(edit[0] for edit in edits) return set(edit[0] for edit in edits)
def apply_updates(self): def apply_updates(self):
max_apply_update_errors = 3
try: try:
edited = self.update_files() edited = self.update_files()
except ValueError as err: except ValueError as err:
self.num_malformed_responses += 1
err = err.args[0] err = err.args[0]
self.apply_update_errors += 1 self.apply_update_errors += 1
if self.apply_update_errors < max_apply_update_errors: if self.apply_update_errors < self.max_apply_update_errors:
self.io.tool_error(f"Malformed response #{self.apply_update_errors}, retrying...") self.io.tool_error(f"Malformed response #{self.apply_update_errors}, retrying...")
self.io.tool_error(str(err)) self.io.tool_error(str(err))
return None, err return None, err
else: else:
self.io.tool_error(f"Malformed response #{self.apply_update_errors}, aborting.") self.io.tool_error(f"Malformed response #{self.apply_update_errors}, aborting.")
self.io.tool_error(str(err))
return False, None return False, None
except Exception as err: except Exception as err:
@ -870,11 +880,13 @@ class Coder:
print() print()
traceback.print_exc() traceback.print_exc()
self.apply_update_errors += 1 self.apply_update_errors += 1
if self.apply_update_errors < max_apply_update_errors: if self.apply_update_errors < self.max_apply_update_errors:
self.io.tool_error(f"Update exception #{self.apply_update_errors}, retrying...") self.io.tool_error(f"Update exception #{self.apply_update_errors}, retrying...")
self.io.tool_error(str(err))
return None, str(err) return None, str(err)
else: else:
self.io.tool_error(f"Update exception #{self.apply_update_errors}, aborting") self.io.tool_error(f"Update exception #{self.apply_update_errors}, aborting")
self.io.tool_error(str(err))
return False, None return False, None
self.apply_update_errors = 0 self.apply_update_errors = 0

View file

@ -148,7 +148,7 @@ def main(argv=None, input=None, output=None, force_git_root=None):
core_group.add_argument( core_group.add_argument(
"--model", "--model",
metavar="MODEL", metavar="MODEL",
default=models.GPT4.name, default=models.GPT4_0613.name,
help=f"Specify the model to use for the main chat (default: {models.GPT4.name})", help=f"Specify the model to use for the main chat (default: {models.GPT4.name})",
) )
core_group.add_argument( core_group.add_argument(
@ -157,6 +157,14 @@ def main(argv=None, input=None, output=None, force_git_root=None):
default=False, default=False,
help="Override to skip model availability check (default: False)", help="Override to skip model availability check (default: False)",
) )
default_4_turbo_model = models.GPT4_1106_PREVIEW
core_group.add_argument(
"--4-turbo",
action="store_const",
dest="model",
const=default_4_turbo_model.name,
help=f"Use {default_4_turbo_model.name} model for the main chat (gpt-4 is better)",
)
default_3_model = models.GPT35_1106 default_3_model = models.GPT35_1106
core_group.add_argument( core_group.add_argument(
"-3", "-3",
@ -380,7 +388,10 @@ def main(argv=None, input=None, output=None, force_git_root=None):
"--message-file", "--message-file",
"-f", "-f",
metavar="MESSAGE_FILE", metavar="MESSAGE_FILE",
help="Specify a file containing the message to send GPT, process reply, then exit (disables chat mode)", help=(
"Specify a file containing the message to send GPT, process reply, then exit (disables"
" chat mode)"
),
) )
other_group.add_argument( other_group.add_argument(
"--encoding", "--encoding",

View file

@ -3,6 +3,8 @@ from .openai import OpenAIModel
from .openrouter import OpenRouterModel from .openrouter import OpenRouterModel
GPT4 = Model.create("gpt-4") GPT4 = Model.create("gpt-4")
GPT4_0613 = Model.create("gpt-4-0613")
GPT4_1106_PREVIEW = Model.create("gpt-4-1106-preview")
GPT35 = Model.create("gpt-3.5-turbo") GPT35 = Model.create("gpt-3.5-turbo")
GPT35_1106 = Model.create("gpt-3.5-turbo-1106") GPT35_1106 = Model.create("gpt-3.5-turbo-1106")
GPT35_16k = Model.create("gpt-3.5-turbo-16k") GPT35_16k = Model.create("gpt-3.5-turbo-16k")

View file

@ -33,7 +33,11 @@ class OpenAIModel(Model):
self.tokenizer = tiktoken.encoding_for_model(name) self.tokenizer = tiktoken.encoding_for_model(name)
if self.is_gpt4(): if self.is_gpt4():
self.edit_format = "diff" if name == "gpt-4-1106-preview":
self.edit_format = "udiff"
else:
self.edit_format = "diff"
self.use_repo_map = True self.use_repo_map = True
self.send_undo_reply = True self.send_undo_reply = True

View file

@ -1,6 +1,68 @@
import os
import tempfile
from pathlib import Path from pathlib import Path
from .dump import dump # noqa: F401 import git
from aider.dump import dump # noqa: F401
class IgnorantTemporaryDirectory:
def __init__(self):
self.temp_dir = tempfile.TemporaryDirectory()
def __enter__(self):
return self.temp_dir.__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
try:
self.temp_dir.__exit__(exc_type, exc_val, exc_tb)
except (OSError, PermissionError):
pass # Ignore errors (Windows)
class ChdirTemporaryDirectory(IgnorantTemporaryDirectory):
def __init__(self):
try:
self.cwd = os.getcwd()
except FileNotFoundError:
self.cwd = None
super().__init__()
def __enter__(self):
res = super().__enter__()
os.chdir(self.temp_dir.name)
return res
def __exit__(self, exc_type, exc_val, exc_tb):
if self.cwd:
try:
os.chdir(self.cwd)
except FileNotFoundError:
pass
super().__exit__(exc_type, exc_val, exc_tb)
class GitTemporaryDirectory(ChdirTemporaryDirectory):
def __enter__(self):
dname = super().__enter__()
self.repo = make_repo(dname)
return dname
def __exit__(self, exc_type, exc_val, exc_tb):
del self.repo
super().__exit__(exc_type, exc_val, exc_tb)
def make_repo(path=None):
if not path:
path = "."
repo = git.Repo.init(path)
repo.config_writer().set_value("user", "name", "Test User").release()
repo.config_writer().set_value("user", "email", "testuser@example.com").release()
return repo
def safe_abs_path(res): def safe_abs_path(res):

View file

@ -32,7 +32,7 @@ from aider.io import InputOutput
BENCHMARK_DNAME = Path(os.environ.get("AIDER_BENCHMARK_DIR", "tmp.benchmarks")) BENCHMARK_DNAME = Path(os.environ.get("AIDER_BENCHMARK_DIR", "tmp.benchmarks"))
ORIGINAL_DNAME = BENCHMARK_DNAME / "exercism-python" EXERCISES_DIR_DEFAULT = "exercism-python"
app = typer.Typer(add_completion=False, pretty_exceptions_enable=False) app = typer.Typer(add_completion=False, pretty_exceptions_enable=False)
@ -83,9 +83,9 @@ def show_stats(dirnames, graphs):
if row.completed_tests < 133: if row.completed_tests < 133:
print(f"Warning: {row.dir_name} is incomplete: {row.completed_tests}") print(f"Warning: {row.dir_name} is incomplete: {row.completed_tests}")
if "repeat" in row.dir_name: # if "repeat" in row.dir_name:
repeats.append(vars(row)) # repeats.append(vars(row))
continue # continue
kind = (row.model, row.edit_format) kind = (row.model, row.edit_format)
if kind in seen: if kind in seen:
@ -97,6 +97,7 @@ def show_stats(dirnames, graphs):
rows.append(vars(row)) rows.append(vars(row))
if repeats: if repeats:
dump(repeats)
extra = rows[repeat_row] extra = rows[repeat_row]
dump(extra) dump(extra)
repeats.append(extra) repeats.append(extra)
@ -313,6 +314,16 @@ def main(
graphs: bool = typer.Option(False, "--graphs", help="Generate graphs"), graphs: bool = typer.Option(False, "--graphs", help="Generate graphs"),
model: str = typer.Option("gpt-3.5-turbo", "--model", "-m", help="Model name"), model: str = typer.Option("gpt-3.5-turbo", "--model", "-m", help="Model name"),
edit_format: str = typer.Option(None, "--edit-format", "-e", help="Edit format"), edit_format: str = typer.Option(None, "--edit-format", "-e", help="Edit format"),
replay: str = typer.Option(
None,
"--replay",
help="Replay previous .aider.chat.history.md responses from previous benchmark run",
),
max_apply_update_errors: int = typer.Option(
3,
"--max-apply-update-errors",
help="Maximum number of apply update errors before stopping the test",
),
keywords: str = typer.Option( keywords: str = typer.Option(
None, "--keywords", "-k", help="Only run tests that contain keywords (comma sep)" None, "--keywords", "-k", help="Only run tests that contain keywords (comma sep)"
), ),
@ -331,6 +342,9 @@ def main(
tries: int = typer.Option(2, "--tries", "-r", help="Number of tries for running tests"), tries: int = typer.Option(2, "--tries", "-r", help="Number of tries for running tests"),
threads: int = typer.Option(1, "--threads", "-t", help="Number of threads to run in parallel"), threads: int = typer.Option(1, "--threads", "-t", help="Number of threads to run in parallel"),
num_tests: int = typer.Option(-1, "--num-tests", "-n", help="Number of tests to run"), num_tests: int = typer.Option(-1, "--num-tests", "-n", help="Number of tests to run"),
exercises_dir: str = typer.Option(
EXERCISES_DIR_DEFAULT, "--exercises-dir", help="Directory with exercise files"
),
): ):
repo = git.Repo(search_parent_directories=True) repo = git.Repo(search_parent_directories=True)
commit_hash = repo.head.object.hexsha[:7] commit_hash = repo.head.object.hexsha[:7]
@ -363,12 +377,13 @@ def main(
return return
assert BENCHMARK_DNAME.exists() and BENCHMARK_DNAME.is_dir(), BENCHMARK_DNAME assert BENCHMARK_DNAME.exists() and BENCHMARK_DNAME.is_dir(), BENCHMARK_DNAME
assert ORIGINAL_DNAME.exists() and ORIGINAL_DNAME.is_dir(), ORIGINAL_DNAME original_dname = BENCHMARK_DNAME / exercises_dir
assert original_dname.exists() and original_dname.is_dir(), original_dname
if clean and dirname.exists(): if clean and dirname.exists():
print("Cleaning up and replacing", dirname) print("Cleaning up and replacing", dirname)
dir_files = set(fn.name for fn in dirname.glob("*")) dir_files = set(fn.name for fn in dirname.glob("*"))
original_files = set(fn.name for fn in ORIGINAL_DNAME.glob("*")) original_files = set(fn.name for fn in original_dname.glob("*"))
if dir_files != original_files: if dir_files != original_files:
print("ERROR: will not delete dir that does not look like original tests", dirname) print("ERROR: will not delete dir that does not look like original tests", dirname)
return return
@ -381,8 +396,8 @@ def main(
dirname.rename(dest) dirname.rename(dest)
if not dirname.exists(): if not dirname.exists():
print(f"Copying {ORIGINAL_DNAME} -> {dirname} ...") print(f"Copying {original_dname} -> {dirname} ...")
shutil.copytree(ORIGINAL_DNAME, dirname) shutil.copytree(original_dname, dirname)
print("...done") print("...done")
test_dnames = sorted(os.listdir(dirname)) test_dnames = sorted(os.listdir(dirname))
@ -399,6 +414,7 @@ def main(
all_results = [] all_results = []
for testname in test_dnames: for testname in test_dnames:
results = run_test( results = run_test(
original_dname,
dirname / testname, dirname / testname,
model, model,
edit_format, edit_format,
@ -407,6 +423,8 @@ def main(
no_aider, no_aider,
verbose, verbose,
commit_hash, commit_hash,
replay,
max_apply_update_errors,
) )
all_results.append(results) all_results.append(results)
@ -415,6 +433,7 @@ def main(
run_test_threaded = lox.thread(threads)(run_test) run_test_threaded = lox.thread(threads)(run_test)
for testname in test_dnames: for testname in test_dnames:
run_test_threaded.scatter( run_test_threaded.scatter(
original_dname,
dirname / testname, dirname / testname,
model, model,
edit_format, edit_format,
@ -423,6 +442,8 @@ def main(
no_aider, no_aider,
verbose, verbose,
commit_hash, commit_hash,
replay,
max_apply_update_errors,
) )
all_results = run_test_threaded.gather(tqdm=True) all_results = run_test_threaded.gather(tqdm=True)
@ -467,6 +488,7 @@ def show_diffs(dirnames):
changed = set(testcases) - unchanged changed = set(testcases) - unchanged
print() print()
print("changed:", len(changed), ",".join(sorted(changed))) print("changed:", len(changed), ",".join(sorted(changed)))
print()
print("unchanged:", len(unchanged), ",".join(sorted(unchanged))) print("unchanged:", len(unchanged), ",".join(sorted(unchanged)))
@ -498,6 +520,10 @@ def summarize_results(dirname):
res.user_asks = 0 res.user_asks = 0
res.test_timeouts = 0 res.test_timeouts = 0
res.exhausted_context_windows = 0 res.exhausted_context_windows = 0
res.num_malformed_responses = 0
res.syntax_errors = 0
res.indentation_errors = 0
res.lazy_comments = 0
variants = defaultdict(set) variants = defaultdict(set)
@ -518,6 +544,11 @@ def summarize_results(dirname):
res.error_outputs += results.get("num_error_outputs", 0) res.error_outputs += results.get("num_error_outputs", 0)
res.user_asks += results.get("num_user_asks", 0) res.user_asks += results.get("num_user_asks", 0)
res.exhausted_context_windows += results.get("num_exhausted_context_windows", 0) res.exhausted_context_windows += results.get("num_exhausted_context_windows", 0)
res.num_malformed_responses += results.get("num_malformed_responses", 0)
res.lazy_comments += results.get("lazy_comments", 0)
res.syntax_errors += results.get("syntax_errors", 0)
res.indentation_errors += results.get("indentation_errors", 0)
for key in "model edit_format commit_hash".split(): for key in "model edit_format commit_hash".split():
val = results.get(key) val = results.get(key)
@ -526,6 +557,9 @@ def summarize_results(dirname):
if not res.completed_tests: if not res.completed_tests:
return return
# if res.completed_tests < 133:
# return
console = Console(highlight=False) console = Console(highlight=False)
console.rule(title=str(dirname)) console.rule(title=str(dirname))
@ -538,14 +572,22 @@ def summarize_results(dirname):
val = ", ".join(map(str, val)) val = ", ".join(map(str, val))
setattr(res, key, val) setattr(res, key, val)
console.print(f"{key}: {val}", style=style) console.print(f"{key}: {val}", style=style)
print("num_error_outputs:", res.error_outputs)
print("num_user_asks:", res.user_asks)
style = "red" if res.exhausted_context_windows else None def show(stat):
console.print("num_exhausted_context_windows", res.exhausted_context_windows, style=style) val = getattr(res, stat)
style = "red" if val else None
console.print(f"{stat}: {val}", style=style)
style = "red" if res.test_timeouts else None console.print()
console.print("test_timeouts:", res.test_timeouts, style=style) show("error_outputs")
show("user_asks")
show("lazy_comments")
show("num_malformed_responses")
show("syntax_errors")
show("indentation_errors")
console.print()
show("exhausted_context_windows")
show("test_timeouts")
console.print() console.print()
for i in range(tries): for i in range(tries):
@ -573,8 +615,35 @@ def summarize_results(dirname):
return res return res
def get_replayed_content(replay_dname, test_dname):
replay_dname = Path(replay_dname)
test_dname = Path(test_dname)
dump(replay_dname, test_dname)
test_name = test_dname.name
replay_fname = replay_dname / test_name / ".aider.chat.history.md"
dump(replay_fname)
res = replay_fname.read_text()
return res
res = res.splitlines(keepends=True)
res = [line for line in res if not line.startswith("> ") and not line.startswith("#### ")]
return "".join(res)
def run_test( def run_test(
testdir, model_name, edit_format, tries, no_unit_tests, no_aider, verbose, commit_hash original_dname,
testdir,
model_name,
edit_format,
tries,
no_unit_tests,
no_aider,
verbose,
commit_hash,
replay,
max_apply_update_errors,
): ):
if not os.path.isdir(testdir): if not os.path.isdir(testdir):
print("Not a dir:", testdir) print("Not a dir:", testdir)
@ -595,12 +664,17 @@ def run_test(
fnames = [] fnames = []
for fname in testdir.glob("*"): for fname in testdir.glob("*"):
if "test" not in fname.name and fname.is_file() and fname.name[0] != ".": if (
"test" not in fname.name
and fname.is_file()
and fname.name[0] != "."
and fname.suffix == ".py"
):
fnames.append(fname) fnames.append(fname)
# restore the original file, in case we interrupted a prev run # restore the original file, in case we interrupted a prev run
# after it had saved changes # after it had saved changes
original_fname = ORIGINAL_DNAME / testdir.name / fname.name original_fname = original_dname / testdir.name / fname.name
shutil.copy(original_fname, fname) shutil.copy(original_fname, fname)
file_list = " ".join(fname.name for fname in fnames) file_list = " ".join(fname.name for fname in fnames)
@ -644,17 +718,40 @@ def run_test(
pretty=False, pretty=False,
verbose=verbose, verbose=verbose,
) )
coder.max_apply_update_errors = max_apply_update_errors
timeouts = 0 timeouts = 0
syntax_errors = 0
indentation_errors = 0
lazy_comments = 0
dur = 0 dur = 0
test_outcomes = [] test_outcomes = []
for i in range(tries): for i in range(tries):
start = time.time() start = time.time()
if not no_aider: if no_aider:
coder.run(with_message=instructions) pass
elif replay:
response = get_replayed_content(replay, testdir)
coder.partial_response_content = response
show = response.splitlines(keepends=True)
show = [">> " + line for line in show]
io.append_chat_history("".join(show))
coder.apply_updates()
else:
response = coder.run(with_message=instructions)
dur += time.time() - start dur += time.time() - start
if not no_aider:
pat = r"^[+]? *[#].* [.][.][.] "
# Count the number of lines that match pat in response
dump(response)
lazy_comments += len(re.findall(pat, response, re.MULTILINE))
dump(lazy_comments)
if coder.last_keyboard_interrupt: if coder.last_keyboard_interrupt:
raise KeyboardInterrupt raise KeyboardInterrupt
@ -673,7 +770,14 @@ def run_test(
test_outcomes.append(True) test_outcomes.append(True)
break break
if replay:
io.append_chat_history(errors)
errors = errors.splitlines() errors = errors.splitlines()
syntax_errors += sum(1 for line in errors if line.startswith("SyntaxError"))
indentation_errors += sum(1 for line in errors if line.startswith("IndentationError"))
print(errors[-1]) print(errors[-1])
errors = errors[:50] errors = errors[:50]
errors = "\n".join(errors) errors = "\n".join(errors)
@ -693,6 +797,10 @@ def run_test(
num_error_outputs=io.num_error_outputs, num_error_outputs=io.num_error_outputs,
num_user_asks=io.num_user_asks, num_user_asks=io.num_user_asks,
num_exhausted_context_windows=coder.num_exhausted_context_windows, num_exhausted_context_windows=coder.num_exhausted_context_windows,
num_malformed_responses=coder.num_malformed_responses,
syntax_errors=syntax_errors,
indentation_errors=indentation_errors,
lazy_comments=lazy_comments, # Add the count of pattern matches to the results
chat_hashes=list( chat_hashes=list(
zip( zip(
coder.chat_completion_call_hashes, coder.chat_completion_call_hashes,

View file

@ -2,9 +2,9 @@ instructions_addendum = """
#### ####
Use the above instructions to modify the supplied files: {file_list} Use the above instructions to modify the supplied files: {file_list}
Keep and implement the existing function or class stubs, they will be called from unit tests. Don't change the names of existing functions or classes, as they may be referenced from other code like unit tests, etc.
Only use standard python libraries, don't suggest installing any packages. Only use standard python libraries, don't suggest installing any packages.
""" """ # noqa: E501
test_failures = """ test_failures = """

View file

@ -149,7 +149,6 @@ urllib3==2.1.0
virtualenv==20.25.0 virtualenv==20.25.0
# via pre-commit # via pre-commit
wheel==0.42.0 wheel==0.42.0
# via pip-tools
# The following packages are considered to be unsafe in a requirements file: # The following packages are considered to be unsafe in a requirements file:
# pip # pip

View file

@ -19,3 +19,4 @@ packaging
sounddevice sounddevice
soundfile soundfile
PyYAML PyYAML
diff-match-patch

View file

@ -29,6 +29,8 @@ charset-normalizer==3.3.2
# via requests # via requests
configargparse==1.7 configargparse==1.7
# via -r requirements.in # via -r requirements.in
diff-match-patch==20230430
# via -r requirements.in
diskcache==5.6.3 diskcache==5.6.3
# via -r requirements.in # via -r requirements.in
distro==1.8.0 distro==1.8.0

View file

@ -10,7 +10,7 @@ from aider import models
from aider.coders import Coder from aider.coders import Coder
from aider.dump import dump # noqa: F401 from aider.dump import dump # noqa: F401
from aider.io import InputOutput from aider.io import InputOutput
from tests.utils import ChdirTemporaryDirectory, GitTemporaryDirectory from aider.utils import ChdirTemporaryDirectory, GitTemporaryDirectory
class TestCoder(unittest.TestCase): class TestCoder(unittest.TestCase):

View file

@ -14,7 +14,7 @@ from aider.coders import Coder
from aider.commands import Commands from aider.commands import Commands
from aider.dump import dump # noqa: F401 from aider.dump import dump # noqa: F401
from aider.io import InputOutput from aider.io import InputOutput
from tests.utils import ChdirTemporaryDirectory, GitTemporaryDirectory, make_repo from aider.utils import ChdirTemporaryDirectory, GitTemporaryDirectory, make_repo
class TestCommands(TestCase): class TestCommands(TestCase):

View file

@ -4,7 +4,7 @@ from pathlib import Path
from unittest.mock import patch from unittest.mock import patch
from aider.io import AutoCompleter, InputOutput from aider.io import AutoCompleter, InputOutput
from tests.utils import ChdirTemporaryDirectory from aider.utils import ChdirTemporaryDirectory
class TestInputOutput(unittest.TestCase): class TestInputOutput(unittest.TestCase):

View file

@ -13,7 +13,7 @@ from prompt_toolkit.output import DummyOutput
from aider.dump import dump # noqa: F401 from aider.dump import dump # noqa: F401
from aider.io import InputOutput from aider.io import InputOutput
from aider.main import check_gitignore, main, setup_git from aider.main import check_gitignore, main, setup_git
from tests.utils import GitTemporaryDirectory, make_repo from aider.utils import GitTemporaryDirectory, make_repo
class TestMain(TestCase): class TestMain(TestCase):

View file

@ -9,7 +9,7 @@ import git
from aider.dump import dump # noqa: F401 from aider.dump import dump # noqa: F401
from aider.io import InputOutput from aider.io import InputOutput
from aider.repo import GitRepo from aider.repo import GitRepo
from tests.utils import GitTemporaryDirectory from aider.utils import GitTemporaryDirectory
class TestRepo(unittest.TestCase): class TestRepo(unittest.TestCase):

View file

@ -4,7 +4,7 @@ import unittest
from aider.dump import dump # noqa: F401 from aider.dump import dump # noqa: F401
from aider.io import InputOutput from aider.io import InputOutput
from aider.repomap import RepoMap from aider.repomap import RepoMap
from tests.utils import IgnorantTemporaryDirectory from aider.utils import IgnorantTemporaryDirectory
class TestRepoMap(unittest.TestCase): class TestRepoMap(unittest.TestCase):

View file

@ -1,56 +0,0 @@
import os
import tempfile
import git
from aider.dump import dump # noqa: F401
class IgnorantTemporaryDirectory:
def __init__(self):
self.temp_dir = tempfile.TemporaryDirectory()
def __enter__(self):
return self.temp_dir.__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
try:
self.temp_dir.__exit__(exc_type, exc_val, exc_tb)
except (OSError, PermissionError):
pass # Ignore errors (Windows)
class ChdirTemporaryDirectory(IgnorantTemporaryDirectory):
def __init__(self):
self.cwd = os.getcwd()
super().__init__()
def __enter__(self):
res = super().__enter__()
os.chdir(self.temp_dir.name)
return res
def __exit__(self, exc_type, exc_val, exc_tb):
os.chdir(self.cwd)
super().__exit__(exc_type, exc_val, exc_tb)
class GitTemporaryDirectory(ChdirTemporaryDirectory):
def __enter__(self):
res = super().__enter__()
self.repo = make_repo()
return res
def __exit__(self, exc_type, exc_val, exc_tb):
del self.repo
super().__exit__(exc_type, exc_val, exc_tb)
def make_repo(path=None):
if not path:
path = "."
repo = git.Repo.init(path)
repo.config_writer().set_value("user", "name", "Test User").release()
repo.config_writer().set_value("user", "email", "testuser@example.com").release()
return repo