From cba75a3bfeb2cf99f914961d57561a876fc3d189 Mon Sep 17 00:00:00 2001 From: Paul Gauthier Date: Thu, 29 Jun 2023 22:08:09 -0700 Subject: [PATCH] prettier --- benchmark/benchmark.py | 40 +++++++++++++++++++++++++++++++++++----- 1 file changed, 35 insertions(+), 5 deletions(-) diff --git a/benchmark/benchmark.py b/benchmark/benchmark.py index 599502372..39e8d23a1 100755 --- a/benchmark/benchmark.py +++ b/benchmark/benchmark.py @@ -17,6 +17,7 @@ from typing import List import git import lox import matplotlib.pyplot as plt +import numpy as np import pandas as pd import prompts import typer @@ -68,12 +69,41 @@ def show_stats(dirnames): # df_grouped1 = df.groupby(["model", "edit_format"])["pass_rate_1"].mean() df_grouped2 = df.groupby(["model", "edit_format"])["pass_rate_2"].mean() - fig, ax = plt.subplots(figsize=(10, 6)) - color_map = {"edit_format1": "blue", "edit_format2": "green", "edit_format3": "red"} # Define your color map here - df_grouped2.unstack().plot(kind="barh", ax=ax, color=[color_map[i] for i in df_grouped2.index]) + dump(df_grouped2.unstack()) - ax.set_xlabel("Percent of passed unittests") - ax.set_ylabel("Model") + plt.rcParams["hatch.linewidth"] = 0.5 + plt.rcParams["hatch.color"] = "#444444" + + fig, ax = plt.subplots(figsize=(10, 6)) + """ + colors = ["#007BFF", "#89CFF0", "#008000", "#32CD32"] + df_grouped2.unstack().plot( + kind="bar", + ax=ax, + color= colors, + ) + """ + df = df_grouped2.unstack() + num_models, num_formats = df.shape + dump(num_models) + dump(num_formats) + pos = np.array(range(num_models)) + width = 0.8 / num_formats + + formats = df.columns + models = df.index + + for i, fmt in enumerate(formats): + dump(df[fmt]) + color = "#b3e6a8" if "diff" in fmt else "#b3d1e6" + hatch = "///" if "func" in fmt else "" + ax.bar(pos + i * width, df[fmt], width * 0.95, label=fmt, color=color, hatch=hatch) + + ax.set_xticks([p + 1.5 * width for p in pos]) + ax.set_xticklabels(models) + + ax.set_ylabel("Percent of passed unittests") + # ax.set_xlabel("Model") ax.set_title("Code editing success rate by model & edit format") ax.legend(title="Edit Format")