From 1cdbc769746a838d12f9a9b344dffc5e1868d671 Mon Sep 17 00:00:00 2001 From: "Paul Gauthier (aider)" Date: Wed, 14 Aug 2024 06:27:48 -0700 Subject: [PATCH] feat: Connect model family lines in over_time plot --- benchmark/over_time.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/benchmark/over_time.py b/benchmark/over_time.py index 2c2116628..6215c5411 100644 --- a/benchmark/over_time.py +++ b/benchmark/over_time.py @@ -57,6 +57,25 @@ def plot_over_time(yaml_file): ) for model in models ] + + # Separate data points by color + orange_points = [(d, r) for d, r, c in zip(dates, pass_rates, colors) if c == "orange"] + red_points = [(d, r) for d, r, c in zip(dates, pass_rates, colors) if c == "red"] + green_points = [(d, r) for d, r, c in zip(dates, pass_rates, colors) if c == "green"] + blue_points = [(d, r) for d, r, c in zip(dates, pass_rates, colors) if c == "blue"] + + # Plot lines for orange, red, and green points + if orange_points: + orange_dates, orange_rates = zip(*sorted(orange_points)) + ax.plot(orange_dates, orange_rates, c="orange", alpha=0.5, linewidth=1) + if red_points: + red_dates, red_rates = zip(*sorted(red_points)) + ax.plot(red_dates, red_rates, c="red", alpha=0.5, linewidth=1) + if green_points: + green_dates, green_rates = zip(*sorted(green_points)) + ax.plot(green_dates, green_rates, c="green", alpha=0.5, linewidth=1) + + # Plot all points ax.scatter(dates, pass_rates, c=colors, alpha=0.5, s=120) for i, model in enumerate(models):