diff --git a/benchmark/benchmark.py b/benchmark/benchmark.py index bffd874e8..fcfc4512c 100755 --- a/benchmark/benchmark.py +++ b/benchmark/benchmark.py @@ -16,8 +16,8 @@ from typing import List import git import lox -import pandas as pd import matplotlib.pyplot as plt +import pandas as pd import prompts import typer from rich.console import Console @@ -38,21 +38,29 @@ def show_stats(dirnames): rows = [] for dirname in dirnames: row = summarize_results(dirname) + if not row: + continue + + if row.model == "gpt-3.5-turbo": + row.model = "gpt-3.5-turbo-0613" + rows.append(vars(row)) df = pd.DataFrame.from_records(rows) df.sort_values(by=["model", "edit_format"], inplace=True) df_grouped = df.groupby(["model", "edit_format"])["pass_rate_1"].mean() - df_grouped.unstack().plot(kind='barh', figsize=(10, 6)) + dump(df_grouped) + dump(df_grouped.unstack()) + df_grouped.unstack().plot(kind="barh", figsize=(10, 6)) - plt.xlabel('Pass Rate 1') - plt.ylabel('Model') - plt.title('Pass Rate 1 for each Model/Edit Format') - plt.legend(title='Edit Format') + plt.xlabel("Pass Rate 1") + plt.ylabel("Model") + plt.title("Pass Rate 1 for each Model/Edit Format") + plt.legend(title="Edit Format") plt.show() - df.to_csv("tmp.benchmark.csv") + # df.to_csv("tmp.benchmarks.csv") def resolve_dirname(dirname, use_single_prior, make_new):