plot the repeat error bars

This commit is contained in:
Paul Gauthier 2023-07-02 07:10:54 -07:00
parent ecb739ed07
commit db18876db6

View file

@ -54,6 +54,14 @@ def show_stats(dirnames):
if row.edit_format == "diff-func-string":
row.edit_format = "diff-func"
if (
row.model == "gpt-3.5-turbo-0613"
and row.edit_format == "whole"
and "repeat" not in row.dir_name
):
# remember this row, so we can update it with the repeat_avg
repeat_row = len(rows)
pieces = row.model.split("-")
row.model = "-".join(pieces[:3])
if pieces[3:]:
@ -71,9 +79,29 @@ def show_stats(dirnames):
dump(row.dir_name)
dump(seen[kind])
return
seen[kind] = row.dir_name
rows.append(vars(row))
if repeats:
extra = rows[repeat_row]
dump(extra)
repeats.append(extra)
repeats = pd.DataFrame.from_records(repeats)
repeat_max = repeats["pass_rate_2"].max()
repeat_min = repeats["pass_rate_2"].min()
repeat_avg = repeats["pass_rate_2"].mean()
repeat_lo = repeat_avg - repeat_min
repeat_hi = repeat_max - repeat_avg
dump(repeat_max)
dump(repeat_min)
dump(repeat_avg)
# use the average in the main bar
rows[repeat_row]["pass_rate_2"] = repeat_avg
df = pd.DataFrame.from_records(rows)
df.sort_values(by=["model", "edit_format"], inplace=True)
@ -128,14 +156,10 @@ def show_stats(dirnames):
if zorder == 2:
ax.bar_label(rects, padding=8, labels=[f"{v:.0f}%" for v in df[fmt]], size=14)
if repeats:
repeats = pd.DataFrame.from_records(repeats)
repeat_max = repeats["pass_rate_2"].max()
repeat_min = repeats["pass_rate_2"].min()
lo = 44 - repeat_min
hi = repeat_max - 44
ax.errorbar(1.4, 44, yerr=[[lo], [hi]], fmt="none", zorder=5, capsize=5)
if len(repeats):
ax.errorbar(
1.4, repeat_avg, yerr=[[repeat_lo], [repeat_hi]], fmt="none", zorder=5, capsize=5
)
ax.set_xticks([p + 1.5 * width for p in pos])
ax.set_xticklabels(models)