diff --git a/benchmark/benchmark.py b/benchmark/benchmark.py index 599ec0efb..63401e9d1 100755 --- a/benchmark/benchmark.py +++ b/benchmark/benchmark.py @@ -41,6 +41,7 @@ def show_stats(dirnames): row = summarize_results(dirname) raw_rows.append(row) + seen = dict() rows = [] for row in raw_rows: if not row: @@ -52,12 +53,20 @@ def show_stats(dirnames): if row.completed_tests < 133: print(f"Warning: {row.dir_name} is incomplete: {row.completed_tests}") + kind = (row.model, row.edit_format) + if kind in seen: + dump(row.dir_name) + dump(seen[kind]) + return + seen[kind] = row.dir_name + rows.append(vars(row)) df = pd.DataFrame.from_records(rows) df.sort_values(by=["model", "edit_format"], inplace=True) - df_grouped = df.groupby(["model", "edit_format"])["pass_rate_1"].mean().sort_values(ascending=False) + df_grouped = df.groupby(["model", "edit_format"])["pass_rate_1"].mean() + fig, ax = plt.subplots(figsize=(10, 6)) df_grouped.unstack().plot(kind="barh", ax=ax)