style: Organize imports and apply linter formatting

This commit is contained in:
Paul Gauthier (aider) 2024-11-21 14:00:24 -08:00
parent 6d6d763dd3
commit c189a52e5e

View file

@ -1,17 +1,19 @@
from dataclasses import dataclass
from datetime import date
from typing import Dict, List, Tuple
import matplotlib.pyplot as plt
import yaml
from imgcat import imgcat
from matplotlib import rc
from dataclasses import dataclass
from typing import List, Tuple, Dict
from datetime import date
@dataclass
class ModelData:
name: str
release_date: date
pass_rate: float
@property
def color(self) -> str:
model = self.name.lower()
@ -56,9 +58,10 @@ class ModelData:
return "Mistral"
return model
class BenchmarkPlotter:
LABEL_FONT_SIZE = 16
def __init__(self):
self.setup_plot_style()
@ -71,14 +74,14 @@ class BenchmarkPlotter:
def load_data(self, yaml_file: str) -> List[ModelData]:
with open(yaml_file, "r") as file:
data = yaml.safe_load(file)
models = []
for entry in data:
if "released" in entry and "pass_rate_2" in entry:
model = ModelData(
name=entry["model"].split("(")[0].strip(),
release_date=entry["released"],
pass_rate=entry["pass_rate_2"]
pass_rate=entry["pass_rate_2"],
)
models.append(model)
return models
@ -104,13 +107,13 @@ class BenchmarkPlotter:
sorted_group = sorted(group, key=lambda x: x.release_date)
dates = [m.release_date for m in sorted_group]
rates = [m.pass_rate for m in sorted_group]
# Plot line
ax.plot(dates, rates, c=color, alpha=0.5, linewidth=1)
# Plot points
ax.scatter(dates, rates, c=color, alpha=0.5, s=120)
# Add label for first point
first_model = sorted_group[0]
ax.annotate(
@ -126,9 +129,7 @@ class BenchmarkPlotter:
def set_labels_and_style(self, ax: plt.Axes):
ax.set_xlabel("Model release date", fontsize=18, color="#555")
ax.set_ylabel(
"Aider code editing benchmark,\npercent completed correctly",
fontsize=18,
color="#555"
"Aider code editing benchmark,\npercent completed correctly", fontsize=18, color="#555"
)
ax.set_title("LLM code editing skill by model release date", fontsize=20)
ax.set_ylim(30, 90)
@ -147,9 +148,11 @@ class BenchmarkPlotter:
self.set_labels_and_style(ax)
self.save_and_display(fig)
def main():
plotter = BenchmarkPlotter()
plotter.plot("aider/website/_data/edit_leaderboard.yml")
if __name__ == "__main__":
main()