diff --git a/benchmark/over_time.py b/benchmark/over_time.py index 4e5b6a8cf..78937242b 100644 --- a/benchmark/over_time.py +++ b/benchmark/over_time.py @@ -15,6 +15,9 @@ def get_model_color(model): if "qwen" in model.lower(): return "darkblue" + if "mistral" in model.lower(): + return "cyan" + if "haiku" in model.lower(): return "pink" @@ -89,6 +92,7 @@ def plot_over_time(yaml_file): brown_points = [(d, r) for d, r, c in zip(dates, pass_rates, colors) if c == "brown"] pink_points = [(d, r) for d, r, c in zip(dates, pass_rates, colors) if c == "pink"] qwen_points = [(d, r) for d, r, c in zip(dates, pass_rates, colors) if c == "darkblue"] + mistral_points = [(d, r) for d, r, c in zip(dates, pass_rates, colors) if c == "cyan"] # Plot lines for purple, red, green, orange and brown points if purple_points: @@ -112,6 +116,9 @@ def plot_over_time(yaml_file): if qwen_points: qwen_dates, qwen_rates = zip(*sorted(qwen_points)) ax.plot(qwen_dates, qwen_rates, c="darkblue", alpha=0.5, linewidth=1) + if mistral_points: + mistral_dates, mistral_rates = zip(*sorted(mistral_points)) + ax.plot(mistral_dates, mistral_rates, c="cyan", alpha=0.5, linewidth=1) # Plot all points ax.scatter(dates, pass_rates, c=colors, alpha=0.5, s=120)