diff --git a/aider/coders/base_coder.py b/aider/coders/base_coder.py index 8da1e6f9f..fb9f8c9ac 100755 --- a/aider/coders/base_coder.py +++ b/aider/coders/base_coder.py @@ -167,13 +167,13 @@ class Coder: self.find_common_root() if main_model.use_repo_map and self.repo and self.gpt_prompts.repo_content_prefix: - rm_io = io if self.verbose else None self.repo_map = RepoMap( map_tokens, self.root, self.main_model, - rm_io, + io, self.gpt_prompts.repo_content_prefix, + self.verbose, ) if self.repo_map.use_ctags: @@ -287,12 +287,21 @@ class Coder: ] fence = fences[0] + def get_abs_fnames_content(self): + for fname in list(self.abs_fnames): + content = self.io.read_text(fname) + + if content is None: + relative_fname = self.get_rel_fname(fname) + self.io.tool_error(f"Dropping {relative_fname} from the chat.") + self.abs_fnames.remove(fname) + else: + yield fname, content + def choose_fence(self): all_content = "" - for fname in self.abs_fnames: - all_content += Path(fname).read_text() + "\n" - - all_content = all_content.splitlines() + for _fname, content in self.get_abs_fnames_content(): + all_content += content + "\n" good = False for fence_open, fence_close in self.fences: @@ -317,15 +326,15 @@ class Coder: fnames = self.abs_fnames prompt = "" - for fname in fnames: + for fname, content in self.get_abs_fnames_content(): relative_fname = self.get_rel_fname(fname) - prompt += utils.quoted_file(fname, relative_fname, fence=self.fence) - return prompt + prompt = "\n" + prompt += relative_fname + prompt += f"\n{self.fence[0]}\n" + prompt += content + prompt += f"{self.fence[1]}\n" - def recheck_abs_fnames(self): - self.abs_fnames = set( - fname for fname in self.abs_fnames if Path(fname).exists() and Path(fname).is_file() - ) + return prompt def get_files_messages(self): all_content = "" @@ -454,10 +463,6 @@ class Coder: ] messages += self.done_messages - - # notice if files disappear - self.recheck_abs_fnames() - messages += self.get_files_messages() messages += self.cur_messages @@ -917,8 +922,8 @@ class Coder: full_path = os.path.abspath(os.path.join(self.root, path)) if full_path in self.abs_fnames: - if not self.dry_run and write_content: - Path(full_path).write_text(write_content) + if write_content: + self.io.write_text(full_path, write_content) return full_path if not Path(full_path).exists(): @@ -943,8 +948,8 @@ class Coder: if not self.dry_run: self.repo.git.add(full_path) - if not self.dry_run and write_content: - Path(full_path).write_text(write_content) + if write_content: + self.io.write_text(full_path, write_content) return full_path diff --git a/aider/coders/editblock_coder.py b/aider/coders/editblock_coder.py index b63f577a2..49ae63674 100644 --- a/aider/coders/editblock_coder.py +++ b/aider/coders/editblock_coder.py @@ -26,7 +26,10 @@ class EditBlockCoder(Coder): full_path = self.allowed_to_edit(path) if not full_path: continue - if do_replace(full_path, original, updated, self.dry_run): + content = self.io.read_text(full_path) + content = do_replace(full_path, content, original, updated) + if content: + self.io.write_text(full_path, content) edited.add(path) continue self.io.tool_error(f"Failed to apply edit to {path}") @@ -211,7 +214,7 @@ def strip_quoted_wrapping(res, fname=None): return res -def do_replace(fname, before_text, after_text, dry_run=False): +def do_replace(fname, content, before_text, after_text): before_text = strip_quoted_wrapping(before_text, fname) after_text = strip_quoted_wrapping(after_text, fname) fname = Path(fname) @@ -219,21 +222,18 @@ def do_replace(fname, before_text, after_text, dry_run=False): # does it want to make a new file? if not fname.exists() and not before_text.strip(): fname.touch() + content = "" - content = fname.read_text() + if content is None: + return if not before_text.strip(): # append to existing file, or start a new file new_content = content + after_text else: new_content = replace_most_similar_chunk(content, before_text, after_text) - if not new_content: - return - if not dry_run: - fname.write_text(new_content) - - return True + return new_content ORIGINAL = "<<<<<<< ORIGINAL" diff --git a/aider/coders/editblock_func_coder.py b/aider/coders/editblock_func_coder.py index 676952938..2c834255e 100644 --- a/aider/coders/editblock_func_coder.py +++ b/aider/coders/editblock_func_coder.py @@ -135,7 +135,10 @@ class EditBlockFunctionCoder(Coder): full_path = self.allowed_to_edit(path) if not full_path: continue - if do_replace(full_path, original, updated, self.dry_run): + content = self.io.read_text(full_path) + content = do_replace(full_path, content, original, updated) + if content: + self.io.write_text(full_path, content) edited.add(path) continue self.io.tool_error(f"Failed to apply edit to {path}") diff --git a/aider/coders/single_wholefile_func_coder.py b/aider/coders/single_wholefile_func_coder.py index b6368311a..a9dd2494a 100644 --- a/aider/coders/single_wholefile_func_coder.py +++ b/aider/coders/single_wholefile_func_coder.py @@ -90,8 +90,11 @@ class SingleWholeFileFunctionCoder(Coder): # ending an existing block full_path = os.path.abspath(os.path.join(self.root, fname)) - with open(full_path, "r") as f: - orig_lines = f.readlines() + content = self.io.read_text(full_path) + if content is None: + orig_lines = [] + else: + orig_lines = content.splitlines() show_diff = diffs.diff_partial_update( orig_lines, diff --git a/aider/coders/wholefile_coder.py b/aider/coders/wholefile_coder.py index c37f531bb..8c8536f04 100644 --- a/aider/coders/wholefile_coder.py +++ b/aider/coders/wholefile_coder.py @@ -53,7 +53,7 @@ class WholeFileCoder(Coder): full_path = (Path(self.root) / fname).absolute() if mode == "diff" and full_path.exists(): - orig_lines = full_path.read_text().splitlines(keepends=True) + orig_lines = self.io.read_text(full_path).splitlines(keepends=True) show_diff = diffs.diff_partial_update( orig_lines, @@ -64,9 +64,8 @@ class WholeFileCoder(Coder): else: if self.allowed_to_edit(fname): edited.add(fname) - if not self.dry_run: - new_lines = "".join(new_lines) - full_path.write_text(new_lines) + new_lines = "".join(new_lines) + self.io.write_text(full_path, new_lines) fname = None new_lines = [] @@ -109,7 +108,7 @@ class WholeFileCoder(Coder): full_path = (Path(self.root) / fname).absolute() if mode == "diff" and full_path.exists(): - orig_lines = full_path.read_text().splitlines(keepends=True) + orig_lines = self.io.read_text(full_path).splitlines(keepends=True) show_diff = diffs.diff_partial_update( orig_lines, @@ -123,8 +122,7 @@ class WholeFileCoder(Coder): full_path = self.allowed_to_edit(fname) if full_path: edited.add(fname) - if not self.dry_run: - new_lines = "".join(new_lines) - Path(full_path).write_text(new_lines) + new_lines = "".join(new_lines) + self.io.write_text(full_path, new_lines) return edited diff --git a/aider/coders/wholefile_func_coder.py b/aider/coders/wholefile_func_coder.py index f45ef48a7..5c7cbd6ae 100644 --- a/aider/coders/wholefile_func_coder.py +++ b/aider/coders/wholefile_func_coder.py @@ -101,8 +101,11 @@ class WholeFileFunctionCoder(Coder): # ending an existing block full_path = os.path.abspath(os.path.join(self.root, fname)) - with open(full_path, "r") as f: - orig_lines = f.readlines() + content = self.io.read_text(full_path) + if content is None: + orig_lines = [] + else: + orig_lines = content.splitlines() show_diff = diffs.diff_partial_update( orig_lines, diff --git a/aider/commands.py b/aider/commands.py index a4541c58d..51e51ef4b 100644 --- a/aider/commands.py +++ b/aider/commands.py @@ -3,12 +3,13 @@ import os import shlex import subprocess import sys +from pathlib import Path import git import tiktoken from prompt_toolkit.completion import Completion -from aider import prompts, utils +from aider import prompts class Commands: @@ -116,8 +117,10 @@ class Commands: # files for fname in self.coder.abs_fnames: relative_fname = self.coder.get_rel_fname(fname) - quoted = utils.quoted_file(fname, relative_fname) - tokens = len(self.tokenizer.encode(quoted)) + content = self.io.read_text(fname) + # approximate + content = f"{relative_fname}\n```\n" + content + "```\n" + tokens = len(self.tokenizer.encode(content)) res.append((tokens, f"{relative_fname}", "use /drop to drop from chat")) self.io.tool_output("Approximate context window usage, in tokens:") @@ -238,8 +241,8 @@ class Commands: ) if create_file: - with open(os.path.join(self.coder.root, word), "w"): - pass + (Path(self.coder.root) / word).touch() + matched_files = [word] if self.coder.repo is not None: self.coder.repo.git.add(os.path.join(self.coder.root, word)) @@ -251,9 +254,11 @@ class Commands: for matched_file in matched_files: abs_file_path = os.path.abspath(os.path.join(self.coder.root, matched_file)) if abs_file_path not in self.coder.abs_fnames: - self.coder.abs_fnames.add(abs_file_path) - self.io.tool_output(f"Added {matched_file} to the chat") - added_fnames.append(matched_file) + content = self.io.read_text(abs_file_path) + if content is not None: + self.coder.abs_fnames.add(abs_file_path) + self.io.tool_output(f"Added {matched_file} to the chat") + added_fnames.append(matched_file) else: self.io.tool_error(f"{matched_file} is already in the chat") diff --git a/aider/diffs.py b/aider/diffs.py index 1330d01dc..784745688 100644 --- a/aider/diffs.py +++ b/aider/diffs.py @@ -11,10 +11,10 @@ def main(): file_orig, file_updated = sys.argv[1], sys.argv[2] - with open(file_orig, "r") as f: + with open(file_orig, "r", encoding="utf-8") as f: lines_orig = f.readlines() - with open(file_updated, "r") as f: + with open(file_updated, "r", encoding="utf-8") as f: lines_updated = f.readlines() for i in range(len(file_updated)): diff --git a/aider/io.py b/aider/io.py index 1d275716e..98a8a62f3 100644 --- a/aider/io.py +++ b/aider/io.py @@ -100,6 +100,8 @@ class InputOutput: user_input_color="blue", tool_output_color=None, tool_error_color="red", + encoding="utf-8", + dry_run=False, ): no_color = os.environ.get("NO_COLOR") if no_color is not None and no_color != "": @@ -124,6 +126,9 @@ class InputOutput: else: self.chat_history_file = None + self.encoding = encoding + self.dry_run = dry_run + if pretty: self.console = Console() else: @@ -132,6 +137,23 @@ class InputOutput: current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") self.append_chat_history(f"\n# aider chat started at {current_time}\n\n") + def read_text(self, filename): + try: + with open(str(filename), "r", encoding=self.encoding) as f: + return f.read() + except FileNotFoundError: + self.tool_error(f"{filename}: file not found error") + return + except UnicodeError as e: + self.tool_error(f"{filename}: {e}") + return + + def write_text(self, filename, content): + if self.dry_run: + return + with open(str(filename), "w", encoding=self.encoding) as f: + f.write(content) + def get_input(self, root, rel_fnames, addable_rel_fnames, commands): if self.pretty: style = dict(style=self.user_input_color) if self.user_input_color else dict() diff --git a/aider/main.py b/aider/main.py index decf1514f..9e614ecee 100644 --- a/aider/main.py +++ b/aider/main.py @@ -185,6 +185,11 @@ def main(args=None, input=None, output=None): dest="dirty_commits", help="Disable commits when repo is found dirty", ) + parser.add_argument( + "--encoding", + default="utf-8", + help="Specify the encoding to use when reading files (default: utf-8)", + ) parser.add_argument( "--openai-api-key", metavar="OPENAI_API_KEY", @@ -247,6 +252,7 @@ def main(args=None, input=None, output=None): user_input_color=args.user_input_color, tool_output_color=args.tool_output_color, tool_error_color=args.tool_error_color, + dry_run=args.dry_run, ) if args.verbose: @@ -294,8 +300,9 @@ def main(args=None, input=None, output=None): coder.commit(ask=True, which="repo_files") if args.apply: - with open(args.apply, "r") as f: - content = f.read() + content = io.read_text(args.apply) + if content is None: + return coder.apply_updates(content) return diff --git a/aider/repomap.py b/aider/repomap.py index fab94ec27..df739c8f9 100644 --- a/aider/repomap.py +++ b/aider/repomap.py @@ -74,8 +74,10 @@ class RepoMap: main_model=models.GPT4, io=None, repo_content_prefix=None, + verbose=False, ): self.io = io + self.verbose = verbose if not root: root = os.getcwd() @@ -130,7 +132,7 @@ class RepoMap: files_listing = self.get_ranked_tags_map(chat_files, other_files) if files_listing: num_tokens = self.token_count(files_listing) - if self.io: + if self.verbose: self.io.tool_output(f"ctags map: {num_tokens/1024:.1f} k-tokens") ctags_msg = " with selected ctags info" return files_listing, ctags_msg @@ -138,7 +140,7 @@ class RepoMap: files_listing = self.get_simple_files_map(other_files) ctags_msg = "" num_tokens = self.token_count(files_listing) - if self.io: + if self.verbose: self.io.tool_output(f"simple map: {num_tokens/1024:.1f} k-tokens") if num_tokens < self.max_map_tokens: return files_listing, ctags_msg @@ -198,7 +200,7 @@ class RepoMap: with tempfile.TemporaryDirectory() as tempdir: hello_py = os.path.join(tempdir, "hello.py") - with open(hello_py, "w") as f: + with open(hello_py, "w", encoding="utf-8") as f: f.write("def hello():\n print('Hello, world!')\n") self.run_ctags(hello_py) except FileNotFoundError: @@ -237,10 +239,8 @@ class RepoMap: return idents def get_name_identifiers_uncached(self, fname): - try: - with open(fname, "r") as f: - content = f.read() - except UnicodeDecodeError: + content = self.io.read_text(fname) + if content is None: return list() try: diff --git a/aider/utils.py b/aider/utils.py index cd805e7ab..15c3bfb71 100644 --- a/aider/utils.py +++ b/aider/utils.py @@ -1,24 +1,6 @@ -from pathlib import Path - from .dump import dump # noqa: F401 -def quoted_file(fname, display_fname, fence=("```", "```"), number=False): - prompt = "\n" - prompt += display_fname - prompt += f"\n{fence[0]}\n" - - file_content = Path(fname).read_text() - lines = file_content.splitlines() - for i, line in enumerate(lines, start=1): - if number: - prompt += f"{i:4d} " - prompt += line + "\n" - - prompt += f"{fence[1]}\n" - return prompt - - def show_messages(messages, title=None, functions=None): if title: print(title.upper(), "*" * 50) diff --git a/tests/test_coder.py b/tests/test_coder.py index 41ba7c768..b1be19eea 100644 --- a/tests/test_coder.py +++ b/tests/test_coder.py @@ -1,5 +1,7 @@ import os +import tempfile import unittest +from pathlib import Path from unittest.mock import MagicMock, patch import openai @@ -7,6 +9,8 @@ import requests from aider import models from aider.coders import Coder +from aider.dump import dump # noqa: F401 +from aider.io import InputOutput class TestCoder(unittest.TestCase): @@ -177,6 +181,132 @@ class TestCoder(unittest.TestCase): # Assert that print was called once mock_print.assert_called_once() + def test_run_with_file_deletion(self): + # Create a few temporary files -if __name__ == "__main__": - unittest.main() + tempdir = Path(tempfile.mkdtemp()) + + file1 = tempdir / "file1.txt" + file2 = tempdir / "file2.txt" + + file1.touch() + file2.touch() + + files = [file1, file2] + + # Initialize the Coder object with the mocked IO and mocked repo + coder = Coder.create( + models.GPT4, None, io=InputOutput(), openai_api_key="fake_key", fnames=files + ) + + def mock_send(*args, **kwargs): + coder.partial_response_content = "ok" + coder.partial_response_function_call = dict() + + coder.send = MagicMock(side_effect=mock_send) + + # Call the run method with a message + coder.run(with_message="hi") + self.assertEqual(len(coder.abs_fnames), 2) + + file1.unlink() + + # Call the run method again with a message + coder.run(with_message="hi") + self.assertEqual(len(coder.abs_fnames), 1) + + def test_run_with_file_unicode_error(self): + # Create a few temporary files + _, file1 = tempfile.mkstemp() + _, file2 = tempfile.mkstemp() + + files = [file1, file2] + + # Initialize the Coder object with the mocked IO and mocked repo + coder = Coder.create( + models.GPT4, None, io=InputOutput(), openai_api_key="fake_key", fnames=files + ) + + def mock_send(*args, **kwargs): + coder.partial_response_content = "ok" + coder.partial_response_function_call = dict() + + coder.send = MagicMock(side_effect=mock_send) + + # Call the run method with a message + coder.run(with_message="hi") + self.assertEqual(len(coder.abs_fnames), 2) + + # Write some non-UTF8 text into the file + with open(file1, "wb") as f: + f.write(b"\x80abc") + + # Call the run method again with a message + coder.run(with_message="hi") + self.assertEqual(len(coder.abs_fnames), 1) + + def test_choose_fence(self): + # Create a few temporary files + _, file1 = tempfile.mkstemp() + + with open(file1, "wb") as f: + f.write(b"this contains ``` backticks") + + files = [file1] + + # Initialize the Coder object with the mocked IO and mocked repo + coder = Coder.create( + models.GPT4, None, io=InputOutput(), openai_api_key="fake_key", fnames=files + ) + + def mock_send(*args, **kwargs): + coder.partial_response_content = "ok" + coder.partial_response_function_call = dict() + + coder.send = MagicMock(side_effect=mock_send) + + # Call the run method with a message + coder.run(with_message="hi") + + self.assertNotEqual(coder.fence[0], "```") + + def test_run_with_file_utf_unicode_error(self): + "make sure that we honor InputOutput(encoding) and don't just assume utf-8" + # Create a few temporary files + _, file1 = tempfile.mkstemp() + _, file2 = tempfile.mkstemp() + + files = [file1, file2] + + encoding = "utf-16" + + # Initialize the Coder object with the mocked IO and mocked repo + coder = Coder.create( + models.GPT4, + None, + io=InputOutput(encoding=encoding), + openai_api_key="fake_key", + fnames=files, + ) + + def mock_send(*args, **kwargs): + coder.partial_response_content = "ok" + coder.partial_response_function_call = dict() + + coder.send = MagicMock(side_effect=mock_send) + + # Call the run method with a message + coder.run(with_message="hi") + self.assertEqual(len(coder.abs_fnames), 2) + + some_content_which_will_error_if_read_with_encoding_utf8 = "ÅÍÎÏ".encode(encoding) + with open(file1, "wb") as f: + f.write(some_content_which_will_error_if_read_with_encoding_utf8) + + coder.run(with_message="hi") + + # both files should still be here + self.assertEqual(len(coder.abs_fnames), 2) + + if __name__ == "__main__": + unittest.main() diff --git a/tests/test_commands.py b/tests/test_commands.py index 7bd186ad6..7645fef13 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -1,10 +1,15 @@ +import codecs import os import shutil +import sys import tempfile +from io import StringIO from unittest import TestCase from aider import models +from aider.coders import Coder from aider.commands import Commands +from aider.dump import dump # noqa: F401 from aider.io import InputOutput @@ -32,3 +37,43 @@ class TestCommands(TestCase): # Check if both files have been created in the temporary directory self.assertTrue(os.path.exists("foo.txt")) self.assertTrue(os.path.exists("bar.txt")) + + def test_cmd_add_bad_encoding(self): + # Initialize the Commands and InputOutput objects + io = InputOutput(pretty=False, yes=True) + from aider.coders import Coder + + coder = Coder.create(models.GPT35, None, io, openai_api_key="deadbeef") + commands = Commands(io, coder) + + # Create a new file foo.bad which will fail to decode as utf-8 + with codecs.open("foo.bad", "w", encoding="iso-8859-15") as f: + f.write("ÆØÅ") # Characters not present in utf-8 + + commands.cmd_add("foo.bad") + + self.assertEqual(coder.abs_fnames, set()) + + def test_cmd_tokens(self): + # Initialize the Commands and InputOutput objects + io = InputOutput(pretty=False, yes=True) + + coder = Coder.create(models.GPT35, None, io, openai_api_key="deadbeef") + commands = Commands(io, coder) + + commands.cmd_add("foo.txt bar.txt") + + # Redirect the standard output to an instance of io.StringIO + stdout = StringIO() + sys.stdout = stdout + + commands.cmd_tokens("") + + # Reset the standard output + sys.stdout = sys.__stdout__ + + # Get the console output + console_output = stdout.getvalue() + + self.assertIn("foo.txt", console_output) + self.assertIn("bar.txt", console_output) diff --git a/tests/test_editblock.py b/tests/test_editblock.py index bed2a5f18..82940e3ef 100644 --- a/tests/test_editblock.py +++ b/tests/test_editblock.py @@ -1,11 +1,26 @@ # flake8: noqa: E501 +import tempfile import unittest +from pathlib import Path +from unittest.mock import MagicMock, patch +from aider import models +from aider.coders import Coder from aider.coders import editblock_coder as eb +from aider.dump import dump # noqa: F401 +from aider.io import InputOutput class TestUtils(unittest.TestCase): + def setUp(self): + self.patcher = patch("aider.coders.base_coder.check_model_availability") + self.mock_check = self.patcher.start() + self.mock_check.return_value = True + + def tearDown(self): + self.patcher.stop() + def test_replace_most_similar_chunk(self): whole = "This is a sample text.\nAnother line of text.\nYet another line.\n" part = "This is a sample text" @@ -223,6 +238,85 @@ These changes replace the `subprocess.run` patches with `subprocess.check_output result = eb.replace_part_with_missing_leading_whitespace(whole, part, replace) self.assertEqual(result, expected_output) + def test_full_edit(self): + # Create a few temporary files + _, file1 = tempfile.mkstemp() + + with open(file1, "w", encoding="utf-8") as f: + f.write("one\ntwo\nthree\n") + + files = [file1] + + # Initialize the Coder object with the mocked IO and mocked repo + coder = Coder.create( + models.GPT4, "diff", io=InputOutput(), openai_api_key="fake_key", fnames=files + ) + + def mock_send(*args, **kwargs): + coder.partial_response_content = f""" +Do this: + +{Path(file1).name} +<<<<<<< ORIGINAL +two +======= +new +>>>>>>> UPDATED + +""" + coder.partial_response_function_call = dict() + + coder.send = MagicMock(side_effect=mock_send) + + # Call the run method with a message + coder.run(with_message="hi") + + content = Path(file1).read_text(encoding="utf-8") + self.assertEqual(content, "one\nnew\nthree\n") + + def test_full_edit_dry_run(self): + # Create a few temporary files + _, file1 = tempfile.mkstemp() + + orig_content = "one\ntwo\nthree\n" + + with open(file1, "w", encoding="utf-8") as f: + f.write(orig_content) + + files = [file1] + + # Initialize the Coder object with the mocked IO and mocked repo + coder = Coder.create( + models.GPT4, + "diff", + io=InputOutput(dry_run=True), + openai_api_key="fake_key", + fnames=files, + dry_run=True, + ) + + def mock_send(*args, **kwargs): + coder.partial_response_content = f""" +Do this: + +{Path(file1).name} +<<<<<<< ORIGINAL +two +======= +new +>>>>>>> UPDATED + +""" + coder.partial_response_function_call = dict() + + coder.send = MagicMock(side_effect=mock_send) + + # Call the run method with a message + coder.run(with_message="hi") + + content = Path(file1).read_text(encoding="utf-8") + self.assertEqual(content, orig_content) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_repomap.py b/tests/test_repomap.py index cba69e2b1..03aa410f8 100644 --- a/tests/test_repomap.py +++ b/tests/test_repomap.py @@ -3,6 +3,7 @@ import tempfile import unittest from unittest.mock import patch +from aider.io import InputOutput from aider.repomap import RepoMap @@ -21,7 +22,8 @@ class TestRepoMap(unittest.TestCase): with open(os.path.join(temp_dir, file), "w") as f: f.write("") - repo_map = RepoMap(root=temp_dir) + io = InputOutput() + repo_map = RepoMap(root=temp_dir, io=io) other_files = [os.path.join(temp_dir, file) for file in test_files] result = repo_map.get_repo_map([], other_files) @@ -65,7 +67,8 @@ print(my_function(3, 4)) with open(os.path.join(temp_dir, test_file3), "w") as f: f.write(file_content3) - repo_map = RepoMap(root=temp_dir) + io = InputOutput() + repo_map = RepoMap(root=temp_dir, io=io) other_files = [ os.path.join(temp_dir, test_file1), os.path.join(temp_dir, test_file2), @@ -83,7 +86,7 @@ print(my_function(3, 4)) def test_check_for_ctags_failure(self): with patch("subprocess.run") as mock_run: mock_run.side_effect = Exception("ctags not found") - repo_map = RepoMap() + repo_map = RepoMap(io=InputOutput()) self.assertFalse(repo_map.has_ctags) def test_check_for_ctags_success(self): @@ -100,7 +103,7 @@ print(my_function(3, 4)) b' status = main()$/", "kind": "variable"}' ), ] - repo_map = RepoMap() + repo_map = RepoMap(io=InputOutput()) self.assertTrue(repo_map.has_ctags) def test_get_repo_map_without_ctags(self): @@ -120,7 +123,7 @@ print(my_function(3, 4)) with open(os.path.join(temp_dir, file), "w") as f: f.write("") - repo_map = RepoMap(root=temp_dir) + repo_map = RepoMap(root=temp_dir, io=InputOutput()) repo_map.has_ctags = False # force it off other_files = [os.path.join(temp_dir, file) for file in test_files] diff --git a/tests/test_wholefile.py b/tests/test_wholefile.py index 0237515cf..de85e04c8 100644 --- a/tests/test_wholefile.py +++ b/tests/test_wholefile.py @@ -3,8 +3,10 @@ import shutil import tempfile import unittest from pathlib import Path +from unittest.mock import MagicMock, patch from aider import models +from aider.coders import Coder from aider.coders.wholefile_coder import WholeFileCoder from aider.io import InputOutput @@ -15,10 +17,16 @@ class TestWholeFileCoder(unittest.TestCase): self.tempdir = tempfile.mkdtemp() os.chdir(self.tempdir) + self.patcher = patch("aider.coders.base_coder.check_model_availability") + self.mock_check = self.patcher.start() + self.mock_check.return_value = True + def tearDown(self): os.chdir(self.original_cwd) shutil.rmtree(self.tempdir, ignore_errors=True) + self.patcher.stop() + def test_update_files(self): # Create a sample file in the temporary directory sample_file = "sample.txt" @@ -198,6 +206,45 @@ after b self.assertEqual(fname_a.read_text(), "after a\n") self.assertEqual(fname_b.read_text(), "after b\n") + def test_full_edit(self): + # Create a few temporary files + _, file1 = tempfile.mkstemp() + + with open(file1, "w", encoding="utf-8") as f: + f.write("one\ntwo\nthree\n") + + files = [file1] + + # Initialize the Coder object with the mocked IO and mocked repo + coder = Coder.create( + models.GPT4, "whole", io=InputOutput(), openai_api_key="fake_key", fnames=files + ) + + # no trailing newline so the response content below doesn't add ANOTHER newline + new_content = "new\ntwo\nthree" + + def mock_send(*args, **kwargs): + coder.partial_response_content = f""" +Do this: + +{Path(file1).name} +``` +{new_content} +``` + +""" + coder.partial_response_function_call = dict() + + coder.send = MagicMock(side_effect=mock_send) + + # Call the run method with a message + coder.run(with_message="hi") + + content = Path(file1).read_text(encoding="utf-8") + + # check for one trailing newline + self.assertEqual(content, new_content + "\n") + if __name__ == "__main__": unittest.main()