aider: Implemented the plot_timing function to create a bar chart showing the average duration of each (model, edit_format) combination, with consistent styling as plot_outcomes.

# Aider chat conversation:

ASSISTANT: To implement the `plot_timing` function and style the graph similarly to `plot_outcomes`, we will follow these steps:

1. Define the `plot_timing` function to create a bar chart that shows the average duration for each combination of model and edit format.
2. Use similar styling options as used in `plot_outcomes` for consistency.

Let's start by implementing the `plot_timing` function in the `benchmark/benchmark.py` file.

<source>python
benchmark/benchmark.py
<<<<<<< SEARCH
def plot_timing(df):
    """plot a graph showing the average duration of each (model, edit_format)"""
=======
def plot_timing(df):
    """plot a graph showing the average duration of each (model, edit_format)"""
    plt.rcParams["hatch.linewidth"] = 0.5
    plt.rcParams["hatch.color"] = "#444444"

    from matplotlib import rc
    rc("font", **{"family": "sans-serif", "sans-serif": ["Helvetica"], "size": 10})

    fig, ax = plt.subplots(figsize=(6, 4))
    ax.grid(axis="y", zorder=0, lw=0.2)

    zorder = 1
    grouped = df.groupby(["model", "edit_format"])["avg_duration"].mean().unstack()
    num_models, num_formats = grouped.shape

    pos = np.array(range(num_models))
    width = 0.8 / num_formats

    formats = grouped.columns
    models = grouped.index

    for i, fmt in enumerate(formats):
        edge = dict(edgecolor="#ffffff", linewidth=1.5)
        color = "#b3e6a8" if "diff" in fmt else "#b3d1e6"
        hatch = "////" if "func" in fmt else ""
        rects = ax.bar(
            pos + i * width,
            grouped[fmt],
            width * 0.95,
            color=color,
            hatch=hatch,
            zorder=zorder + 1,
            **edge,
        )
        ax.bar_label(rects, padding=4, labels=[f"{v:.1f}s" for v in grouped[fmt]], size=6)

    ax.set_xticks([p + 0.5 * width for p in pos])
    ax.set_xticklabels(models)

    ax.set_ylabel("Average duration (seconds)")
    ax.set_title("GPT Code Editing - Timing")
    ax.legend(
        title="Edit Format",
        loc="upper left",
    )
    ax.set_ylim(top=max(grouped.max()) * 1.1)  # Set y-axis limit to 10% more than the max value

    plt.tight_layout()
    plt.savefig("tmp_timing.svg")
    imgcat(fig)
>>>>>>> REPLACE
</source>

This code block adds the implementation of the `plot_timing` function, which creates a bar chart to visualize the average duration of tests for each model and edit format. The styling is kept consistent with the `plot_outcomes` function, using similar colors, hatches, and font settings. The resulting graph is saved as an SVG file and displayed in the terminal using `imgcat`.
This commit is contained in:
Paul Gauthier 2023-11-08 10:07:29 -08:00
parent ab55b6ff0a
commit 207bbbc4df

View file

@ -127,6 +127,54 @@ def show_stats(dirnames):
def plot_timing(df):
"""plot a graph showing the average duration of each (model, edit_format)"""
plt.rcParams["hatch.linewidth"] = 0.5
plt.rcParams["hatch.color"] = "#444444"
from matplotlib import rc
rc("font", **{"family": "sans-serif", "sans-serif": ["Helvetica"], "size": 10})
fig, ax = plt.subplots(figsize=(6, 4))
ax.grid(axis="y", zorder=0, lw=0.2)
zorder = 1
grouped = df.groupby(["model", "edit_format"])["avg_duration"].mean().unstack()
num_models, num_formats = grouped.shape
pos = np.array(range(num_models))
width = 0.8 / num_formats
formats = grouped.columns
models = grouped.index
for i, fmt in enumerate(formats):
edge = dict(edgecolor="#ffffff", linewidth=1.5)
color = "#b3e6a8" if "diff" in fmt else "#b3d1e6"
hatch = "////" if "func" in fmt else ""
rects = ax.bar(
pos + i * width,
grouped[fmt],
width * 0.95,
color=color,
hatch=hatch,
zorder=zorder + 1,
**edge,
)
ax.bar_label(rects, padding=4, labels=[f"{v:.1f}s" for v in grouped[fmt]], size=6)
ax.set_xticks([p + 0.5 * width for p in pos])
ax.set_xticklabels(models)
ax.set_ylabel("Average duration (seconds)")
ax.set_title("GPT Code Editing - Timing")
ax.legend(
title="Edit Format",
loc="upper left",
)
ax.set_ylim(top=max(grouped.max()) * 1.1) # Set y-axis limit to 10% more than the max value
plt.tight_layout()
plt.savefig("tmp_timing.svg")
imgcat(fig)
def plot_outcomes(df, repeats, repeat_hi, repeat_lo, repeat_avg):