diff --git a/scripts/benchmark.py b/scripts/benchmark.py index 7cb0d7539..a1a950867 100644 --- a/scripts/benchmark.py +++ b/scripts/benchmark.py @@ -173,18 +173,28 @@ def run_test(testdir, model_name, edit_format, retries): dur = 0 test_outcomes = [] for i in range(retries): - def run_coder(): - coder.run(with_message=instructions) + def run_coder(stop_event): + try: + coder.run(with_message=instructions) + except Exception as e: + if stop_event.is_set(): + print("Thread stopped due to timeout.") + else: + raise e start = time.time() - coder_thread = threading.Thread(target=run_coder) + stop_event = threading.Event() + coder_thread = threading.Thread(target=run_coder, args=(stop_event,)) coder_thread.start() coder_thread.join(60) # 60 seconds timeout if coder_thread.is_alive(): + stop_event.set() + coder_thread.join() # Wait for the thread to exit gracefully + print("coder.run took longer than 60 seconds and was stopped.") # Handle the case when the coder.run call takes longer than 60 seconds # You can raise an exception or handle it accordingly - raise Exception("coder.run took longer than 60 seconds") + # raise Exception("coder.run took longer than 60 seconds") dur += time.time() - start