diff --git a/benchmark/benchmark.py b/benchmark/benchmark.py index a5c9de9b0..599ec0efb 100755 --- a/benchmark/benchmark.py +++ b/benchmark/benchmark.py @@ -57,7 +57,7 @@ def show_stats(dirnames): df = pd.DataFrame.from_records(rows) df.sort_values(by=["model", "edit_format"], inplace=True) - df_grouped = df.groupby(["model", "edit_format"])["pass_rate_1"].mean() + df_grouped = df.groupby(["model", "edit_format"])["pass_rate_1"].mean().sort_values(ascending=False) fig, ax = plt.subplots(figsize=(10, 6)) df_grouped.unstack().plot(kind="barh", ax=ax)