aider: Sort DataFrame by "model" and "edit_format" and plot a horizontal bar graph of the average "pass_rate_1" for each model.

# Context:
USER: in `show_stats` plot a horizontal bar graph. sort on "model" and "edit_format". group the bars by model. plot the pass_rate_1 for each model/edit_format.
ASSISTANT: To implement this change, we need to modify the `show_stats` function in the `benchmark/benchmark.py` file. We will use the pandas and matplotlib libraries to create the bar plot.

Here are the steps:

1. Import the matplotlib library at the top of the file.
2. In the `show_stats` function, after creating the DataFrame `df`, sort it based on "model" and "edit_format".
3. Group the DataFrame by "model" and plot a bar graph for "pass_rate_1".

Here are the changes:

```python
benchmark/benchmark.py
<<<<<<< ORIGINAL
import pandas as pd
=======
import pandas as pd
import matplotlib.pyplot as plt
>>>>>>> UPDATED
```

```python
benchmark/benchmark.py
<<<<<<< ORIGINAL
def show_stats(dirnames):
    rows = []
    for dirname in dirnames:
        row = summarize_results(dirname)
        rows.append(vars(row))

    df = pd.DataFrame.from_records(rows)

    print(df)

    df.to_csv("tmp.benchmark.csv")
=======
def show_stats(dirnames):
    rows = []
    for dirname in dirnames:
        row = summarize_results(dirname)
        rows.append(vars(row))

    df = pd.DataFrame.from_records(rows)
    df.sort_values(by=["model", "edit_format"], inplace=True)

    df_grouped = df.groupby("model")["pass_rate_1"].mean()
    df_grouped.plot(kind='barh', figsize=(10, 6))

    plt.xlabel('Pass Rate 1')
    plt.ylabel('Model')
    plt.title('Pass Rate 1 for each Model/Edit Format')
    plt.show()

    df.to_csv("tmp.benchmark.csv")
>>>>>>> UPDATED
```

Please note that this will create a bar graph with the average of "pass_rate_1" for each model. If you want to plot "pass_rate_1" for each "edit_format" within each model, we would need to modify the grouping and plotting code.
This commit is contained in:
Paul Gauthier 2023-06-29 18:50:59 -07:00
parent fe0aed2420
commit d5921871f7

View file

@ -17,6 +17,7 @@ from typing import List
import git
import lox
import pandas as pd
import matplotlib.pyplot as plt
import prompts
import typer
from rich.console import Console
@ -40,8 +41,15 @@ def show_stats(dirnames):
rows.append(vars(row))
df = pd.DataFrame.from_records(rows)
df.sort_values(by=["model", "edit_format"], inplace=True)
print(df)
df_grouped = df.groupby("model")["pass_rate_1"].mean()
df_grouped.plot(kind='barh', figsize=(10, 6))
plt.xlabel('Pass Rate 1')
plt.ylabel('Model')
plt.title('Pass Rate 1 for each Model/Edit Format')
plt.show()
df.to_csv("tmp.benchmark.csv")