aider: Sort DataFrame by "model" and "edit_format" and plot a horizontal bar graph of the average "pass_rate_1" for each model.

# Context:
USER: in `show_stats` plot a horizontal bar graph. sort on "model" and "edit_format". group the bars by model. plot the pass_rate_1 for each model/edit_format.
ASSISTANT: To implement this change, we need to modify the `show_stats` function in the `benchmark/benchmark.py` file. We will use the pandas and matplotlib libraries to create the bar plot.

Here are the steps:

1. Import the matplotlib library at the top of the file.
2. In the `show_stats` function, after creating the DataFrame `df`, sort it based on "model" and "edit_format".
3. Group the DataFrame by "model" and plot a bar graph for "pass_rate_1".

Here are the changes:

```python
benchmark/benchmark.py
<<<<<<< ORIGINAL
import pandas as pd
=======
import pandas as pd
import matplotlib.pyplot as plt
>>>>>>> UPDATED
```

```python
benchmark/benchmark.py
<<<<<<< ORIGINAL
def show_stats(dirnames):
    rows = []
    for dirname in dirnames:
        row = summarize_results(dirname)
        rows.append(vars(row))

    df = pd.DataFrame.from_records(rows)

    print(df)

    df.to_csv("tmp.benchmark.csv")
=======
def show_stats(dirnames):
    rows = []
    for dirname in dirnames:
        row = summarize_results(dirname)
        rows.append(vars(row))

    df = pd.DataFrame.from_records(rows)
    df.sort_values(by=["model", "edit_format"], inplace=True)

    df_grouped = df.groupby("model")["pass_rate_1"].mean()
    df_grouped.plot(kind='barh', figsize=(10, 6))

    plt.xlabel('Pass Rate 1')
    plt.ylabel('Model')
    plt.title('Pass Rate 1 for each Model/Edit Format')
    plt.show()

    df.to_csv("tmp.benchmark.csv")
>>>>>>> UPDATED
```

Please note that this will create a bar graph with the average of "pass_rate_1" for each model. If you want to plot "pass_rate_1" for each "edit_format" within each model, we would need to modify the grouping and plotting code.
This commit is contained in:
Paul Gauthier 2023-06-29 18:50:59 -07:00
parent fe0aed2420
commit d5921871f7

View file

@ -17,6 +17,7 @@ from typing import List
import git import git
import lox import lox
import pandas as pd import pandas as pd
import matplotlib.pyplot as plt
import prompts import prompts
import typer import typer
from rich.console import Console from rich.console import Console
@ -40,8 +41,15 @@ def show_stats(dirnames):
rows.append(vars(row)) rows.append(vars(row))
df = pd.DataFrame.from_records(rows) df = pd.DataFrame.from_records(rows)
df.sort_values(by=["model", "edit_format"], inplace=True)
print(df) df_grouped = df.groupby("model")["pass_rate_1"].mean()
df_grouped.plot(kind='barh', figsize=(10, 6))
plt.xlabel('Pass Rate 1')
plt.ylabel('Model')
plt.title('Pass Rate 1 for each Model/Edit Format')
plt.show()
df.to_csv("tmp.benchmark.csv") df.to_csv("tmp.benchmark.csv")