refactor: Restructure benchmark plotting script for improved maintainability

This commit is contained in:
Paul Gauthier (aider) 2024-11-21 14:00:20 -08:00
parent 3cfbaa0ed6
commit 6d6d763dd3

View file

@ -2,14 +2,40 @@ import matplotlib.pyplot as plt
import yaml
from imgcat import imgcat
from matplotlib import rc
from dataclasses import dataclass
from typing import List, Tuple, Dict
from datetime import date
from aider.dump import dump # noqa: 401
@dataclass
class ModelData:
name: str
release_date: date
pass_rate: float
LABEL_FONT_SIZE = 16 # Font size for scatter plot dot labels
@property
def color(self) -> str:
model = self.name.lower()
if "qwen" in model:
return "darkblue"
if "mistral" in model:
return "cyan"
if "haiku" in model:
return "pink"
if "deepseek" in model:
return "brown"
if "sonnet" in model:
return "orange"
if "-4o" in model:
return "purple"
if "gpt-4" in model:
return "red"
if "gpt-3.5" in model:
return "green"
return "lightblue"
def get_legend_label(model):
model = model.lower()
@property
def legend_label(self) -> str:
model = self.name.lower()
if "claude-3-sonnet" in model:
return "Sonnet"
if "o1-preview" in model:
@ -28,251 +54,102 @@ def get_legend_label(model):
return "DeepSeek"
if "mistral" in model:
return "Mistral"
if "o1-preview" in model:
return "o1-preview"
return model
class BenchmarkPlotter:
LABEL_FONT_SIZE = 16
def get_model_color(model):
default = "lightblue"
if model == "gpt-4o-mini":
return default
if "qwen" in model.lower():
return "darkblue"
if "mistral" in model.lower():
return "cyan"
if "haiku" in model.lower():
return "pink"
if "deepseek" in model.lower():
return "brown"
if "sonnet" in model.lower():
return "orange"
if "-4o" in model:
return "purple"
if "gpt-4" in model:
return "red"
if "gpt-3.5" in model:
return "green"
return default
def plot_over_time(yaml_file):
with open(yaml_file, "r") as file:
data = yaml.safe_load(file)
dates = []
pass_rates = []
models = []
print("Debug: Raw data from YAML file:")
print(data)
for entry in data:
if "released" in entry and "pass_rate_2" in entry:
dates.append(entry["released"])
pass_rates.append(entry["pass_rate_2"])
models.append(entry["model"].split("(")[0].strip())
print("Debug: Processed data:")
print("Dates:", dates)
print("Pass rates:", pass_rates)
print("Models:", models)
if not dates or not pass_rates:
print(
"Error: No data to plot. Check if the YAML file is empty or if the data is in the"
" expected format."
)
return
def __init__(self):
self.setup_plot_style()
def setup_plot_style(self):
plt.rcParams["hatch.linewidth"] = 0.5
plt.rcParams["hatch.color"] = "#444444"
rc("font", **{"family": "sans-serif", "sans-serif": ["Helvetica"], "size": 10})
plt.rcParams["text.color"] = "#444444"
fig, ax = plt.subplots(figsize=(12, 8)) # Make figure square
def load_data(self, yaml_file: str) -> List[ModelData]:
with open(yaml_file, "r") as file:
data = yaml.safe_load(file)
print("Debug: Figure created. Plotting data...")
models = []
for entry in data:
if "released" in entry and "pass_rate_2" in entry:
model = ModelData(
name=entry["model"].split("(")[0].strip(),
release_date=entry["released"],
pass_rate=entry["pass_rate_2"]
)
models.append(model)
return models
def create_figure(self) -> Tuple[plt.Figure, plt.Axes]:
fig, ax = plt.subplots(figsize=(12, 8))
ax.grid(axis="y", zorder=0, lw=0.2)
for spine in ax.spines.values():
spine.set_edgecolor("#DDDDDD")
spine.set_linewidth(0.5)
return fig, ax
colors = [get_model_color(model) for model in models]
def plot_model_series(self, ax: plt.Axes, models: List[ModelData]):
# Group models by color
color_groups: Dict[str, List[ModelData]] = {}
for model in models:
if model.color not in color_groups:
color_groups[model.color] = []
color_groups[model.color].append(model)
# Separate data points by color
purple_points = [(d, r) for d, r, c in zip(dates, pass_rates, colors) if c == "purple"]
red_points = [(d, r) for d, r, c in zip(dates, pass_rates, colors) if c == "red"]
green_points = [(d, r) for d, r, c in zip(dates, pass_rates, colors) if c == "green"]
orange_points = [(d, r) for d, r, c in zip(dates, pass_rates, colors) if c == "orange"]
brown_points = [(d, r) for d, r, c in zip(dates, pass_rates, colors) if c == "brown"]
pink_points = [(d, r) for d, r, c in zip(dates, pass_rates, colors) if c == "pink"]
qwen_points = [(d, r) for d, r, c in zip(dates, pass_rates, colors) if c == "darkblue"]
mistral_points = [(d, r) for d, r, c in zip(dates, pass_rates, colors) if c == "cyan"]
# Plot each color group
for color, group in color_groups.items():
sorted_group = sorted(group, key=lambda x: x.release_date)
dates = [m.release_date for m in sorted_group]
rates = [m.pass_rate for m in sorted_group]
# Create a mapping of colors to first points and labels
color_to_first_point = {}
color_to_label = {}
# Plot line
ax.plot(dates, rates, c=color, alpha=0.5, linewidth=1)
for date, rate, color, model in sorted(zip(dates, pass_rates, colors, models)):
if color not in color_to_first_point:
color_to_first_point[color] = (date, rate)
color_to_label[color] = get_legend_label(model)
# Plot points
ax.scatter(dates, rates, c=color, alpha=0.5, s=120)
# Plot lines and add labels at first points
if purple_points:
purple_dates, purple_rates = zip(*sorted(purple_points))
ax.plot(purple_dates, purple_rates, c="purple", alpha=0.5, linewidth=1)
if "purple" in color_to_first_point:
date, rate = color_to_first_point["purple"]
# Add label for first point
first_model = sorted_group[0]
ax.annotate(
color_to_label["purple"],
(date, rate),
first_model.legend_label,
(first_model.release_date, first_model.pass_rate),
xytext=(10, 5),
textcoords="offset points",
color="purple",
color=color,
alpha=0.8,
fontsize=LABEL_FONT_SIZE,
fontsize=self.LABEL_FONT_SIZE,
)
if red_points:
red_dates, red_rates = zip(*sorted(red_points))
ax.plot(red_dates, red_rates, c="red", alpha=0.5, linewidth=1)
if "red" in color_to_first_point:
date, rate = color_to_first_point["red"]
ax.annotate(
color_to_label["red"],
(date, rate),
xytext=(10, 5),
textcoords="offset points",
color="red",
alpha=0.8,
fontsize=LABEL_FONT_SIZE,
)
if green_points:
green_dates, green_rates = zip(*sorted(green_points))
ax.plot(green_dates, green_rates, c="green", alpha=0.5, linewidth=1)
if "green" in color_to_first_point:
date, rate = color_to_first_point["green"]
ax.annotate(
color_to_label["green"],
(date, rate),
xytext=(10, 5),
textcoords="offset points",
color="green",
alpha=0.8,
fontsize=LABEL_FONT_SIZE,
)
if orange_points:
orange_dates, orange_rates = zip(*sorted(orange_points))
ax.plot(orange_dates, orange_rates, c="orange", alpha=0.5, linewidth=1)
if "orange" in color_to_first_point:
date, rate = color_to_first_point["orange"]
ax.annotate(
color_to_label["orange"],
(date, rate),
xytext=(10, 5),
textcoords="offset points",
color="orange",
alpha=0.8,
fontsize=LABEL_FONT_SIZE,
)
if brown_points:
brown_dates, brown_rates = zip(*sorted(brown_points))
ax.plot(brown_dates, brown_rates, c="brown", alpha=0.5, linewidth=1)
if "brown" in color_to_first_point:
date, rate = color_to_first_point["brown"]
ax.annotate(
color_to_label["brown"],
(date, rate),
xytext=(10, -10),
textcoords="offset points",
color="brown",
alpha=0.8,
fontsize=LABEL_FONT_SIZE,
)
if pink_points:
pink_dates, pink_rates = zip(*sorted(pink_points))
ax.plot(pink_dates, pink_rates, c="pink", alpha=0.5, linewidth=1)
if "pink" in color_to_first_point:
date, rate = color_to_first_point["pink"]
ax.annotate(
color_to_label["pink"],
(date, rate),
xytext=(10, 5),
textcoords="offset points",
color="pink",
alpha=0.8,
fontsize=LABEL_FONT_SIZE,
)
if qwen_points:
qwen_dates, qwen_rates = zip(*sorted(qwen_points))
ax.plot(qwen_dates, qwen_rates, c="darkblue", alpha=0.5, linewidth=1)
if "darkblue" in color_to_first_point:
date, rate = color_to_first_point["darkblue"]
ax.annotate(
color_to_label["darkblue"],
(date, rate),
xytext=(10, 5),
textcoords="offset points",
color="darkblue",
alpha=0.8,
fontsize=LABEL_FONT_SIZE,
)
if mistral_points:
mistral_dates, mistral_rates = zip(*sorted(mistral_points))
ax.plot(mistral_dates, mistral_rates, c="cyan", alpha=0.5, linewidth=1)
if "cyan" in color_to_first_point:
date, rate = color_to_first_point["cyan"]
ax.annotate(
color_to_label["cyan"],
(date, rate),
xytext=(10, -10),
textcoords="offset points",
color="cyan",
alpha=0.8,
fontsize=LABEL_FONT_SIZE,
)
# Plot points without legend
for date, rate, color in zip(dates, pass_rates, colors):
ax.scatter([date], [rate], c=[color], alpha=0.5, s=120)
def set_labels_and_style(self, ax: plt.Axes):
ax.set_xlabel("Model release date", fontsize=18, color="#555")
ax.set_ylabel(
"Aider code editing benchmark,\npercent completed correctly", fontsize=18, color="#555"
"Aider code editing benchmark,\npercent completed correctly",
fontsize=18,
color="#555"
)
ax.set_title("LLM code editing skill by model release date", fontsize=20)
ax.set_ylim(30, 90) # Adjust y-axis limit to accommodate higher values
plt.xticks(fontsize=14, rotation=45, ha="right") # Rotate x-axis labels for better readability
plt.tight_layout(pad=1.0) # Adjust layout since we don't need room for legend anymore
ax.set_ylim(30, 90)
plt.xticks(fontsize=14, rotation=45, ha="right")
plt.tight_layout(pad=1.0)
print("Debug: Saving figures...")
def save_and_display(self, fig: plt.Figure):
plt.savefig("tmp_over_time.png")
plt.savefig("tmp_over_time.svg")
print("Debug: Displaying figure with imgcat...")
imgcat(fig)
print("Debug: Figure generation complete.")
def plot(self, yaml_file: str):
models = self.load_data(yaml_file)
fig, ax = self.create_figure()
self.plot_model_series(ax, models)
self.set_labels_and_style(ax)
self.save_and_display(fig)
def main():
plotter = BenchmarkPlotter()
plotter.plot("aider/website/_data/edit_leaderboard.yml")
# Example usage
plot_over_time("aider/website/_data/edit_leaderboard.yml")
if __name__ == "__main__":
main()