added --threads for speed

This commit is contained in:
Paul Gauthier 2023-06-24 20:41:35 -07:00
parent 5b4fbd9901
commit 5c3a635eed

View file

@ -2,12 +2,15 @@ import argparse
import datetime import datetime
import json import json
import os import os
import random
import shutil import shutil
import subprocess import subprocess
import time import time
from json.decoder import JSONDecodeError from json.decoder import JSONDecodeError
from pathlib import Path from pathlib import Path
import lox
from aider import models from aider import models
from aider.coders import Coder from aider.coders import Coder
from aider.dump import dump # noqa: F401 from aider.dump import dump # noqa: F401
@ -47,6 +50,13 @@ def main():
help="Number of retries for running tests", help="Number of retries for running tests",
default=2, default=2,
) )
parser.add_argument(
"--threads",
"-t",
type=int,
help="Number of threads to run in parallel",
default=1,
)
args = parser.parse_args() args = parser.parse_args()
@ -72,6 +82,7 @@ def main():
if args.keyword: if args.keyword:
test_dnames = [dn for dn in test_dnames if args.keyword in dn] test_dnames = [dn for dn in test_dnames if args.keyword in dn]
if args.threads == 1:
all_results = [] all_results = []
for testname in test_dnames: for testname in test_dnames:
results = run_test( results = run_test(
@ -85,6 +96,23 @@ def main():
all_results.append(results) all_results.append(results)
summarize_results(all_results) summarize_results(all_results)
else:
random.shuffle(test_dnames)
run_test_threaded = lox.thread(args.threads)(run_test)
for testname in test_dnames:
run_test_threaded.scatter(
dirname / testname,
args.model,
args.edit_format,
args.retries,
args.no_test,
args.verbose,
)
all_results = run_test_threaded.gather(tqdm=True)
print()
print()
print()
summarize_results(all_results)
def summarize_results(all_results, total_tests=None): def summarize_results(all_results, total_tests=None):