diff --git a/benchmark/benchmark.py b/benchmark/benchmark.py index 825f77759..e9b0e4425 100755 --- a/benchmark/benchmark.py +++ b/benchmark/benchmark.py @@ -727,8 +727,7 @@ def run_test_real( break try: - # pass test_files in, and choose the test command based on them ai! - errors = run_unit_tests(testdir, history_fname) + errors = run_unit_tests(testdir, history_fname, test_files) except subprocess.TimeoutExpired: errors = "Tests timed out!" timeouts += 1 @@ -787,11 +786,12 @@ def run_test_real( return results -def run_unit_tests(testdir, history_fname): +def run_unit_tests(testdir, history_fname, test_files): timeout = 60 - command = ["pytest"] - command = "cargo test -- --include-ignored".split() + # Choose test command based on test file extensions + is_rust = any(Path(f).suffix == '.rs' for f in test_files) + command = "cargo test -- --include-ignored".split() if is_rust else ["pytest"] print(" ".join(command)) result = subprocess.run(