diff --git a/benchmark/benchmark.py b/benchmark/benchmark.py index fcfc4512c..3fd8ebb60 100755 --- a/benchmark/benchmark.py +++ b/benchmark/benchmark.py @@ -52,12 +52,13 @@ def show_stats(dirnames): df_grouped = df.groupby(["model", "edit_format"])["pass_rate_1"].mean() dump(df_grouped) dump(df_grouped.unstack()) - df_grouped.unstack().plot(kind="barh", figsize=(10, 6)) + fig, ax = plt.subplots(figsize=(10, 6)) + df_grouped.unstack().plot(kind="barh", ax=ax) - plt.xlabel("Pass Rate 1") - plt.ylabel("Model") - plt.title("Pass Rate 1 for each Model/Edit Format") - plt.legend(title="Edit Format") + ax.set_xlabel("Pass Rate 1") + ax.set_ylabel("Model") + ax.set_title("Pass Rate 1 for each Model/Edit Format") + ax.legend(title="Edit Format") plt.show() # df.to_csv("tmp.benchmarks.csv")