diff --git a/benchmark/benchmark.py b/benchmark/benchmark.py index 12dc1b1e9..c170e540d 100755 --- a/benchmark/benchmark.py +++ b/benchmark/benchmark.py @@ -72,7 +72,7 @@ def show_stats(dirnames): df = pd.DataFrame.from_records(rows) df.sort_values(by=["model", "edit_format"], inplace=True) - # df_grouped1 = df.groupby(["model", "edit_format"])["pass_rate_1"].mean() + df_grouped1 = df.groupby(["model", "edit_format"])["pass_rate_1"].mean() df_grouped2 = df.groupby(["model", "edit_format"])["pass_rate_2"].mean() plt.rcParams["hatch.linewidth"] = 0.5 @@ -93,7 +93,7 @@ def show_stats(dirnames): ) """ zorder = 1 - for grouped in (df_grouped2,): # df_grouped1 + for grouped in (df_grouped2, df_grouped1): zorder += 1 df = grouped.unstack() num_models, num_formats = df.shape