diff --git a/benchmark/benchmark.py b/benchmark/benchmark.py index c2dd0980d..f0797f3d4 100755 --- a/benchmark/benchmark.py +++ b/benchmark/benchmark.py @@ -17,6 +17,7 @@ from typing import List import git import lox import pandas as pd +import matplotlib.pyplot as plt import prompts import typer from rich.console import Console @@ -40,8 +41,15 @@ def show_stats(dirnames): rows.append(vars(row)) df = pd.DataFrame.from_records(rows) + df.sort_values(by=["model", "edit_format"], inplace=True) - print(df) + df_grouped = df.groupby("model")["pass_rate_1"].mean() + df_grouped.plot(kind='barh', figsize=(10, 6)) + + plt.xlabel('Pass Rate 1') + plt.ylabel('Model') + plt.title('Pass Rate 1 for each Model/Edit Format') + plt.show() df.to_csv("tmp.benchmark.csv")