diff --git a/benchmark/benchmark.py b/benchmark/benchmark.py index c170e540d..2ad539fa6 100755 --- a/benchmark/benchmark.py +++ b/benchmark/benchmark.py @@ -72,8 +72,9 @@ def show_stats(dirnames): df = pd.DataFrame.from_records(rows) df.sort_values(by=["model", "edit_format"], inplace=True) - df_grouped1 = df.groupby(["model", "edit_format"])["pass_rate_1"].mean() - df_grouped2 = df.groupby(["model", "edit_format"])["pass_rate_2"].mean() + tries = [df.groupby(["model", "edit_format"])["pass_rate_2"].mean()] + if True: + tries += [df.groupby(["model", "edit_format"])["pass_rate_1"].mean()] plt.rcParams["hatch.linewidth"] = 0.5 plt.rcParams["hatch.color"] = "#444444" @@ -84,16 +85,9 @@ def show_stats(dirnames): fig, ax = plt.subplots(figsize=(12, 10)) ax.grid(axis="y", zorder=0, lw=0.2) - """ - colors = ["#007BFF", "#89CFF0", "#008000", "#32CD32"] - df_grouped2.unstack().plot( - kind="bar", - ax=ax, - color= colors, - ) - """ + zorder = 1 - for grouped in (df_grouped2, df_grouped1): + for grouped in tries: zorder += 1 df = grouped.unstack() num_models, num_formats = df.shape