diff --git a/benchmark/benchmark.py b/benchmark/benchmark.py index 0b04eb014..8979180f8 100755 --- a/benchmark/benchmark.py +++ b/benchmark/benchmark.py @@ -114,12 +114,17 @@ def show_stats(dirnames): # use the average in the main bar rows[repeat_row]["pass_rate_2"] = repeat_avg + else: + repeat_hi = repeat_lo = repeat_avg = None df = pd.DataFrame.from_records(rows) df.sort_values(by=["model", "edit_format"], inplace=True) dump(df) + plot_outcomes(df, repeats, repeat_hi, repeat_lo, repeat_avg) + +def plot_outcomes(df, repeats, repeat_hi, repeat_lo, repeat_avg): tries = [df.groupby(["model", "edit_format"])["pass_rate_2"].mean()] if True: tries += [df.groupby(["model", "edit_format"])["pass_rate_1"].mean()]