diff --git a/scripts/benchmark.py b/scripts/benchmark.py index 3dd801a9d..0c1fe3e42 100755 --- a/scripts/benchmark.py +++ b/scripts/benchmark.py @@ -93,7 +93,7 @@ def main(): dirname = BENCHMARK_DNAME / dirname now = datetime.datetime.now() - now = now.strftime("%Y-%m-%d-%H-%M-%S-") + now = now.strftime("%Y-%m-%d-%H-%M-%S--") if not dirname.exists(): if not re.match(r"\d\d\d\d-\d\d-\d\d-", str(dirname)): @@ -180,6 +180,8 @@ def summarize_results(dirname, all_results, total_tests=None): passed_tests = [0] * retries duration = 0 total_cost = 0 + total_error_outputs = 0 + total_user_asks = 0 variants = defaultdict(set) @@ -196,7 +198,10 @@ def summarize_results(dirname, all_results, total_tests=None): total_cost += results["cost"] duration += results["duration"] - for key in "model edit_format commit_hash num_error_outputs num_user_asks".split(): + total_error_outputs += results.get("num_error_outputs", 0) + total_user_asks += results.get("num_user_asks", 0) + + for key in "model edit_format commit_hash".split(): val = results.get(key) variants[key].add(val) @@ -214,6 +219,8 @@ def summarize_results(dirname, all_results, total_tests=None): style = None val = ", ".join(map(str, val)) console.print(f"{key}: {val}", style=style) + print("num_error_outputs:", total_error_outputs) + print("num_user_asks:", total_user_asks) console.print() for i in range(retries):