style: Organize imports and apply linter formatting

This commit is contained in:
Paul Gauthier (aider) 2024-11-21 14:00:24 -08:00
parent 6d6d763dd3
commit c189a52e5e

View file

@ -1,17 +1,19 @@
from dataclasses import dataclass
from datetime import date
from typing import Dict, List, Tuple
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import yaml import yaml
from imgcat import imgcat from imgcat import imgcat
from matplotlib import rc from matplotlib import rc
from dataclasses import dataclass
from typing import List, Tuple, Dict
from datetime import date
@dataclass @dataclass
class ModelData: class ModelData:
name: str name: str
release_date: date release_date: date
pass_rate: float pass_rate: float
@property @property
def color(self) -> str: def color(self) -> str:
model = self.name.lower() model = self.name.lower()
@ -56,9 +58,10 @@ class ModelData:
return "Mistral" return "Mistral"
return model return model
class BenchmarkPlotter: class BenchmarkPlotter:
LABEL_FONT_SIZE = 16 LABEL_FONT_SIZE = 16
def __init__(self): def __init__(self):
self.setup_plot_style() self.setup_plot_style()
@ -71,14 +74,14 @@ class BenchmarkPlotter:
def load_data(self, yaml_file: str) -> List[ModelData]: def load_data(self, yaml_file: str) -> List[ModelData]:
with open(yaml_file, "r") as file: with open(yaml_file, "r") as file:
data = yaml.safe_load(file) data = yaml.safe_load(file)
models = [] models = []
for entry in data: for entry in data:
if "released" in entry and "pass_rate_2" in entry: if "released" in entry and "pass_rate_2" in entry:
model = ModelData( model = ModelData(
name=entry["model"].split("(")[0].strip(), name=entry["model"].split("(")[0].strip(),
release_date=entry["released"], release_date=entry["released"],
pass_rate=entry["pass_rate_2"] pass_rate=entry["pass_rate_2"],
) )
models.append(model) models.append(model)
return models return models
@ -104,13 +107,13 @@ class BenchmarkPlotter:
sorted_group = sorted(group, key=lambda x: x.release_date) sorted_group = sorted(group, key=lambda x: x.release_date)
dates = [m.release_date for m in sorted_group] dates = [m.release_date for m in sorted_group]
rates = [m.pass_rate for m in sorted_group] rates = [m.pass_rate for m in sorted_group]
# Plot line # Plot line
ax.plot(dates, rates, c=color, alpha=0.5, linewidth=1) ax.plot(dates, rates, c=color, alpha=0.5, linewidth=1)
# Plot points # Plot points
ax.scatter(dates, rates, c=color, alpha=0.5, s=120) ax.scatter(dates, rates, c=color, alpha=0.5, s=120)
# Add label for first point # Add label for first point
first_model = sorted_group[0] first_model = sorted_group[0]
ax.annotate( ax.annotate(
@ -126,9 +129,7 @@ class BenchmarkPlotter:
def set_labels_and_style(self, ax: plt.Axes): def set_labels_and_style(self, ax: plt.Axes):
ax.set_xlabel("Model release date", fontsize=18, color="#555") ax.set_xlabel("Model release date", fontsize=18, color="#555")
ax.set_ylabel( ax.set_ylabel(
"Aider code editing benchmark,\npercent completed correctly", "Aider code editing benchmark,\npercent completed correctly", fontsize=18, color="#555"
fontsize=18,
color="#555"
) )
ax.set_title("LLM code editing skill by model release date", fontsize=20) ax.set_title("LLM code editing skill by model release date", fontsize=20)
ax.set_ylim(30, 90) ax.set_ylim(30, 90)
@ -147,9 +148,11 @@ class BenchmarkPlotter:
self.set_labels_and_style(ax) self.set_labels_and_style(ax)
self.save_and_display(fig) self.save_and_display(fig)
def main(): def main():
plotter = BenchmarkPlotter() plotter = BenchmarkPlotter()
plotter.plot("aider/website/_data/edit_leaderboard.yml") plotter.plot("aider/website/_data/edit_leaderboard.yml")
if __name__ == "__main__": if __name__ == "__main__":
main() main()