diff --git a/benchmark/over_time.py b/benchmark/over_time.py index 2c2116628..6215c5411 100644 --- a/benchmark/over_time.py +++ b/benchmark/over_time.py @@ -57,6 +57,25 @@ def plot_over_time(yaml_file): ) for model in models ] + + # Separate data points by color + orange_points = [(d, r) for d, r, c in zip(dates, pass_rates, colors) if c == "orange"] + red_points = [(d, r) for d, r, c in zip(dates, pass_rates, colors) if c == "red"] + green_points = [(d, r) for d, r, c in zip(dates, pass_rates, colors) if c == "green"] + blue_points = [(d, r) for d, r, c in zip(dates, pass_rates, colors) if c == "blue"] + + # Plot lines for orange, red, and green points + if orange_points: + orange_dates, orange_rates = zip(*sorted(orange_points)) + ax.plot(orange_dates, orange_rates, c="orange", alpha=0.5, linewidth=1) + if red_points: + red_dates, red_rates = zip(*sorted(red_points)) + ax.plot(red_dates, red_rates, c="red", alpha=0.5, linewidth=1) + if green_points: + green_dates, green_rates = zip(*sorted(green_points)) + ax.plot(green_dates, green_rates, c="green", alpha=0.5, linewidth=1) + + # Plot all points ax.scatter(dates, pass_rates, c=colors, alpha=0.5, s=120) for i, model in enumerate(models):