remove chdir

This commit is contained in:
Paul Gauthier 2023-06-24 20:28:59 -07:00
parent 959861ee2e
commit 37c282c7f3

View file

@ -68,8 +68,6 @@ def main():
if not dirname.exists(): if not dirname.exists():
shutil.copytree(ORIGINAL_DNAME, dirname) shutil.copytree(ORIGINAL_DNAME, dirname)
cwd = os.getcwd()
test_dnames = sorted(os.listdir(dirname)) test_dnames = sorted(os.listdir(dirname))
all_results = [] all_results = []
@ -86,7 +84,6 @@ def main():
args.no_test, args.no_test,
args.verbose, args.verbose,
) )
os.chdir(cwd)
all_results.append(results) all_results.append(results)
summarize_results(all_results) summarize_results(all_results)
@ -147,32 +144,25 @@ def run_test(testdir, model_name, edit_format, retries, no_test, verbose):
print("Not a dir:", testdir) print("Not a dir:", testdir)
return return
os.chdir(testdir) testdir = Path(testdir)
history_fname = Path(".aider.chat.history.md") history_fname = testdir / ".aider.chat.history.md"
results_fname = Path(".aider.results.json") results_fname = testdir / ".aider.results.json"
if results_fname.exists(): if results_fname.exists():
try: try:
return json.loads(results_fname.read_text()) return json.loads(results_fname.read_text())
except JSONDecodeError: except JSONDecodeError:
print(f"{testdir}/{results_fname} failed to parse, skipping") print(f"{results_fname} failed to parse, skipping")
return return
started_fname = Path(".aider.started")
if started_fname.exists():
# print(f"{testdir}/{started_fname} exists, skipping")
# return
pass
started_fname.touch()
fnames = [] fnames = []
for fname in os.listdir("."): for fname in testdir.glob("*"):
if "test" not in fname and os.path.isfile(fname) and fname[0] != ".": if "test" not in fname.name and fname.is_file() and fname.name[0] != ".":
fnames.append(fname) fnames.append(fname)
file_list = " ".join(fnames) file_list = " ".join(fname.name for fname in fnames)
instructions = Path(".docs/instructions.md").read_text() instructions = (testdir / ".docs/instructions.md").read_text()
instructions += ( instructions += (
"\n\n=====\n\nModify these files according to the above instructions. Only use standard" "\n\n=====\n\nModify these files according to the above instructions. Only use standard"
" python libraries, don't suggest installing any packages.\n" " python libraries, don't suggest installing any packages.\n"
@ -190,6 +180,7 @@ def run_test(testdir, model_name, edit_format, retries, no_test, verbose):
dump(main_model) dump(main_model)
dump(edit_format) dump(edit_format)
dump(fnames)
coder = Coder.create( coder = Coder.create(
main_model, main_model,
@ -216,7 +207,7 @@ def run_test(testdir, model_name, edit_format, retries, no_test, verbose):
if no_test: if no_test:
return return
errors = run_tests(history_fname) errors = run_pytests(testdir, history_fname)
if errors: if errors:
test_outcomes.append(False) test_outcomes.append(False)
@ -245,13 +236,12 @@ def run_test(testdir, model_name, edit_format, retries, no_test, verbose):
dump(results) dump(results)
results_fname.write_text(json.dumps(results, indent=4)) results_fname.write_text(json.dumps(results, indent=4))
started_fname.unlink()
return results return results
def run_tests(history_fname): def run_pytests(testdir, history_fname):
test_files = [file for file in os.listdir() if file.endswith("_test.py")] test_files = [file for file in testdir.glob("*") if file.name.endswith("_test.py")]
assert len(test_files) assert len(test_files)
all_tests_passed = True all_tests_passed = True