diff --git a/benchmark/benchmark.py b/benchmark/benchmark.py index f0797f3d4..bffd874e8 100755 --- a/benchmark/benchmark.py +++ b/benchmark/benchmark.py @@ -43,12 +43,13 @@ def show_stats(dirnames): df = pd.DataFrame.from_records(rows) df.sort_values(by=["model", "edit_format"], inplace=True) - df_grouped = df.groupby("model")["pass_rate_1"].mean() - df_grouped.plot(kind='barh', figsize=(10, 6)) + df_grouped = df.groupby(["model", "edit_format"])["pass_rate_1"].mean() + df_grouped.unstack().plot(kind='barh', figsize=(10, 6)) plt.xlabel('Pass Rate 1') plt.ylabel('Model') plt.title('Pass Rate 1 for each Model/Edit Format') + plt.legend(title='Edit Format') plt.show() df.to_csv("tmp.benchmark.csv")