show with imgcat

This commit is contained in:
Paul Gauthier 2023-06-29 19:04:48 -07:00
parent 0f6a90bfed
commit 937ee478ad
2 changed files with 12 additions and 5 deletions

View file

@ -1,7 +1,7 @@
FROM python:3.8-slim FROM python:3.8-slim
RUN apt-get update && apt-get install -y less git RUN apt-get update && apt-get install -y less git
COPY requirements.txt /aider/requirements.txt COPY requirements.txt /aider/requirements.txt
RUN pip install lox typer pandas matplotlib imgcat aider-chat
RUN pip install --upgrade pip && pip install -r /aider/requirements.txt RUN pip install --upgrade pip && pip install -r /aider/requirements.txt
RUN pip install lox typer pandas matplotlib imgcat
WORKDIR /aider WORKDIR /aider

View file

@ -20,6 +20,7 @@ import matplotlib.pyplot as plt
import pandas as pd import pandas as pd
import prompts import prompts
import typer import typer
from imgcat import imgcat
from rich.console import Console from rich.console import Console
from aider import models from aider import models
@ -35,23 +36,28 @@ app = typer.Typer(add_completion=False, pretty_exceptions_enable=False)
def show_stats(dirnames): def show_stats(dirnames):
rows = [] raw_rows = []
for dirname in dirnames: for dirname in dirnames:
row = summarize_results(dirname) row = summarize_results(dirname)
raw_rows.append(row)
rows = []
for row in raw_rows:
if not row: if not row:
continue continue
if row.model == "gpt-3.5-turbo": if row.model == "gpt-3.5-turbo":
row.model = "gpt-3.5-turbo-0613" row.model = "gpt-3.5-turbo-0613"
if row.completed_tests < 133:
print(f"Warning: {row.dir_name} is incomplete: {row.completed_tests}")
rows.append(vars(row)) rows.append(vars(row))
df = pd.DataFrame.from_records(rows) df = pd.DataFrame.from_records(rows)
df.sort_values(by=["model", "edit_format"], inplace=True) df.sort_values(by=["model", "edit_format"], inplace=True)
df_grouped = df.groupby(["model", "edit_format"])["pass_rate_1"].mean() df_grouped = df.groupby(["model", "edit_format"])["pass_rate_1"].mean()
dump(df_grouped)
dump(df_grouped.unstack())
fig, ax = plt.subplots(figsize=(10, 6)) fig, ax = plt.subplots(figsize=(10, 6))
df_grouped.unstack().plot(kind="barh", ax=ax) df_grouped.unstack().plot(kind="barh", ax=ax)
@ -59,7 +65,8 @@ def show_stats(dirnames):
ax.set_ylabel("Model") ax.set_ylabel("Model")
ax.set_title("Pass Rate 1 for each Model/Edit Format") ax.set_title("Pass Rate 1 for each Model/Edit Format")
ax.legend(title="Edit Format") ax.legend(title="Edit Format")
plt.show()
imgcat(fig)
# df.to_csv("tmp.benchmarks.csv") # df.to_csv("tmp.benchmarks.csv")