full swe bench graph

This commit is contained in:
Paul Gauthier 2024-05-31 08:16:49 -07:00
parent 20ee2bbf36
commit 639121bb45

View file

@ -1,8 +1,10 @@
import sys
from pathlib import Path
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from imgcat import imgcat from imgcat import imgcat
from matplotlib import rc from matplotlib import rc
import sys
from pathlib import Path
def plot_swe_bench_lite(data_file): def plot_swe_bench_lite(data_file):
with open(data_file, "r") as file: with open(data_file, "r") as file:
@ -32,7 +34,7 @@ def plot_swe_bench_lite(data_file):
rc("font", **font_params) rc("font", **font_params)
plt.rcParams["text.color"] = font_color plt.rcParams["text.color"] = font_color
fig, ax = plt.subplots(figsize=(10, 5)) fig, ax = plt.subplots(figsize=(10, 6))
ax.grid(axis="y", zorder=0, lw=0.2) ax.grid(axis="y", zorder=0, lw=0.2)
for spine in ax.spines.values(): for spine in ax.spines.values():
spine.set_edgecolor("#DDDDDD") spine.set_edgecolor("#DDDDDD")
@ -43,7 +45,7 @@ def plot_swe_bench_lite(data_file):
for model, pass_rate, color in zip(models, pass_rates, colors): for model, pass_rate, color in zip(models, pass_rates, colors):
alpha = 0.9 if "Aider" in model else 0.3 alpha = 0.9 if "Aider" in model else 0.3
hatch = "" hatch = ""
#if "lite" not in data_file: # if "lite" not in data_file:
# hatch = "///" if "(570)" in model else "" # hatch = "///" if "(570)" in model else ""
bar = ax.bar(model, pass_rate, color=color, alpha=alpha, zorder=3, hatch=hatch) bar = ax.bar(model, pass_rate, color=color, alpha=alpha, zorder=3, hatch=hatch)
bars.append(bar[0]) bars.append(bar[0])
@ -78,19 +80,21 @@ def plot_swe_bench_lite(data_file):
color=font_color, color=font_color,
) )
# Add note at the bottom of the graph # Add note at the bottom of the graph
note = "Note: (570) and (2294) refer to the number of instances that were processed by the agent." note = (
"Note: (570) and (2294) refer to the number of SWE Bench instances that were benchmarked."
)
plt.figtext( plt.figtext(
0.5, 0.025, 0.5,
0.025,
note, note,
wrap=True, wrap=True,
horizontalalignment='center', horizontalalignment="center",
fontsize=12, fontsize=12,
color=font_color, color=font_color,
) )
plt.tight_layout(pad=5.0) plt.tight_layout(pad=3.0)
out_fname = Path(data_file) out_fname = Path(data_file)
plt.savefig(out_fname.with_suffix(".jpg").name) plt.savefig(out_fname.with_suffix(".jpg").name)