hello, graph

This commit is contained in:
Paul Gauthier 2023-06-29 18:58:21 -07:00
parent c7ba0a5dba
commit e075fe9f0d

View file

@ -16,8 +16,8 @@ from typing import List
import git import git
import lox import lox
import pandas as pd
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import pandas as pd
import prompts import prompts
import typer import typer
from rich.console import Console from rich.console import Console
@ -38,21 +38,29 @@ def show_stats(dirnames):
rows = [] rows = []
for dirname in dirnames: for dirname in dirnames:
row = summarize_results(dirname) row = summarize_results(dirname)
if not row:
continue
if row.model == "gpt-3.5-turbo":
row.model = "gpt-3.5-turbo-0613"
rows.append(vars(row)) rows.append(vars(row))
df = pd.DataFrame.from_records(rows) df = pd.DataFrame.from_records(rows)
df.sort_values(by=["model", "edit_format"], inplace=True) df.sort_values(by=["model", "edit_format"], inplace=True)
df_grouped = df.groupby(["model", "edit_format"])["pass_rate_1"].mean() df_grouped = df.groupby(["model", "edit_format"])["pass_rate_1"].mean()
df_grouped.unstack().plot(kind='barh', figsize=(10, 6)) dump(df_grouped)
dump(df_grouped.unstack())
df_grouped.unstack().plot(kind="barh", figsize=(10, 6))
plt.xlabel('Pass Rate 1') plt.xlabel("Pass Rate 1")
plt.ylabel('Model') plt.ylabel("Model")
plt.title('Pass Rate 1 for each Model/Edit Format') plt.title("Pass Rate 1 for each Model/Edit Format")
plt.legend(title='Edit Format') plt.legend(title="Edit Format")
plt.show() plt.show()
df.to_csv("tmp.benchmark.csv") # df.to_csv("tmp.benchmarks.csv")
def resolve_dirname(dirname, use_single_prior, make_new): def resolve_dirname(dirname, use_single_prior, make_new):