diff --git a/benchmark/benchmark.py b/benchmark/benchmark.py index ddd3b0977..d23065415 100755 --- a/benchmark/benchmark.py +++ b/benchmark/benchmark.py @@ -127,6 +127,54 @@ def show_stats(dirnames): def plot_timing(df): """plot a graph showing the average duration of each (model, edit_format)""" + plt.rcParams["hatch.linewidth"] = 0.5 + plt.rcParams["hatch.color"] = "#444444" + + from matplotlib import rc + rc("font", **{"family": "sans-serif", "sans-serif": ["Helvetica"], "size": 10}) + + fig, ax = plt.subplots(figsize=(6, 4)) + ax.grid(axis="y", zorder=0, lw=0.2) + + zorder = 1 + grouped = df.groupby(["model", "edit_format"])["avg_duration"].mean().unstack() + num_models, num_formats = grouped.shape + + pos = np.array(range(num_models)) + width = 0.8 / num_formats + + formats = grouped.columns + models = grouped.index + + for i, fmt in enumerate(formats): + edge = dict(edgecolor="#ffffff", linewidth=1.5) + color = "#b3e6a8" if "diff" in fmt else "#b3d1e6" + hatch = "////" if "func" in fmt else "" + rects = ax.bar( + pos + i * width, + grouped[fmt], + width * 0.95, + color=color, + hatch=hatch, + zorder=zorder + 1, + **edge, + ) + ax.bar_label(rects, padding=4, labels=[f"{v:.1f}s" for v in grouped[fmt]], size=6) + + ax.set_xticks([p + 0.5 * width for p in pos]) + ax.set_xticklabels(models) + + ax.set_ylabel("Average duration (seconds)") + ax.set_title("GPT Code Editing - Timing") + ax.legend( + title="Edit Format", + loc="upper left", + ) + ax.set_ylim(top=max(grouped.max()) * 1.1) # Set y-axis limit to 10% more than the max value + + plt.tight_layout() + plt.savefig("tmp_timing.svg") + imgcat(fig) def plot_outcomes(df, repeats, repeat_hi, repeat_lo, repeat_avg):