diff --git a/benchmark/benchmark.py b/benchmark/benchmark.py index cb264fe68..599502372 100755 --- a/benchmark/benchmark.py +++ b/benchmark/benchmark.py @@ -69,7 +69,8 @@ def show_stats(dirnames): df_grouped2 = df.groupby(["model", "edit_format"])["pass_rate_2"].mean() fig, ax = plt.subplots(figsize=(10, 6)) - df_grouped2.unstack().plot(kind="barh", ax=ax) + color_map = {"edit_format1": "blue", "edit_format2": "green", "edit_format3": "red"} # Define your color map here + df_grouped2.unstack().plot(kind="barh", ax=ax, color=[color_map[i] for i in df_grouped2.index]) ax.set_xlabel("Percent of passed unittests") ax.set_ylabel("Model")