diff --git a/benchmark/benchmark.py b/benchmark/benchmark.py index 63401e9d1..8a84d4086 100755 --- a/benchmark/benchmark.py +++ b/benchmark/benchmark.py @@ -65,10 +65,11 @@ def show_stats(dirnames): df = pd.DataFrame.from_records(rows) df.sort_values(by=["model", "edit_format"], inplace=True) - df_grouped = df.groupby(["model", "edit_format"])["pass_rate_1"].mean() + df_grouped1 = df.groupby(["model", "edit_format"])["pass_rate_1"].mean() + df_grouped2 = df.groupby(["model", "edit_format"])["pass_rate_2"].mean() fig, ax = plt.subplots(figsize=(10, 6)) - df_grouped.unstack().plot(kind="barh", ax=ax) + df_grouped2.unstack().plot(kind="barh", ax=ax) ax.set_xlabel("Pass Rate 1") ax.set_ylabel("Model")