mirror of
https://github.com/Aider-AI/aider.git
synced 2025-05-31 01:35:00 +00:00
refactor: Restructure benchmark plotting script for improved maintainability
This commit is contained in:
parent
3cfbaa0ed6
commit
6d6d763dd3
1 changed files with 137 additions and 260 deletions
|
@ -2,277 +2,154 @@ import matplotlib.pyplot as plt
|
|||
import yaml
|
||||
from imgcat import imgcat
|
||||
from matplotlib import rc
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Tuple, Dict
|
||||
from datetime import date
|
||||
|
||||
from aider.dump import dump # noqa: 401
|
||||
@dataclass
|
||||
class ModelData:
|
||||
name: str
|
||||
release_date: date
|
||||
pass_rate: float
|
||||
|
||||
@property
|
||||
def color(self) -> str:
|
||||
model = self.name.lower()
|
||||
if "qwen" in model:
|
||||
return "darkblue"
|
||||
if "mistral" in model:
|
||||
return "cyan"
|
||||
if "haiku" in model:
|
||||
return "pink"
|
||||
if "deepseek" in model:
|
||||
return "brown"
|
||||
if "sonnet" in model:
|
||||
return "orange"
|
||||
if "-4o" in model:
|
||||
return "purple"
|
||||
if "gpt-4" in model:
|
||||
return "red"
|
||||
if "gpt-3.5" in model:
|
||||
return "green"
|
||||
return "lightblue"
|
||||
|
||||
LABEL_FONT_SIZE = 16 # Font size for scatter plot dot labels
|
||||
@property
|
||||
def legend_label(self) -> str:
|
||||
model = self.name.lower()
|
||||
if "claude-3-sonnet" in model:
|
||||
return "Sonnet"
|
||||
if "o1-preview" in model:
|
||||
return "O1 Preview"
|
||||
if "gpt-3.5" in model:
|
||||
return "GPT-3.5 Turbo"
|
||||
if "gpt-4-" in model and "-4o" not in model:
|
||||
return "GPT-4"
|
||||
if "qwen" in model:
|
||||
return "Qwen"
|
||||
if "-4o" in model:
|
||||
return "GPT-4o"
|
||||
if "haiku" in model:
|
||||
return "Haiku"
|
||||
if "deepseek" in model:
|
||||
return "DeepSeek"
|
||||
if "mistral" in model:
|
||||
return "Mistral"
|
||||
return model
|
||||
|
||||
class BenchmarkPlotter:
|
||||
LABEL_FONT_SIZE = 16
|
||||
|
||||
def __init__(self):
|
||||
self.setup_plot_style()
|
||||
|
||||
def get_legend_label(model):
|
||||
model = model.lower()
|
||||
if "claude-3-sonnet" in model:
|
||||
return "Sonnet"
|
||||
if "o1-preview" in model:
|
||||
return "O1 Preview"
|
||||
if "gpt-3.5" in model:
|
||||
return "GPT-3.5 Turbo"
|
||||
if "gpt-4-" in model and "-4o" not in model:
|
||||
return "GPT-4"
|
||||
if "qwen" in model:
|
||||
return "Qwen"
|
||||
if "-4o" in model:
|
||||
return "GPT-4o"
|
||||
if "haiku" in model:
|
||||
return "Haiku"
|
||||
if "deepseek" in model:
|
||||
return "DeepSeek"
|
||||
if "mistral" in model:
|
||||
return "Mistral"
|
||||
if "o1-preview" in model:
|
||||
return "o1-preview"
|
||||
return model
|
||||
def setup_plot_style(self):
|
||||
plt.rcParams["hatch.linewidth"] = 0.5
|
||||
plt.rcParams["hatch.color"] = "#444444"
|
||||
rc("font", **{"family": "sans-serif", "sans-serif": ["Helvetica"], "size": 10})
|
||||
plt.rcParams["text.color"] = "#444444"
|
||||
|
||||
def load_data(self, yaml_file: str) -> List[ModelData]:
|
||||
with open(yaml_file, "r") as file:
|
||||
data = yaml.safe_load(file)
|
||||
|
||||
models = []
|
||||
for entry in data:
|
||||
if "released" in entry and "pass_rate_2" in entry:
|
||||
model = ModelData(
|
||||
name=entry["model"].split("(")[0].strip(),
|
||||
release_date=entry["released"],
|
||||
pass_rate=entry["pass_rate_2"]
|
||||
)
|
||||
models.append(model)
|
||||
return models
|
||||
|
||||
def get_model_color(model):
|
||||
default = "lightblue"
|
||||
def create_figure(self) -> Tuple[plt.Figure, plt.Axes]:
|
||||
fig, ax = plt.subplots(figsize=(12, 8))
|
||||
ax.grid(axis="y", zorder=0, lw=0.2)
|
||||
for spine in ax.spines.values():
|
||||
spine.set_edgecolor("#DDDDDD")
|
||||
spine.set_linewidth(0.5)
|
||||
return fig, ax
|
||||
|
||||
if model == "gpt-4o-mini":
|
||||
return default
|
||||
def plot_model_series(self, ax: plt.Axes, models: List[ModelData]):
|
||||
# Group models by color
|
||||
color_groups: Dict[str, List[ModelData]] = {}
|
||||
for model in models:
|
||||
if model.color not in color_groups:
|
||||
color_groups[model.color] = []
|
||||
color_groups[model.color].append(model)
|
||||
|
||||
if "qwen" in model.lower():
|
||||
return "darkblue"
|
||||
# Plot each color group
|
||||
for color, group in color_groups.items():
|
||||
sorted_group = sorted(group, key=lambda x: x.release_date)
|
||||
dates = [m.release_date for m in sorted_group]
|
||||
rates = [m.pass_rate for m in sorted_group]
|
||||
|
||||
# Plot line
|
||||
ax.plot(dates, rates, c=color, alpha=0.5, linewidth=1)
|
||||
|
||||
# Plot points
|
||||
ax.scatter(dates, rates, c=color, alpha=0.5, s=120)
|
||||
|
||||
# Add label for first point
|
||||
first_model = sorted_group[0]
|
||||
ax.annotate(
|
||||
first_model.legend_label,
|
||||
(first_model.release_date, first_model.pass_rate),
|
||||
xytext=(10, 5),
|
||||
textcoords="offset points",
|
||||
color=color,
|
||||
alpha=0.8,
|
||||
fontsize=self.LABEL_FONT_SIZE,
|
||||
)
|
||||
|
||||
if "mistral" in model.lower():
|
||||
return "cyan"
|
||||
|
||||
if "haiku" in model.lower():
|
||||
return "pink"
|
||||
|
||||
if "deepseek" in model.lower():
|
||||
return "brown"
|
||||
|
||||
if "sonnet" in model.lower():
|
||||
return "orange"
|
||||
|
||||
if "-4o" in model:
|
||||
return "purple"
|
||||
|
||||
if "gpt-4" in model:
|
||||
return "red"
|
||||
|
||||
if "gpt-3.5" in model:
|
||||
return "green"
|
||||
|
||||
return default
|
||||
|
||||
|
||||
def plot_over_time(yaml_file):
|
||||
with open(yaml_file, "r") as file:
|
||||
data = yaml.safe_load(file)
|
||||
|
||||
dates = []
|
||||
pass_rates = []
|
||||
models = []
|
||||
|
||||
print("Debug: Raw data from YAML file:")
|
||||
print(data)
|
||||
|
||||
for entry in data:
|
||||
if "released" in entry and "pass_rate_2" in entry:
|
||||
dates.append(entry["released"])
|
||||
pass_rates.append(entry["pass_rate_2"])
|
||||
models.append(entry["model"].split("(")[0].strip())
|
||||
|
||||
print("Debug: Processed data:")
|
||||
print("Dates:", dates)
|
||||
print("Pass rates:", pass_rates)
|
||||
print("Models:", models)
|
||||
|
||||
if not dates or not pass_rates:
|
||||
print(
|
||||
"Error: No data to plot. Check if the YAML file is empty or if the data is in the"
|
||||
" expected format."
|
||||
def set_labels_and_style(self, ax: plt.Axes):
|
||||
ax.set_xlabel("Model release date", fontsize=18, color="#555")
|
||||
ax.set_ylabel(
|
||||
"Aider code editing benchmark,\npercent completed correctly",
|
||||
fontsize=18,
|
||||
color="#555"
|
||||
)
|
||||
return
|
||||
ax.set_title("LLM code editing skill by model release date", fontsize=20)
|
||||
ax.set_ylim(30, 90)
|
||||
plt.xticks(fontsize=14, rotation=45, ha="right")
|
||||
plt.tight_layout(pad=1.0)
|
||||
|
||||
plt.rcParams["hatch.linewidth"] = 0.5
|
||||
plt.rcParams["hatch.color"] = "#444444"
|
||||
def save_and_display(self, fig: plt.Figure):
|
||||
plt.savefig("tmp_over_time.png")
|
||||
plt.savefig("tmp_over_time.svg")
|
||||
imgcat(fig)
|
||||
|
||||
rc("font", **{"family": "sans-serif", "sans-serif": ["Helvetica"], "size": 10})
|
||||
plt.rcParams["text.color"] = "#444444"
|
||||
def plot(self, yaml_file: str):
|
||||
models = self.load_data(yaml_file)
|
||||
fig, ax = self.create_figure()
|
||||
self.plot_model_series(ax, models)
|
||||
self.set_labels_and_style(ax)
|
||||
self.save_and_display(fig)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(12, 8)) # Make figure square
|
||||
def main():
|
||||
plotter = BenchmarkPlotter()
|
||||
plotter.plot("aider/website/_data/edit_leaderboard.yml")
|
||||
|
||||
print("Debug: Figure created. Plotting data...")
|
||||
ax.grid(axis="y", zorder=0, lw=0.2)
|
||||
for spine in ax.spines.values():
|
||||
spine.set_edgecolor("#DDDDDD")
|
||||
spine.set_linewidth(0.5)
|
||||
|
||||
colors = [get_model_color(model) for model in models]
|
||||
|
||||
# Separate data points by color
|
||||
purple_points = [(d, r) for d, r, c in zip(dates, pass_rates, colors) if c == "purple"]
|
||||
red_points = [(d, r) for d, r, c in zip(dates, pass_rates, colors) if c == "red"]
|
||||
green_points = [(d, r) for d, r, c in zip(dates, pass_rates, colors) if c == "green"]
|
||||
orange_points = [(d, r) for d, r, c in zip(dates, pass_rates, colors) if c == "orange"]
|
||||
brown_points = [(d, r) for d, r, c in zip(dates, pass_rates, colors) if c == "brown"]
|
||||
pink_points = [(d, r) for d, r, c in zip(dates, pass_rates, colors) if c == "pink"]
|
||||
qwen_points = [(d, r) for d, r, c in zip(dates, pass_rates, colors) if c == "darkblue"]
|
||||
mistral_points = [(d, r) for d, r, c in zip(dates, pass_rates, colors) if c == "cyan"]
|
||||
|
||||
# Create a mapping of colors to first points and labels
|
||||
color_to_first_point = {}
|
||||
color_to_label = {}
|
||||
|
||||
for date, rate, color, model in sorted(zip(dates, pass_rates, colors, models)):
|
||||
if color not in color_to_first_point:
|
||||
color_to_first_point[color] = (date, rate)
|
||||
color_to_label[color] = get_legend_label(model)
|
||||
|
||||
# Plot lines and add labels at first points
|
||||
if purple_points:
|
||||
purple_dates, purple_rates = zip(*sorted(purple_points))
|
||||
ax.plot(purple_dates, purple_rates, c="purple", alpha=0.5, linewidth=1)
|
||||
if "purple" in color_to_first_point:
|
||||
date, rate = color_to_first_point["purple"]
|
||||
ax.annotate(
|
||||
color_to_label["purple"],
|
||||
(date, rate),
|
||||
xytext=(10, 5),
|
||||
textcoords="offset points",
|
||||
color="purple",
|
||||
alpha=0.8,
|
||||
fontsize=LABEL_FONT_SIZE,
|
||||
)
|
||||
|
||||
if red_points:
|
||||
red_dates, red_rates = zip(*sorted(red_points))
|
||||
ax.plot(red_dates, red_rates, c="red", alpha=0.5, linewidth=1)
|
||||
if "red" in color_to_first_point:
|
||||
date, rate = color_to_first_point["red"]
|
||||
ax.annotate(
|
||||
color_to_label["red"],
|
||||
(date, rate),
|
||||
xytext=(10, 5),
|
||||
textcoords="offset points",
|
||||
color="red",
|
||||
alpha=0.8,
|
||||
fontsize=LABEL_FONT_SIZE,
|
||||
)
|
||||
|
||||
if green_points:
|
||||
green_dates, green_rates = zip(*sorted(green_points))
|
||||
ax.plot(green_dates, green_rates, c="green", alpha=0.5, linewidth=1)
|
||||
if "green" in color_to_first_point:
|
||||
date, rate = color_to_first_point["green"]
|
||||
ax.annotate(
|
||||
color_to_label["green"],
|
||||
(date, rate),
|
||||
xytext=(10, 5),
|
||||
textcoords="offset points",
|
||||
color="green",
|
||||
alpha=0.8,
|
||||
fontsize=LABEL_FONT_SIZE,
|
||||
)
|
||||
|
||||
if orange_points:
|
||||
orange_dates, orange_rates = zip(*sorted(orange_points))
|
||||
ax.plot(orange_dates, orange_rates, c="orange", alpha=0.5, linewidth=1)
|
||||
if "orange" in color_to_first_point:
|
||||
date, rate = color_to_first_point["orange"]
|
||||
ax.annotate(
|
||||
color_to_label["orange"],
|
||||
(date, rate),
|
||||
xytext=(10, 5),
|
||||
textcoords="offset points",
|
||||
color="orange",
|
||||
alpha=0.8,
|
||||
fontsize=LABEL_FONT_SIZE,
|
||||
)
|
||||
|
||||
if brown_points:
|
||||
brown_dates, brown_rates = zip(*sorted(brown_points))
|
||||
ax.plot(brown_dates, brown_rates, c="brown", alpha=0.5, linewidth=1)
|
||||
if "brown" in color_to_first_point:
|
||||
date, rate = color_to_first_point["brown"]
|
||||
ax.annotate(
|
||||
color_to_label["brown"],
|
||||
(date, rate),
|
||||
xytext=(10, -10),
|
||||
textcoords="offset points",
|
||||
color="brown",
|
||||
alpha=0.8,
|
||||
fontsize=LABEL_FONT_SIZE,
|
||||
)
|
||||
|
||||
if pink_points:
|
||||
pink_dates, pink_rates = zip(*sorted(pink_points))
|
||||
ax.plot(pink_dates, pink_rates, c="pink", alpha=0.5, linewidth=1)
|
||||
if "pink" in color_to_first_point:
|
||||
date, rate = color_to_first_point["pink"]
|
||||
ax.annotate(
|
||||
color_to_label["pink"],
|
||||
(date, rate),
|
||||
xytext=(10, 5),
|
||||
textcoords="offset points",
|
||||
color="pink",
|
||||
alpha=0.8,
|
||||
fontsize=LABEL_FONT_SIZE,
|
||||
)
|
||||
|
||||
if qwen_points:
|
||||
qwen_dates, qwen_rates = zip(*sorted(qwen_points))
|
||||
ax.plot(qwen_dates, qwen_rates, c="darkblue", alpha=0.5, linewidth=1)
|
||||
if "darkblue" in color_to_first_point:
|
||||
date, rate = color_to_first_point["darkblue"]
|
||||
ax.annotate(
|
||||
color_to_label["darkblue"],
|
||||
(date, rate),
|
||||
xytext=(10, 5),
|
||||
textcoords="offset points",
|
||||
color="darkblue",
|
||||
alpha=0.8,
|
||||
fontsize=LABEL_FONT_SIZE,
|
||||
)
|
||||
|
||||
if mistral_points:
|
||||
mistral_dates, mistral_rates = zip(*sorted(mistral_points))
|
||||
ax.plot(mistral_dates, mistral_rates, c="cyan", alpha=0.5, linewidth=1)
|
||||
if "cyan" in color_to_first_point:
|
||||
date, rate = color_to_first_point["cyan"]
|
||||
ax.annotate(
|
||||
color_to_label["cyan"],
|
||||
(date, rate),
|
||||
xytext=(10, -10),
|
||||
textcoords="offset points",
|
||||
color="cyan",
|
||||
alpha=0.8,
|
||||
fontsize=LABEL_FONT_SIZE,
|
||||
)
|
||||
|
||||
# Plot points without legend
|
||||
for date, rate, color in zip(dates, pass_rates, colors):
|
||||
ax.scatter([date], [rate], c=[color], alpha=0.5, s=120)
|
||||
|
||||
ax.set_xlabel("Model release date", fontsize=18, color="#555")
|
||||
ax.set_ylabel(
|
||||
"Aider code editing benchmark,\npercent completed correctly", fontsize=18, color="#555"
|
||||
)
|
||||
ax.set_title("LLM code editing skill by model release date", fontsize=20)
|
||||
ax.set_ylim(30, 90) # Adjust y-axis limit to accommodate higher values
|
||||
plt.xticks(fontsize=14, rotation=45, ha="right") # Rotate x-axis labels for better readability
|
||||
plt.tight_layout(pad=1.0) # Adjust layout since we don't need room for legend anymore
|
||||
|
||||
print("Debug: Saving figures...")
|
||||
plt.savefig("tmp_over_time.png")
|
||||
plt.savefig("tmp_over_time.svg")
|
||||
|
||||
print("Debug: Displaying figure with imgcat...")
|
||||
imgcat(fig)
|
||||
|
||||
print("Debug: Figure generation complete.")
|
||||
|
||||
|
||||
# Example usage
|
||||
plot_over_time("aider/website/_data/edit_leaderboard.yml")
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue