Merge branch 'main' into fix-flake8

This commit is contained in:
Paul Gauthier 2024-06-21 17:01:57 -07:00
commit 10ab57779a
44 changed files with 1249 additions and 188 deletions

View file

@ -1 +1 @@
__version__ = "0.38.1-dev"
__version__ = "0.39.1-dev"

4
aider/__main__.py Normal file
View file

@ -0,0 +1,4 @@
from .main import main
if __name__ == "__main__":
main()

View file

@ -7,11 +7,30 @@ import sys
import configargparse
from aider import __version__, models
from aider.args_formatter import MarkdownHelpFormatter, YamlHelpFormatter
from aider.args_formatter import (
DotEnvFormatter,
MarkdownHelpFormatter,
YamlHelpFormatter,
)
from .dump import dump # noqa: F401
def default_env_file(git_root):
return os.path.join(git_root, ".env") if git_root else ".env"
def get_preparser(git_root):
parser = configargparse.ArgumentParser(add_help=False)
parser.add_argument(
"--env-file",
metavar="ENV_FILE",
default=default_env_file(git_root),
help="Specify the .env file to load (default: .env in git root)",
)
return parser
def get_parser(default_config_files, git_root):
parser = configargparse.ArgumentParser(
description="aider is GPT powered coding in your terminal",
@ -28,10 +47,7 @@ def get_parser(default_config_files, git_root):
help="Log the conversation with the LLM to this file (for example, .aider.llm.history)",
)
group.add_argument(
"files",
metavar="FILE",
nargs="*",
help="files to edit with an LLM (optional)"
"files", metavar="FILE", nargs="*", help="files to edit with an LLM (optional)"
)
group.add_argument(
"--openai-api-key",
@ -60,7 +76,7 @@ def get_parser(default_config_files, git_root):
const=opus_model,
help=f"Use {opus_model} model for the main chat",
)
sonnet_model = "claude-3-sonnet-20240229"
sonnet_model = "claude-3-5-sonnet-20240620"
group.add_argument(
"--sonnet",
action="store_const",
@ -142,12 +158,24 @@ def get_parser(default_config_files, git_root):
env_var="OPENAI_ORGANIZATION_ID",
help="Specify the OpenAI organization ID",
)
group.add_argument(
"--model-settings-file",
metavar="MODEL_SETTINGS_FILE",
default=None,
help="Specify a file with aider model settings for unknown models",
)
group.add_argument(
"--model-metadata-file",
metavar="MODEL_FILE",
metavar="MODEL_METADATA_FILE",
default=None,
help="Specify a file with context window and costs for unknown models",
)
group.add_argument(
"--verify-ssl",
action=argparse.BooleanOptionalAction,
default=True,
help="Verify the SSL cert when connecting to models (default: True)",
)
group.add_argument(
"--edit-format",
metavar="EDIT_FORMAT",
@ -184,11 +212,12 @@ def get_parser(default_config_files, git_root):
" max_chat_history_tokens."
),
)
default_env_file = os.path.join(git_root, ".env") if git_root else ".env"
# This is a duplicate of the argument in the preparser and is a no-op by this time of
# argument parsing, but it's here so that the help is displayed as expected.
group.add_argument(
"--env-file",
metavar="ENV_FILE",
default=default_env_file,
default=default_env_file(git_root),
help="Specify the .env file to load (default: .env in git root)",
)
@ -501,11 +530,27 @@ def get_sample_yaml():
return parser.format_help()
def get_sample_dotenv():
os.environ["COLUMNS"] = "120"
sys.argv = ["aider"]
parser = get_parser([], None)
# This instantiates all the action.env_var values
parser.parse_known_args()
parser.formatter_class = DotEnvFormatter
return argparse.ArgumentParser.format_help(parser)
return parser.format_help()
def main():
arg = sys.argv[1] if len(sys.argv[1:]) else None
if arg == "md":
print(get_md_help())
elif arg == "dotenv":
print(get_sample_dotenv())
else:
print(get_sample_yaml())

View file

@ -1,8 +1,83 @@
import argparse
from aider import urls
from .dump import dump # noqa: F401
class DotEnvFormatter(argparse.HelpFormatter):
def start_section(self, heading):
res = "\n\n"
res += "#" * (len(heading) + 3)
res += f"\n# {heading}"
super().start_section(res)
def _format_usage(self, usage, actions, groups, prefix):
return ""
def _format_text(self, text):
return f"""
##########################################################
# Sample aider .env file.
# Place at the root of your git repo.
# Or use `aider --env <fname>` to specify.
##########################################################
#################
# LLM parameters:
#
# Include xxx_API_KEY parameters and other params needed for your LLMs.
# See {urls.llms} for details.
## OpenAI
#OPENAI_API_KEY=
## Anthropic
#ANTHROPIC_API_KEY=
##...
"""
def _format_action(self, action):
if not action.option_strings:
return ""
if not action.env_var:
return
parts = [""]
default = action.default
if default == argparse.SUPPRESS:
default = ""
elif isinstance(default, str):
pass
elif isinstance(default, list) and not default:
default = ""
elif action.default is not None:
default = "true" if default else "false"
else:
default = ""
if action.help:
parts.append(f"## {action.help}")
if action.env_var:
env_var = action.env_var
if default:
parts.append(f"#{env_var}={default}\n")
else:
parts.append(f"#{env_var}=\n")
return "\n".join(parts) + "\n"
def _format_action_invocation(self, action):
return ""
def _format_args(self, action, default_metavar):
return ""
class YamlHelpFormatter(argparse.HelpFormatter):
def start_section(self, heading):
res = "\n\n"

View file

@ -913,9 +913,9 @@ class Coder:
res = ["", ""]
res.append(f"Model {self.main_model.name} has hit a token limit!")
res.append("")
res.append(f"Input tokens: {input_tokens} of {max_input_tokens}{inp_err}")
res.append(f"Output tokens: {output_tokens} of {max_output_tokens}{out_err}")
res.append(f"Total tokens: {total_tokens} of {max_input_tokens}{tot_err}")
res.append(f"Input tokens: {input_tokens:,} of {max_input_tokens:,}{inp_err}")
res.append(f"Output tokens: {output_tokens:,} of {max_output_tokens:,}{out_err}")
res.append(f"Total tokens: {total_tokens:,} of {max_input_tokens:,}{tot_err}")
if output_tokens >= max_output_tokens:
res.append("")

View file

@ -414,16 +414,8 @@ def find_original_update_blocks(content, fence=DEFAULT_FENCE):
processed.append(cur) # original_marker
filename = strip_filename(processed[-2].splitlines()[-1], fence)
try:
if not filename:
filename = strip_filename(processed[-2].splitlines()[-2], fence)
if not filename:
if current_filename:
filename = current_filename
else:
raise ValueError(missing_filename_err.format(fence=fence))
except IndexError:
filename = find_filename(processed[-2].splitlines(), fence)
if not filename:
if current_filename:
filename = current_filename
else:
@ -460,6 +452,35 @@ def find_original_update_blocks(content, fence=DEFAULT_FENCE):
raise ValueError(f"{processed}\n^^^ Error parsing SEARCH/REPLACE block.")
def find_filename(lines, fence):
"""
Deepseek Coder v2 has been doing this:
```python
word_count.py
```
```python
<<<<<<< SEARCH
...
This is a more flexible search back for filenames.
"""
# Go back through the 3 preceding lines
lines.reverse()
lines = lines[:3]
for line in lines:
# If we find a filename, done
filename = strip_filename(line, fence)
if filename:
return filename
# Only continue as long as we keep seeing fences
if not line.startswith(fence[0]):
return
if __name__ == "__main__":
edit = """
Here's the change:

View file

@ -332,7 +332,7 @@ class Commands:
last_commit = self.coder.repo.repo.head.commit
if (
not last_commit.message.startswith("aider:")
not last_commit.author.name.endswith(" (aider)")
or last_commit.hexsha[:7] != self.coder.last_aider_commit_hash
):
self.io.tool_error("The last commit was not made by aider in this chat session.")

View file

@ -5,12 +5,13 @@ import sys
from pathlib import Path
import git
import httpx
from dotenv import load_dotenv
from prompt_toolkit.enums import EditingMode
from streamlit.web import cli
from aider import __version__, models, utils
from aider.args import get_parser
from aider.args import get_parser, get_preparser
from aider.coders import Coder
from aider.commands import SwitchModel
from aider.io import InputOutput
@ -124,12 +125,18 @@ def check_gitignore(git_root, io, ask=True):
def format_settings(parser, args):
show = scrub_sensitive_info(args, parser.format_values())
# clean up the headings for consistency w/ new lines
heading_env = "Environment Variables:"
heading_defaults = "Defaults:"
if heading_env in show:
show = show.replace(heading_env, "\n" + heading_env)
show = show.replace(heading_defaults, "\n" + heading_defaults)
show += "\n"
show += "Option settings:\n"
for arg, val in sorted(vars(args).items()):
if val:
val = scrub_sensitive_info(args, str(val))
show += f" - {arg}: {val}\n"
show += f" - {arg}: {val}\n" # noqa: E221
return show
@ -205,7 +212,48 @@ def parse_lint_cmds(lint_cmds, io):
return
return res
def generate_search_path_list(default_fname, git_root, command_line_file):
files = []
default_file = Path(default_fname)
files.append(Path.home() / default_file) # homedir
if git_root:
files.append(Path(git_root) / default_file) # git root
if command_line_file:
files.append(command_line_file)
files.append(default_file.resolve())
files = list(map(str, files))
files = list(dict.fromkeys(files))
return files
def register_models(git_root, model_settings_fname, io):
model_settings_files = generate_search_path_list(".aider.models.yml", git_root, model_settings_fname)
try:
files_loaded = models.register_models(model_settings_files)
if len(files_loaded) > 0:
io.tool_output(f"Loaded {len(files_loaded)} model settings file(s)")
for file_loaded in files_loaded:
io.tool_output(f" - {file_loaded}")
except Exception as e:
io.tool_error(f"Error loading aider model settings: {e}")
return 1
return None
def register_litellm_models(git_root, model_metadata_fname, io):
model_metatdata_files = generate_search_path_list(".aider.litellm.models.json", git_root, model_metadata_fname)
try:
model_metadata_files_loaded = models.register_litellm_models(model_metatdata_files)
if len(model_metadata_files_loaded) > 0:
io.tool_output(f"Loaded {len(model_metadata_files_loaded)} litellm model file(s)")
for model_metadata_file in model_metadata_files_loaded:
io.tool_output(f" - {model_metadata_file}")
except Exception as e:
io.tool_error(f"Error loading litellm models: {e}")
return 1
def main(argv=None, input=None, output=None, force_git_root=None, return_coder=False):
if argv is None:
argv = sys.argv[1:]
@ -225,9 +273,18 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F
default_config_files.append(Path.home() / conf_fname) # homedir
default_config_files = list(map(str, default_config_files))
preparser = get_preparser(git_root)
pre_args, _ = preparser.parse_known_args(argv)
# Load the .env file specified in the arguments
load_dotenv(pre_args.env_file)
parser = get_parser(default_config_files, git_root)
args = parser.parse_args(argv)
if not args.verify_ssl:
litellm.client_session = httpx.Client(verify=False)
if args.gui and not return_coder:
launch_gui(argv)
return
@ -320,9 +377,6 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F
cmd_line = scrub_sensitive_info(args, cmd_line)
io.tool_output(cmd_line, log_only=True)
if args.env_file:
load_dotenv(args.env_file)
if args.anthropic_api_key:
os.environ["ANTHROPIC_API_KEY"] = args.anthropic_api_key
@ -337,26 +391,9 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F
if args.openai_organization_id:
os.environ["OPENAI_ORGANIZATION"] = args.openai_organization_id
model_def_files = []
model_def_fname = Path(".aider.models.json")
model_def_files.append(Path.home() / model_def_fname) # homedir
if git_root:
model_def_files.append(Path(git_root) / model_def_fname) # git root
if args.model_metadata_file:
model_def_files.append(args.model_metadata_file)
model_def_files.append(model_def_fname.resolve())
model_def_files = list(map(str, model_def_files))
model_def_files = list(dict.fromkeys(model_def_files))
try:
model_metadata_files_loaded = models.register_models(model_def_files)
if len(model_metadata_files_loaded) > 0:
io.tool_output(f"Loaded {len(model_metadata_files_loaded)} model file(s)")
for model_metadata_file in model_metadata_files_loaded:
io.tool_output(f" - {model_metadata_file}")
except Exception as e:
io.tool_error(f"Error loading model info/cost: {e}")
return 1
register_models(git_root, args.model_settings_file, io)
register_litellm_models(git_root, args.model_metadata_file, io)
main_model = models.Model(args.model, weak_model=args.weak_model)
lint_cmds = parse_lint_cmds(args.lint_cmd, io)

View file

@ -1,5 +1,6 @@
import difflib
import json
import yaml
import math
import os
import sys
@ -178,6 +179,43 @@ MODEL_SETTINGS = [
"whole",
weak_model_name="claude-3-haiku-20240307",
),
ModelSettings(
"claude-3-5-sonnet-20240620",
"diff",
weak_model_name="claude-3-haiku-20240307",
use_repo_map=True,
),
ModelSettings(
"anthropic/claude-3-5-sonnet-20240620",
"diff",
weak_model_name="claude-3-haiku-20240307",
use_repo_map=True,
),
ModelSettings(
"openrouter/anthropic/claude-3.5-sonnet",
"diff",
weak_model_name="openrouter/anthropic/claude-3-haiku-20240307",
use_repo_map=True,
),
# Vertex AI Claude models
ModelSettings(
"vertex_ai/claude-3-5-sonnet@20240620",
"diff",
weak_model_name="vertex_ai/claude-3-haiku@20240307",
use_repo_map=True,
),
ModelSettings(
"vertex_ai/claude-3-opus@20240229",
"diff",
weak_model_name="vertex_ai/claude-3-haiku@20240307",
use_repo_map=True,
send_undo_reply=True,
),
ModelSettings(
"vertex_ai/claude-3-sonnet@20240229",
"whole",
weak_model_name="vertex_ai/claude-3-haiku@20240307",
),
# Cohere
ModelSettings(
"command-r-plus",
@ -218,7 +256,7 @@ MODEL_SETTINGS = [
send_undo_reply=True,
),
ModelSettings(
"openai/deepseek-chat",
"deepseek/deepseek-chat",
"diff",
use_repo_map=True,
send_undo_reply=True,
@ -226,7 +264,15 @@ MODEL_SETTINGS = [
reminder_as_sys_msg=True,
),
ModelSettings(
"deepseek/deepseek-chat",
"deepseek/deepseek-coder",
"diff",
use_repo_map=True,
send_undo_reply=True,
examples_as_sys_msg=True,
reminder_as_sys_msg=True,
),
ModelSettings(
"openrouter/deepseek/deepseek-coder",
"diff",
use_repo_map=True,
send_undo_reply=True,
@ -425,23 +471,47 @@ class Model:
return validate_variables(["GROQ_API_KEY"])
return res
def register_models(model_def_fnames):
model_metadata_files_loaded = []
for model_def_fname in model_def_fnames:
if not os.path.exists(model_def_fname):
def register_models(model_settings_fnames):
files_loaded = []
for model_settings_fname in model_settings_fnames:
if not os.path.exists(model_settings_fname):
continue
model_metadata_files_loaded.append(model_def_fname)
try:
with open(model_def_fname, "r") as model_def_file:
with open(model_settings_fname, "r") as model_settings_file:
model_settings_list = yaml.safe_load(model_settings_file)
for model_settings_dict in model_settings_list:
model_settings = ModelSettings(**model_settings_dict)
existing_model_settings = next((ms for ms in MODEL_SETTINGS if ms.name == model_settings.name), None)
if existing_model_settings:
MODEL_SETTINGS.remove(existing_model_settings)
MODEL_SETTINGS.append(model_settings)
except Exception as e:
raise Exception(f"Error loading model settings from {model_settings_fname}: {e}")
files_loaded.append(model_settings_fname)
return files_loaded
def register_litellm_models(model_fnames):
files_loaded = []
for model_fname in model_fnames:
if not os.path.exists(model_fname):
continue
try:
with open(model_fname, "r") as model_def_file:
model_def = json.load(model_def_file)
except json.JSONDecodeError as e:
raise Exception(f"Error loading model definition from {model_def_fname}: {e}")
litellm.register_model(model_def)
except Exception as e:
raise Exception(f"Error loading model definition from {model_fname}: {e}")
files_loaded.append(model_fname)
litellm.register_model(model_def)
return model_metadata_files_loaded
return files_loaded
def validate_variables(vars):

View file

@ -88,11 +88,15 @@ class GitRepo:
else:
cmd += ["-a"]
original_user_name = self.repo.config_reader().get_value("user", "name")
original_committer_name_env = os.environ.get("GIT_COMMITTER_NAME")
committer_name = f"{original_user_name} (aider)"
os.environ["GIT_COMMITTER_NAME"] = committer_name
if aider_edits:
user_name = self.repo.config_reader().get_value("user", "name")
committer_name = f"{user_name} (aider)"
original_committer_name = os.environ.get("GIT_COMMITTER_NAME")
os.environ["GIT_COMMITTER_NAME"] = committer_name
original_auther_name_env = os.environ.get("GIT_AUTHOR_NAME")
os.environ["GIT_AUTHOR_NAME"] = committer_name
self.repo.git.commit(cmd)
commit_hash = self.repo.head.commit.hexsha[:7]
@ -100,10 +104,15 @@ class GitRepo:
# Restore the original GIT_COMMITTER_NAME
if aider_edits:
if original_committer_name is not None:
os.environ["GIT_COMMITTER_NAME"] = original_committer_name
if original_auther_name_env is not None:
os.environ["GIT_AUTHOR_NAME"] = original_auther_name_env
else:
del os.environ["GIT_COMMITTER_NAME"]
del os.environ["GIT_AUTHOR_NAME"]
if original_committer_name_env is not None:
os.environ["GIT_COMMITTER_NAME"] = original_committer_name_env
else:
del os.environ["GIT_COMMITTER_NAME"]
return commit_hash, commit_message

View file

@ -523,16 +523,20 @@ class TestCommands(TestCase):
other_path.write_text("other content")
repo.git.add(str(other_path))
os.environ["GIT_AUTHOR_NAME"] = "Foo (aider)"
# Create and commit a file
filename = "test_file.txt"
file_path = Path(repo_dir) / filename
file_path.write_text("first content")
repo.git.add(filename)
repo.git.commit("-m", "aider: first commit")
repo.git.commit("-m", "first commit")
file_path.write_text("second content")
repo.git.add(filename)
repo.git.commit("-m", "aider: second commit")
repo.git.commit("-m", "second commit")
del os.environ["GIT_AUTHOR_NAME"]
# Store the commit hash
last_commit_hash = repo.head.commit.hexsha[:7]

View file

@ -398,6 +398,32 @@ Hope you like it!
],
)
def test_deepseek_coder_v2_filename_mangling(self):
edit = """
Here's the change:
```python
foo.txt
```
```python
<<<<<<< SEARCH
one
=======
two
>>>>>>> REPLACE
```
Hope you like it!
"""
edits = list(eb.find_original_update_blocks(edit))
self.assertEqual(
edits,
[
("foo.txt", "one\n", "two\n"),
],
)
if __name__ == "__main__":
unittest.main()

View file

@ -1,7 +1,7 @@
import os
import shutil
import subprocess
import tempfile
from io import StringIO
from pathlib import Path
from unittest import TestCase
from unittest.mock import MagicMock, patch
@ -13,24 +13,28 @@ from prompt_toolkit.output import DummyOutput
from aider.dump import dump # noqa: F401
from aider.io import InputOutput
from aider.main import check_gitignore, main, setup_git
from aider.utils import GitTemporaryDirectory, make_repo
from aider.utils import GitTemporaryDirectory, IgnorantTemporaryDirectory, make_repo
class TestMain(TestCase):
def setUp(self):
self.original_env = os.environ.copy()
os.environ["OPENAI_API_KEY"] = "deadbeef"
self.original_cwd = os.getcwd()
self.tempdir = tempfile.mkdtemp()
self.tempdir_obj = IgnorantTemporaryDirectory()
self.tempdir = self.tempdir_obj.name
os.chdir(self.tempdir)
def tearDown(self):
os.chdir(self.original_cwd)
shutil.rmtree(self.tempdir, ignore_errors=True)
self.tempdir_obj.cleanup()
os.environ.clear()
os.environ.update(self.original_env)
def test_main_with_empty_dir_no_files_on_command(self):
main(["--no-git"], input=DummyInput(), output=DummyOutput())
def test_main_with_empty_dir_new_file(self):
def test_main_with_emptqy_dir_new_file(self):
main(["foo.txt", "--yes", "--no-git"], input=DummyInput(), output=DummyOutput())
self.assertTrue(os.path.exists("foo.txt"))
@ -237,3 +241,82 @@ class TestMain(TestCase):
main(["--message", test_message])
args, kwargs = MockInputOutput.call_args
self.assertEqual(args[1], None)
def test_dark_mode_sets_code_theme(self):
# Mock Coder.create to capture the configuration
with patch("aider.coders.Coder.create") as MockCoder:
main(["--dark-mode", "--no-git"], input=DummyInput(), output=DummyOutput())
# Ensure Coder.create was called
MockCoder.assert_called_once()
# Check if the code_theme setting is for dark mode
_, kwargs = MockCoder.call_args
self.assertEqual(kwargs["code_theme"], "monokai")
def test_light_mode_sets_code_theme(self):
# Mock Coder.create to capture the configuration
with patch("aider.coders.Coder.create") as MockCoder:
main(["--light-mode", "--no-git"], input=DummyInput(), output=DummyOutput())
# Ensure Coder.create was called
MockCoder.assert_called_once()
# Check if the code_theme setting is for light mode
_, kwargs = MockCoder.call_args
self.assertEqual(kwargs["code_theme"], "default")
def create_env_file(self, file_name, content):
env_file_path = Path(self.tempdir) / file_name
env_file_path.write_text(content)
return env_file_path
def test_env_file_flag_sets_automatic_variable(self):
env_file_path = self.create_env_file(".env.test", "AIDER_DARK_MODE=True")
with patch("aider.coders.Coder.create") as MockCoder:
main(
["--env-file", str(env_file_path), "--no-git"],
input=DummyInput(),
output=DummyOutput(),
)
MockCoder.assert_called_once()
# Check if the color settings are for dark mode
_, kwargs = MockCoder.call_args
self.assertEqual(kwargs["code_theme"], "monokai")
def test_default_env_file_sets_automatic_variable(self):
self.create_env_file(".env", "AIDER_DARK_MODE=True")
with patch("aider.coders.Coder.create") as MockCoder:
main(["--no-git"], input=DummyInput(), output=DummyOutput())
# Ensure Coder.create was called
MockCoder.assert_called_once()
# Check if the color settings are for dark mode
_, kwargs = MockCoder.call_args
self.assertEqual(kwargs["code_theme"], "monokai")
def test_false_vals_in_env_file(self):
self.create_env_file(".env", "AIDER_SHOW_DIFFS=off")
with patch("aider.coders.Coder.create") as MockCoder:
main(["--no-git"], input=DummyInput(), output=DummyOutput())
MockCoder.assert_called_once()
_, kwargs = MockCoder.call_args
self.assertEqual(kwargs["show_diffs"], False)
def test_true_vals_in_env_file(self):
self.create_env_file(".env", "AIDER_SHOW_DIFFS=on")
with patch("aider.coders.Coder.create") as MockCoder:
main(["--no-git"], input=DummyInput(), output=DummyOutput())
MockCoder.assert_called_once()
_, kwargs = MockCoder.call_args
self.assertEqual(kwargs["show_diffs"], True)
def test_verbose_mode_lists_env_vars(self):
self.create_env_file(".env", "AIDER_DARK_MODE=on")
with patch("sys.stdout", new_callable=StringIO) as mock_stdout:
main(["--no-git", "--verbose"], input=DummyInput(), output=DummyOutput())
output = mock_stdout.getvalue()
relevant_output = "\n".join(
line
for line in output.splitlines()
if "AIDER_DARK_MODE" in line or "dark_mode" in line
) # this bit just helps failing assertions to be easier to read
self.assertIn("AIDER_DARK_MODE", relevant_output)
self.assertIn("dark_mode", relevant_output)
self.assertRegex(relevant_output, r"AIDER_DARK_MODE:\s+on")
self.assertRegex(relevant_output, r"dark_mode:\s+True")

View file

@ -1,4 +1,5 @@
import os
import platform
import tempfile
import unittest
from pathlib import Path
@ -141,6 +142,10 @@ class TestRepo(unittest.TestCase):
def test_commit_with_custom_committer_name(self, mock_send):
mock_send.return_value = '"a good commit message"'
# Cleanup of the git temp dir explodes on windows
if platform.system() == "Windows":
return
with GitTemporaryDirectory():
# new repo
raw_repo = git.Repo()
@ -152,7 +157,8 @@ class TestRepo(unittest.TestCase):
raw_repo.git.add(str(fname))
raw_repo.git.commit("-m", "initial commit")
git_repo = GitRepo(InputOutput(), None, None)
io = InputOutput()
git_repo = GitRepo(io, None, None)
# commit a change
fname.write_text("new content")
@ -160,11 +166,23 @@ class TestRepo(unittest.TestCase):
# check the committer name
commit = raw_repo.head.commit
self.assertEqual(commit.author.name, "Test User (aider)")
self.assertEqual(commit.committer.name, "Test User (aider)")
# commit a change without aider_edits
fname.write_text("new content again!")
git_repo.commit(fnames=[str(fname)], aider_edits=False)
# check the committer name
commit = raw_repo.head.commit
self.assertEqual(commit.author.name, "Test User")
self.assertEqual(commit.committer.name, "Test User (aider)")
# check that the original committer name is restored
original_committer_name = os.environ.get("GIT_COMMITTER_NAME")
self.assertIsNone(original_committer_name)
original_author_name = os.environ.get("GIT_AUTHOR_NAME")
self.assertIsNone(original_author_name)
def test_get_tracked_files(self):
# Create a temporary directory

View file

@ -6,3 +6,4 @@ enable_playwright = "https://aider.chat/docs/install/optional.html#enable-playwr
favicon = "https://aider.chat/assets/icons/favicon-32x32.png"
model_warnings = "https://aider.chat/docs/llms/warnings.html"
token_limits = "https://aider.chat/docs/troubleshooting/token-limits.html"
llms = "https://aider.chat/docs/llms.html"

View file

@ -17,11 +17,17 @@ class IgnorantTemporaryDirectory:
return self.temp_dir.__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
self.cleanup()
def cleanup(self):
try:
self.temp_dir.__exit__(exc_type, exc_val, exc_tb)
self.temp_dir.cleanup()
except (OSError, PermissionError):
pass # Ignore errors (Windows)
def __getattr__(self, item):
return getattr(self.temp_dir, item)
class ChdirTemporaryDirectory(IgnorantTemporaryDirectory):
def __init__(self):