From 937ee478ade1a108b6a3ce5ca6b9c9297ed3b785 Mon Sep 17 00:00:00 2001 From: Paul Gauthier Date: Thu, 29 Jun 2023 19:04:48 -0700 Subject: [PATCH] show with imgcat --- benchmark/Dockerfile | 2 +- benchmark/benchmark.py | 15 +++++++++++---- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/benchmark/Dockerfile b/benchmark/Dockerfile index 3bbc61266..ad79affa6 100644 --- a/benchmark/Dockerfile +++ b/benchmark/Dockerfile @@ -1,7 +1,7 @@ FROM python:3.8-slim RUN apt-get update && apt-get install -y less git COPY requirements.txt /aider/requirements.txt +RUN pip install lox typer pandas matplotlib imgcat aider-chat RUN pip install --upgrade pip && pip install -r /aider/requirements.txt -RUN pip install lox typer pandas matplotlib imgcat WORKDIR /aider diff --git a/benchmark/benchmark.py b/benchmark/benchmark.py index 3fd8ebb60..a5c9de9b0 100755 --- a/benchmark/benchmark.py +++ b/benchmark/benchmark.py @@ -20,6 +20,7 @@ import matplotlib.pyplot as plt import pandas as pd import prompts import typer +from imgcat import imgcat from rich.console import Console from aider import models @@ -35,23 +36,28 @@ app = typer.Typer(add_completion=False, pretty_exceptions_enable=False) def show_stats(dirnames): - rows = [] + raw_rows = [] for dirname in dirnames: row = summarize_results(dirname) + raw_rows.append(row) + + rows = [] + for row in raw_rows: if not row: continue if row.model == "gpt-3.5-turbo": row.model = "gpt-3.5-turbo-0613" + if row.completed_tests < 133: + print(f"Warning: {row.dir_name} is incomplete: {row.completed_tests}") + rows.append(vars(row)) df = pd.DataFrame.from_records(rows) df.sort_values(by=["model", "edit_format"], inplace=True) df_grouped = df.groupby(["model", "edit_format"])["pass_rate_1"].mean() - dump(df_grouped) - dump(df_grouped.unstack()) fig, ax = plt.subplots(figsize=(10, 6)) df_grouped.unstack().plot(kind="barh", ax=ax) @@ -59,7 +65,8 @@ def show_stats(dirnames): ax.set_ylabel("Model") ax.set_title("Pass Rate 1 for each Model/Edit Format") ax.legend(title="Edit Format") - plt.show() + + imgcat(fig) # df.to_csv("tmp.benchmarks.csv")