diff --git a/scripts/benchmark.py b/scripts/benchmark.py index 3ce61cf49..6e90d0db3 100755 --- a/scripts/benchmark.py +++ b/scripts/benchmark.py @@ -1,6 +1,5 @@ #!/usr/bin/env python -import argparse import datetime import json import os @@ -16,6 +15,7 @@ from pathlib import Path import git import lox from rich.console import Console +import typer from aider import models from aider.coders import Coder @@ -28,66 +28,29 @@ assert BENCHMARK_DNAME.exists() and BENCHMARK_DNAME.is_dir() ORIGINAL_DNAME = BENCHMARK_DNAME / "practice/." assert ORIGINAL_DNAME.exists() and ORIGINAL_DNAME.is_dir() +app = typer.Typer() -def main(): + +@app.command() +def main( + dirname: str = typer.Argument(..., help="Directory name"), + model: str = typer.Option("gpt-3.5-turbo", "--model", "-m", help="Model name"), + edit_format: str = typer.Option(None, "--edit-format", "-e", help="Edit format"), + keyword: str = typer.Option(None, "--keyword", "-k", help="Only run tests that contain keyword"), + clean: bool = typer.Option(False, "--clean", "-c", help="Discard the current testdir and make a clean copy"), + no_test: bool = typer.Option(False, "--no-test", help="Do not run tests"), + verbose: bool = typer.Option(False, "--verbose", "-v", help="Verbose output"), + stats_only: bool = typer.Option(False, "--stats-only", "-s", help="Do not run tests, just collect stats on completed tests"), + retries: int = typer.Option(2, "--retries", "-r", help="Number of retries for running tests"), + threads: int = typer.Option(1, "--threads", "-t", help="Number of threads to run in parallel"), + num_tests: int = typer.Option(-1, "--num-tests", "-n", help="Number of tests to run"), +): repo = git.Repo(search_parent_directories=True) commit_hash = repo.head.object.hexsha[:7] if repo.is_dirty(): commit_hash += "-dirty" - parser = argparse.ArgumentParser(description="Aider Benchmark") - parser.add_argument("dirname", type=str, help="Directory name") - parser.add_argument("--model", "-m", type=str, help="Model name", default="gpt-3.5-turbo") - parser.add_argument("--edit-format", "-e", type=str, help="Edit format") - parser.add_argument("--keyword", "-k", type=str, help="Only run tests that contain keyword") - parser.add_argument( - "--clean", - "-c", - action="store_true", - help="Discard the current testdir and make a clean copy", - ) - parser.add_argument( - "--no-test", - action="store_true", - help="Do not run tests", - ) - parser.add_argument( - "--verbose", - "-v", - action="store_true", - help="Verbose output", - ) - parser.add_argument( - "--stats-only", - "-s", - action="store_true", - help="Do not run tests, just collect stats on completed tests", - ) - parser.add_argument( - "--retries", - "-r", - type=int, - help="Number of retries for running tests", - default=2, - ) - parser.add_argument( - "--threads", - "-t", - type=int, - help="Number of threads to run in parallel", - default=1, - ) - parser.add_argument( - "--num-tests", - "-n", - type=int, - help="Number of tests to run", - default=-1, - ) - - args = parser.parse_args() - - dirname = Path(args.dirname) + dirname = Path(dirname) if len(dirname.parts) == 1: dirname = BENCHMARK_DNAME / dirname @@ -101,7 +64,7 @@ def main(): dump(dirname) - if args.clean and dirname.exists(): + if clean and dirname.exists(): print("Cleaning up and replacing", dirname) dir_files = set(fn.name for fn in dirname.glob("*")) original_files = set(fn.name for fn in ORIGINAL_DNAME.glob("*")) @@ -120,46 +83,46 @@ def main(): test_dnames = sorted(os.listdir(dirname)) - if args.keyword: - test_dnames = [dn for dn in test_dnames if args.keyword in dn] + if keyword: + test_dnames = [dn for dn in test_dnames if keyword in dn] random.shuffle(test_dnames) - if args.num_tests > 0: - test_dnames = test_dnames[: args.num_tests] + if num_tests > 0: + test_dnames = test_dnames[:num_tests] - if args.threads == 1: + if threads == 1: all_results = [] for testname in test_dnames: results = run_test( dirname / testname, - args.model, - args.edit_format, - args.retries, - args.no_test, - args.verbose, - args.stats_only, + model, + edit_format, + retries, + no_test, + verbose, + stats_only, commit_hash, ) all_results.append(results) - if not args.stats_only: + if not stats_only: summarize_results(dirname) else: - run_test_threaded = lox.thread(args.threads)(run_test) + run_test_threaded = lox.thread(threads)(run_test) for testname in test_dnames: run_test_threaded.scatter( dirname / testname, - args.model, - args.edit_format, - args.retries, - args.no_test, - args.verbose, - args.stats_only, + model, + edit_format, + retries, + no_test, + verbose, + stats_only, commit_hash, ) all_results = run_test_threaded.gather(tqdm=True) - if not args.stats_only: + if not stats_only: print() print() print() @@ -406,4 +369,4 @@ def run_pytests(testdir, history_fname): if __name__ == "__main__": - main() + app()