diff --git a/benchmark/benchmark.py b/benchmark/benchmark.py index 972a80480..58e33af62 100755 --- a/benchmark/benchmark.py +++ b/benchmark/benchmark.py @@ -146,14 +146,45 @@ def plot_timing(df): formats = grouped.columns models = grouped.index + # Create a mapping for the labels and colors to ensure consistency + format_labels = { + 'diff': 'Diff', + 'func': 'Function', + 'whole': 'Whole', + 'diff-func': 'Diff-Function', + 'diff-whole': 'Diff-Whole', + 'func-whole': 'Function-Whole', + # Add other edit formats here if necessary + } + format_colors = { + 'diff': '#b3e6a8', + 'func': '#b3d1e6', + 'whole': '#e6b3b3', + 'diff-func': '#e6d1b3', + 'diff-whole': '#b3e6d1', + 'func-whole': '#d1b3e6', + # Add other edit format colors here if necessary + } + format_hatches = { + 'diff': '////', + 'func': '....', + 'whole': '', + 'diff-func': '\\\\\\\\', + 'diff-whole': 'xxxx', + 'func-whole': '----', + # Add other edit format hatches here if necessary + } + for i, fmt in enumerate(formats): + label = format_labels.get(fmt, fmt) + color = format_colors.get(fmt, '#b3d1e6') + hatch = format_hatches.get(fmt, '') edge = dict(edgecolor="#ffffff", linewidth=1.5) - color = "#b3e6a8" if "diff" in fmt else "#b3d1e6" - hatch = "////" if "func" in fmt else "" rects = ax.bar( pos + i * width, grouped[fmt], width * 0.95, + label=label, color=color, hatch=hatch, zorder=zorder + 1,