From 5c3a635eed613ff4122bc6b22fb26e2ee88473d1 Mon Sep 17 00:00:00 2001 From: Paul Gauthier Date: Sat, 24 Jun 2023 20:41:35 -0700 Subject: [PATCH] added --threads for speed --- scripts/benchmark.py | 50 ++++++++++++++++++++++++++++++++++---------- 1 file changed, 39 insertions(+), 11 deletions(-) diff --git a/scripts/benchmark.py b/scripts/benchmark.py index 53b2d9931..9598d9862 100644 --- a/scripts/benchmark.py +++ b/scripts/benchmark.py @@ -2,12 +2,15 @@ import argparse import datetime import json import os +import random import shutil import subprocess import time from json.decoder import JSONDecodeError from pathlib import Path +import lox + from aider import models from aider.coders import Coder from aider.dump import dump # noqa: F401 @@ -47,6 +50,13 @@ def main(): help="Number of retries for running tests", default=2, ) + parser.add_argument( + "--threads", + "-t", + type=int, + help="Number of threads to run in parallel", + default=1, + ) args = parser.parse_args() @@ -72,18 +82,36 @@ def main(): if args.keyword: test_dnames = [dn for dn in test_dnames if args.keyword in dn] - all_results = [] - for testname in test_dnames: - results = run_test( - dirname / testname, - args.model, - args.edit_format, - args.retries, - args.no_test, - args.verbose, - ) + if args.threads == 1: + all_results = [] + for testname in test_dnames: + results = run_test( + dirname / testname, + args.model, + args.edit_format, + args.retries, + args.no_test, + args.verbose, + ) - all_results.append(results) + all_results.append(results) + summarize_results(all_results) + else: + random.shuffle(test_dnames) + run_test_threaded = lox.thread(args.threads)(run_test) + for testname in test_dnames: + run_test_threaded.scatter( + dirname / testname, + args.model, + args.edit_format, + args.retries, + args.no_test, + args.verbose, + ) + all_results = run_test_threaded.gather(tqdm=True) + print() + print() + print() summarize_results(all_results)