unified diffs

This commit is contained in:
Paul Gauthier 2023-12-17 12:54:34 -08:00
parent 3aa17c46dd
commit 7113a30271
18 changed files with 243 additions and 96 deletions

View file

@ -32,7 +32,7 @@ from aider.io import InputOutput
BENCHMARK_DNAME = Path(os.environ.get("AIDER_BENCHMARK_DIR", "tmp.benchmarks"))
ORIGINAL_DNAME = BENCHMARK_DNAME / "exercism-python"
EXERCISES_DIR_DEFAULT = "exercism-python"
app = typer.Typer(add_completion=False, pretty_exceptions_enable=False)
@ -83,9 +83,9 @@ def show_stats(dirnames, graphs):
if row.completed_tests < 133:
print(f"Warning: {row.dir_name} is incomplete: {row.completed_tests}")
if "repeat" in row.dir_name:
repeats.append(vars(row))
continue
# if "repeat" in row.dir_name:
# repeats.append(vars(row))
# continue
kind = (row.model, row.edit_format)
if kind in seen:
@ -97,6 +97,7 @@ def show_stats(dirnames, graphs):
rows.append(vars(row))
if repeats:
dump(repeats)
extra = rows[repeat_row]
dump(extra)
repeats.append(extra)
@ -313,6 +314,16 @@ def main(
graphs: bool = typer.Option(False, "--graphs", help="Generate graphs"),
model: str = typer.Option("gpt-3.5-turbo", "--model", "-m", help="Model name"),
edit_format: str = typer.Option(None, "--edit-format", "-e", help="Edit format"),
replay: str = typer.Option(
None,
"--replay",
help="Replay previous .aider.chat.history.md responses from previous benchmark run",
),
max_apply_update_errors: int = typer.Option(
3,
"--max-apply-update-errors",
help="Maximum number of apply update errors before stopping the test",
),
keywords: str = typer.Option(
None, "--keywords", "-k", help="Only run tests that contain keywords (comma sep)"
),
@ -331,6 +342,9 @@ def main(
tries: int = typer.Option(2, "--tries", "-r", help="Number of tries for running tests"),
threads: int = typer.Option(1, "--threads", "-t", help="Number of threads to run in parallel"),
num_tests: int = typer.Option(-1, "--num-tests", "-n", help="Number of tests to run"),
exercises_dir: str = typer.Option(
EXERCISES_DIR_DEFAULT, "--exercises-dir", help="Directory with exercise files"
),
):
repo = git.Repo(search_parent_directories=True)
commit_hash = repo.head.object.hexsha[:7]
@ -363,12 +377,13 @@ def main(
return
assert BENCHMARK_DNAME.exists() and BENCHMARK_DNAME.is_dir(), BENCHMARK_DNAME
assert ORIGINAL_DNAME.exists() and ORIGINAL_DNAME.is_dir(), ORIGINAL_DNAME
original_dname = BENCHMARK_DNAME / exercises_dir
assert original_dname.exists() and original_dname.is_dir(), original_dname
if clean and dirname.exists():
print("Cleaning up and replacing", dirname)
dir_files = set(fn.name for fn in dirname.glob("*"))
original_files = set(fn.name for fn in ORIGINAL_DNAME.glob("*"))
original_files = set(fn.name for fn in original_dname.glob("*"))
if dir_files != original_files:
print("ERROR: will not delete dir that does not look like original tests", dirname)
return
@ -381,8 +396,8 @@ def main(
dirname.rename(dest)
if not dirname.exists():
print(f"Copying {ORIGINAL_DNAME} -> {dirname} ...")
shutil.copytree(ORIGINAL_DNAME, dirname)
print(f"Copying {original_dname} -> {dirname} ...")
shutil.copytree(original_dname, dirname)
print("...done")
test_dnames = sorted(os.listdir(dirname))
@ -399,6 +414,7 @@ def main(
all_results = []
for testname in test_dnames:
results = run_test(
original_dname,
dirname / testname,
model,
edit_format,
@ -407,6 +423,8 @@ def main(
no_aider,
verbose,
commit_hash,
replay,
max_apply_update_errors,
)
all_results.append(results)
@ -415,6 +433,7 @@ def main(
run_test_threaded = lox.thread(threads)(run_test)
for testname in test_dnames:
run_test_threaded.scatter(
original_dname,
dirname / testname,
model,
edit_format,
@ -423,6 +442,8 @@ def main(
no_aider,
verbose,
commit_hash,
replay,
max_apply_update_errors,
)
all_results = run_test_threaded.gather(tqdm=True)
@ -467,6 +488,7 @@ def show_diffs(dirnames):
changed = set(testcases) - unchanged
print()
print("changed:", len(changed), ",".join(sorted(changed)))
print()
print("unchanged:", len(unchanged), ",".join(sorted(unchanged)))
@ -498,6 +520,10 @@ def summarize_results(dirname):
res.user_asks = 0
res.test_timeouts = 0
res.exhausted_context_windows = 0
res.num_malformed_responses = 0
res.syntax_errors = 0
res.indentation_errors = 0
res.lazy_comments = 0
variants = defaultdict(set)
@ -518,6 +544,11 @@ def summarize_results(dirname):
res.error_outputs += results.get("num_error_outputs", 0)
res.user_asks += results.get("num_user_asks", 0)
res.exhausted_context_windows += results.get("num_exhausted_context_windows", 0)
res.num_malformed_responses += results.get("num_malformed_responses", 0)
res.lazy_comments += results.get("lazy_comments", 0)
res.syntax_errors += results.get("syntax_errors", 0)
res.indentation_errors += results.get("indentation_errors", 0)
for key in "model edit_format commit_hash".split():
val = results.get(key)
@ -526,6 +557,9 @@ def summarize_results(dirname):
if not res.completed_tests:
return
# if res.completed_tests < 133:
# return
console = Console(highlight=False)
console.rule(title=str(dirname))
@ -538,14 +572,22 @@ def summarize_results(dirname):
val = ", ".join(map(str, val))
setattr(res, key, val)
console.print(f"{key}: {val}", style=style)
print("num_error_outputs:", res.error_outputs)
print("num_user_asks:", res.user_asks)
style = "red" if res.exhausted_context_windows else None
console.print("num_exhausted_context_windows", res.exhausted_context_windows, style=style)
def show(stat):
val = getattr(res, stat)
style = "red" if val else None
console.print(f"{stat}: {val}", style=style)
style = "red" if res.test_timeouts else None
console.print("test_timeouts:", res.test_timeouts, style=style)
console.print()
show("error_outputs")
show("user_asks")
show("lazy_comments")
show("num_malformed_responses")
show("syntax_errors")
show("indentation_errors")
console.print()
show("exhausted_context_windows")
show("test_timeouts")
console.print()
for i in range(tries):
@ -573,8 +615,35 @@ def summarize_results(dirname):
return res
def get_replayed_content(replay_dname, test_dname):
replay_dname = Path(replay_dname)
test_dname = Path(test_dname)
dump(replay_dname, test_dname)
test_name = test_dname.name
replay_fname = replay_dname / test_name / ".aider.chat.history.md"
dump(replay_fname)
res = replay_fname.read_text()
return res
res = res.splitlines(keepends=True)
res = [line for line in res if not line.startswith("> ") and not line.startswith("#### ")]
return "".join(res)
def run_test(
testdir, model_name, edit_format, tries, no_unit_tests, no_aider, verbose, commit_hash
original_dname,
testdir,
model_name,
edit_format,
tries,
no_unit_tests,
no_aider,
verbose,
commit_hash,
replay,
max_apply_update_errors,
):
if not os.path.isdir(testdir):
print("Not a dir:", testdir)
@ -595,12 +664,17 @@ def run_test(
fnames = []
for fname in testdir.glob("*"):
if "test" not in fname.name and fname.is_file() and fname.name[0] != ".":
if (
"test" not in fname.name
and fname.is_file()
and fname.name[0] != "."
and fname.suffix == ".py"
):
fnames.append(fname)
# restore the original file, in case we interrupted a prev run
# after it had saved changes
original_fname = ORIGINAL_DNAME / testdir.name / fname.name
original_fname = original_dname / testdir.name / fname.name
shutil.copy(original_fname, fname)
file_list = " ".join(fname.name for fname in fnames)
@ -644,17 +718,40 @@ def run_test(
pretty=False,
verbose=verbose,
)
coder.max_apply_update_errors = max_apply_update_errors
timeouts = 0
syntax_errors = 0
indentation_errors = 0
lazy_comments = 0
dur = 0
test_outcomes = []
for i in range(tries):
start = time.time()
if not no_aider:
coder.run(with_message=instructions)
if no_aider:
pass
elif replay:
response = get_replayed_content(replay, testdir)
coder.partial_response_content = response
show = response.splitlines(keepends=True)
show = [">> " + line for line in show]
io.append_chat_history("".join(show))
coder.apply_updates()
else:
response = coder.run(with_message=instructions)
dur += time.time() - start
if not no_aider:
pat = r"^[+]? *[#].* [.][.][.] "
# Count the number of lines that match pat in response
dump(response)
lazy_comments += len(re.findall(pat, response, re.MULTILINE))
dump(lazy_comments)
if coder.last_keyboard_interrupt:
raise KeyboardInterrupt
@ -673,7 +770,14 @@ def run_test(
test_outcomes.append(True)
break
if replay:
io.append_chat_history(errors)
errors = errors.splitlines()
syntax_errors += sum(1 for line in errors if line.startswith("SyntaxError"))
indentation_errors += sum(1 for line in errors if line.startswith("IndentationError"))
print(errors[-1])
errors = errors[:50]
errors = "\n".join(errors)
@ -693,6 +797,10 @@ def run_test(
num_error_outputs=io.num_error_outputs,
num_user_asks=io.num_user_asks,
num_exhausted_context_windows=coder.num_exhausted_context_windows,
num_malformed_responses=coder.num_malformed_responses,
syntax_errors=syntax_errors,
indentation_errors=indentation_errors,
lazy_comments=lazy_comments, # Add the count of pattern matches to the results
chat_hashes=list(
zip(
coder.chat_completion_call_hashes,