refactor: Restructure benchmark plotting script for improved maintainability

This commit is contained in:
Paul Gauthier (aider) 2024-11-21 14:00:20 -08:00
parent 3cfbaa0ed6
commit 6d6d763dd3

View file

@ -2,14 +2,40 @@ import matplotlib.pyplot as plt
import yaml import yaml
from imgcat import imgcat from imgcat import imgcat
from matplotlib import rc from matplotlib import rc
from dataclasses import dataclass
from typing import List, Tuple, Dict
from datetime import date
from aider.dump import dump # noqa: 401 @dataclass
class ModelData:
name: str
release_date: date
pass_rate: float
LABEL_FONT_SIZE = 16 # Font size for scatter plot dot labels @property
def color(self) -> str:
model = self.name.lower()
if "qwen" in model:
return "darkblue"
if "mistral" in model:
return "cyan"
if "haiku" in model:
return "pink"
if "deepseek" in model:
return "brown"
if "sonnet" in model:
return "orange"
if "-4o" in model:
return "purple"
if "gpt-4" in model:
return "red"
if "gpt-3.5" in model:
return "green"
return "lightblue"
@property
def get_legend_label(model): def legend_label(self) -> str:
model = model.lower() model = self.name.lower()
if "claude-3-sonnet" in model: if "claude-3-sonnet" in model:
return "Sonnet" return "Sonnet"
if "o1-preview" in model: if "o1-preview" in model:
@ -28,251 +54,102 @@ def get_legend_label(model):
return "DeepSeek" return "DeepSeek"
if "mistral" in model: if "mistral" in model:
return "Mistral" return "Mistral"
if "o1-preview" in model:
return "o1-preview"
return model return model
class BenchmarkPlotter:
LABEL_FONT_SIZE = 16
def get_model_color(model): def __init__(self):
default = "lightblue" self.setup_plot_style()
if model == "gpt-4o-mini":
return default
if "qwen" in model.lower():
return "darkblue"
if "mistral" in model.lower():
return "cyan"
if "haiku" in model.lower():
return "pink"
if "deepseek" in model.lower():
return "brown"
if "sonnet" in model.lower():
return "orange"
if "-4o" in model:
return "purple"
if "gpt-4" in model:
return "red"
if "gpt-3.5" in model:
return "green"
return default
def plot_over_time(yaml_file):
with open(yaml_file, "r") as file:
data = yaml.safe_load(file)
dates = []
pass_rates = []
models = []
print("Debug: Raw data from YAML file:")
print(data)
for entry in data:
if "released" in entry and "pass_rate_2" in entry:
dates.append(entry["released"])
pass_rates.append(entry["pass_rate_2"])
models.append(entry["model"].split("(")[0].strip())
print("Debug: Processed data:")
print("Dates:", dates)
print("Pass rates:", pass_rates)
print("Models:", models)
if not dates or not pass_rates:
print(
"Error: No data to plot. Check if the YAML file is empty or if the data is in the"
" expected format."
)
return
def setup_plot_style(self):
plt.rcParams["hatch.linewidth"] = 0.5 plt.rcParams["hatch.linewidth"] = 0.5
plt.rcParams["hatch.color"] = "#444444" plt.rcParams["hatch.color"] = "#444444"
rc("font", **{"family": "sans-serif", "sans-serif": ["Helvetica"], "size": 10}) rc("font", **{"family": "sans-serif", "sans-serif": ["Helvetica"], "size": 10})
plt.rcParams["text.color"] = "#444444" plt.rcParams["text.color"] = "#444444"
fig, ax = plt.subplots(figsize=(12, 8)) # Make figure square def load_data(self, yaml_file: str) -> List[ModelData]:
with open(yaml_file, "r") as file:
data = yaml.safe_load(file)
print("Debug: Figure created. Plotting data...") models = []
for entry in data:
if "released" in entry and "pass_rate_2" in entry:
model = ModelData(
name=entry["model"].split("(")[0].strip(),
release_date=entry["released"],
pass_rate=entry["pass_rate_2"]
)
models.append(model)
return models
def create_figure(self) -> Tuple[plt.Figure, plt.Axes]:
fig, ax = plt.subplots(figsize=(12, 8))
ax.grid(axis="y", zorder=0, lw=0.2) ax.grid(axis="y", zorder=0, lw=0.2)
for spine in ax.spines.values(): for spine in ax.spines.values():
spine.set_edgecolor("#DDDDDD") spine.set_edgecolor("#DDDDDD")
spine.set_linewidth(0.5) spine.set_linewidth(0.5)
return fig, ax
colors = [get_model_color(model) for model in models] def plot_model_series(self, ax: plt.Axes, models: List[ModelData]):
# Group models by color
color_groups: Dict[str, List[ModelData]] = {}
for model in models:
if model.color not in color_groups:
color_groups[model.color] = []
color_groups[model.color].append(model)
# Separate data points by color # Plot each color group
purple_points = [(d, r) for d, r, c in zip(dates, pass_rates, colors) if c == "purple"] for color, group in color_groups.items():
red_points = [(d, r) for d, r, c in zip(dates, pass_rates, colors) if c == "red"] sorted_group = sorted(group, key=lambda x: x.release_date)
green_points = [(d, r) for d, r, c in zip(dates, pass_rates, colors) if c == "green"] dates = [m.release_date for m in sorted_group]
orange_points = [(d, r) for d, r, c in zip(dates, pass_rates, colors) if c == "orange"] rates = [m.pass_rate for m in sorted_group]
brown_points = [(d, r) for d, r, c in zip(dates, pass_rates, colors) if c == "brown"]
pink_points = [(d, r) for d, r, c in zip(dates, pass_rates, colors) if c == "pink"]
qwen_points = [(d, r) for d, r, c in zip(dates, pass_rates, colors) if c == "darkblue"]
mistral_points = [(d, r) for d, r, c in zip(dates, pass_rates, colors) if c == "cyan"]
# Create a mapping of colors to first points and labels # Plot line
color_to_first_point = {} ax.plot(dates, rates, c=color, alpha=0.5, linewidth=1)
color_to_label = {}
for date, rate, color, model in sorted(zip(dates, pass_rates, colors, models)): # Plot points
if color not in color_to_first_point: ax.scatter(dates, rates, c=color, alpha=0.5, s=120)
color_to_first_point[color] = (date, rate)
color_to_label[color] = get_legend_label(model)
# Plot lines and add labels at first points # Add label for first point
if purple_points: first_model = sorted_group[0]
purple_dates, purple_rates = zip(*sorted(purple_points))
ax.plot(purple_dates, purple_rates, c="purple", alpha=0.5, linewidth=1)
if "purple" in color_to_first_point:
date, rate = color_to_first_point["purple"]
ax.annotate( ax.annotate(
color_to_label["purple"], first_model.legend_label,
(date, rate), (first_model.release_date, first_model.pass_rate),
xytext=(10, 5), xytext=(10, 5),
textcoords="offset points", textcoords="offset points",
color="purple", color=color,
alpha=0.8, alpha=0.8,
fontsize=LABEL_FONT_SIZE, fontsize=self.LABEL_FONT_SIZE,
) )
if red_points: def set_labels_and_style(self, ax: plt.Axes):
red_dates, red_rates = zip(*sorted(red_points))
ax.plot(red_dates, red_rates, c="red", alpha=0.5, linewidth=1)
if "red" in color_to_first_point:
date, rate = color_to_first_point["red"]
ax.annotate(
color_to_label["red"],
(date, rate),
xytext=(10, 5),
textcoords="offset points",
color="red",
alpha=0.8,
fontsize=LABEL_FONT_SIZE,
)
if green_points:
green_dates, green_rates = zip(*sorted(green_points))
ax.plot(green_dates, green_rates, c="green", alpha=0.5, linewidth=1)
if "green" in color_to_first_point:
date, rate = color_to_first_point["green"]
ax.annotate(
color_to_label["green"],
(date, rate),
xytext=(10, 5),
textcoords="offset points",
color="green",
alpha=0.8,
fontsize=LABEL_FONT_SIZE,
)
if orange_points:
orange_dates, orange_rates = zip(*sorted(orange_points))
ax.plot(orange_dates, orange_rates, c="orange", alpha=0.5, linewidth=1)
if "orange" in color_to_first_point:
date, rate = color_to_first_point["orange"]
ax.annotate(
color_to_label["orange"],
(date, rate),
xytext=(10, 5),
textcoords="offset points",
color="orange",
alpha=0.8,
fontsize=LABEL_FONT_SIZE,
)
if brown_points:
brown_dates, brown_rates = zip(*sorted(brown_points))
ax.plot(brown_dates, brown_rates, c="brown", alpha=0.5, linewidth=1)
if "brown" in color_to_first_point:
date, rate = color_to_first_point["brown"]
ax.annotate(
color_to_label["brown"],
(date, rate),
xytext=(10, -10),
textcoords="offset points",
color="brown",
alpha=0.8,
fontsize=LABEL_FONT_SIZE,
)
if pink_points:
pink_dates, pink_rates = zip(*sorted(pink_points))
ax.plot(pink_dates, pink_rates, c="pink", alpha=0.5, linewidth=1)
if "pink" in color_to_first_point:
date, rate = color_to_first_point["pink"]
ax.annotate(
color_to_label["pink"],
(date, rate),
xytext=(10, 5),
textcoords="offset points",
color="pink",
alpha=0.8,
fontsize=LABEL_FONT_SIZE,
)
if qwen_points:
qwen_dates, qwen_rates = zip(*sorted(qwen_points))
ax.plot(qwen_dates, qwen_rates, c="darkblue", alpha=0.5, linewidth=1)
if "darkblue" in color_to_first_point:
date, rate = color_to_first_point["darkblue"]
ax.annotate(
color_to_label["darkblue"],
(date, rate),
xytext=(10, 5),
textcoords="offset points",
color="darkblue",
alpha=0.8,
fontsize=LABEL_FONT_SIZE,
)
if mistral_points:
mistral_dates, mistral_rates = zip(*sorted(mistral_points))
ax.plot(mistral_dates, mistral_rates, c="cyan", alpha=0.5, linewidth=1)
if "cyan" in color_to_first_point:
date, rate = color_to_first_point["cyan"]
ax.annotate(
color_to_label["cyan"],
(date, rate),
xytext=(10, -10),
textcoords="offset points",
color="cyan",
alpha=0.8,
fontsize=LABEL_FONT_SIZE,
)
# Plot points without legend
for date, rate, color in zip(dates, pass_rates, colors):
ax.scatter([date], [rate], c=[color], alpha=0.5, s=120)
ax.set_xlabel("Model release date", fontsize=18, color="#555") ax.set_xlabel("Model release date", fontsize=18, color="#555")
ax.set_ylabel( ax.set_ylabel(
"Aider code editing benchmark,\npercent completed correctly", fontsize=18, color="#555" "Aider code editing benchmark,\npercent completed correctly",
fontsize=18,
color="#555"
) )
ax.set_title("LLM code editing skill by model release date", fontsize=20) ax.set_title("LLM code editing skill by model release date", fontsize=20)
ax.set_ylim(30, 90) # Adjust y-axis limit to accommodate higher values ax.set_ylim(30, 90)
plt.xticks(fontsize=14, rotation=45, ha="right") # Rotate x-axis labels for better readability plt.xticks(fontsize=14, rotation=45, ha="right")
plt.tight_layout(pad=1.0) # Adjust layout since we don't need room for legend anymore plt.tight_layout(pad=1.0)
print("Debug: Saving figures...") def save_and_display(self, fig: plt.Figure):
plt.savefig("tmp_over_time.png") plt.savefig("tmp_over_time.png")
plt.savefig("tmp_over_time.svg") plt.savefig("tmp_over_time.svg")
print("Debug: Displaying figure with imgcat...")
imgcat(fig) imgcat(fig)
print("Debug: Figure generation complete.") def plot(self, yaml_file: str):
models = self.load_data(yaml_file)
fig, ax = self.create_figure()
self.plot_model_series(ax, models)
self.set_labels_and_style(ax)
self.save_and_display(fig)
def main():
plotter = BenchmarkPlotter()
plotter.plot("aider/website/_data/edit_leaderboard.yml")
# Example usage if __name__ == "__main__":
plot_over_time("aider/website/_data/edit_leaderboard.yml") main()