Merge pull request #58 from paul-gauthier/encodings

Use an explicit encoding to read/write text files, utf-8 by default
This commit is contained in:
paul-gauthier 2023-07-06 13:40:02 -07:00 committed by GitHub
commit c5f8f564a4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 434 additions and 87 deletions

View file

@ -167,13 +167,13 @@ class Coder:
self.find_common_root()
if main_model.use_repo_map and self.repo and self.gpt_prompts.repo_content_prefix:
rm_io = io if self.verbose else None
self.repo_map = RepoMap(
map_tokens,
self.root,
self.main_model,
rm_io,
io,
self.gpt_prompts.repo_content_prefix,
self.verbose,
)
if self.repo_map.use_ctags:
@ -287,12 +287,21 @@ class Coder:
]
fence = fences[0]
def get_abs_fnames_content(self):
for fname in list(self.abs_fnames):
content = self.io.read_text(fname)
if content is None:
relative_fname = self.get_rel_fname(fname)
self.io.tool_error(f"Dropping {relative_fname} from the chat.")
self.abs_fnames.remove(fname)
else:
yield fname, content
def choose_fence(self):
all_content = ""
for fname in self.abs_fnames:
all_content += Path(fname).read_text() + "\n"
all_content = all_content.splitlines()
for _fname, content in self.get_abs_fnames_content():
all_content += content + "\n"
good = False
for fence_open, fence_close in self.fences:
@ -317,15 +326,15 @@ class Coder:
fnames = self.abs_fnames
prompt = ""
for fname in fnames:
for fname, content in self.get_abs_fnames_content():
relative_fname = self.get_rel_fname(fname)
prompt += utils.quoted_file(fname, relative_fname, fence=self.fence)
return prompt
prompt = "\n"
prompt += relative_fname
prompt += f"\n{self.fence[0]}\n"
prompt += content
prompt += f"{self.fence[1]}\n"
def recheck_abs_fnames(self):
self.abs_fnames = set(
fname for fname in self.abs_fnames if Path(fname).exists() and Path(fname).is_file()
)
return prompt
def get_files_messages(self):
all_content = ""
@ -454,10 +463,6 @@ class Coder:
]
messages += self.done_messages
# notice if files disappear
self.recheck_abs_fnames()
messages += self.get_files_messages()
messages += self.cur_messages
@ -917,8 +922,8 @@ class Coder:
full_path = os.path.abspath(os.path.join(self.root, path))
if full_path in self.abs_fnames:
if not self.dry_run and write_content:
Path(full_path).write_text(write_content)
if write_content:
self.io.write_text(full_path, write_content)
return full_path
if not Path(full_path).exists():
@ -943,8 +948,8 @@ class Coder:
if not self.dry_run:
self.repo.git.add(full_path)
if not self.dry_run and write_content:
Path(full_path).write_text(write_content)
if write_content:
self.io.write_text(full_path, write_content)
return full_path

View file

@ -26,7 +26,10 @@ class EditBlockCoder(Coder):
full_path = self.allowed_to_edit(path)
if not full_path:
continue
if do_replace(full_path, original, updated, self.dry_run):
content = self.io.read_text(full_path)
content = do_replace(full_path, content, original, updated)
if content:
self.io.write_text(full_path, content)
edited.add(path)
continue
self.io.tool_error(f"Failed to apply edit to {path}")
@ -211,7 +214,7 @@ def strip_quoted_wrapping(res, fname=None):
return res
def do_replace(fname, before_text, after_text, dry_run=False):
def do_replace(fname, content, before_text, after_text):
before_text = strip_quoted_wrapping(before_text, fname)
after_text = strip_quoted_wrapping(after_text, fname)
fname = Path(fname)
@ -219,21 +222,18 @@ def do_replace(fname, before_text, after_text, dry_run=False):
# does it want to make a new file?
if not fname.exists() and not before_text.strip():
fname.touch()
content = ""
content = fname.read_text()
if content is None:
return
if not before_text.strip():
# append to existing file, or start a new file
new_content = content + after_text
else:
new_content = replace_most_similar_chunk(content, before_text, after_text)
if not new_content:
return
if not dry_run:
fname.write_text(new_content)
return True
return new_content
ORIGINAL = "<<<<<<< ORIGINAL"

View file

@ -135,7 +135,10 @@ class EditBlockFunctionCoder(Coder):
full_path = self.allowed_to_edit(path)
if not full_path:
continue
if do_replace(full_path, original, updated, self.dry_run):
content = self.io.read_text(full_path)
content = do_replace(full_path, content, original, updated)
if content:
self.io.write_text(full_path, content)
edited.add(path)
continue
self.io.tool_error(f"Failed to apply edit to {path}")

View file

@ -90,8 +90,11 @@ class SingleWholeFileFunctionCoder(Coder):
# ending an existing block
full_path = os.path.abspath(os.path.join(self.root, fname))
with open(full_path, "r") as f:
orig_lines = f.readlines()
content = self.io.read_text(full_path)
if content is None:
orig_lines = []
else:
orig_lines = content.splitlines()
show_diff = diffs.diff_partial_update(
orig_lines,

View file

@ -53,7 +53,7 @@ class WholeFileCoder(Coder):
full_path = (Path(self.root) / fname).absolute()
if mode == "diff" and full_path.exists():
orig_lines = full_path.read_text().splitlines(keepends=True)
orig_lines = self.io.read_text(full_path).splitlines(keepends=True)
show_diff = diffs.diff_partial_update(
orig_lines,
@ -64,9 +64,8 @@ class WholeFileCoder(Coder):
else:
if self.allowed_to_edit(fname):
edited.add(fname)
if not self.dry_run:
new_lines = "".join(new_lines)
full_path.write_text(new_lines)
new_lines = "".join(new_lines)
self.io.write_text(full_path, new_lines)
fname = None
new_lines = []
@ -109,7 +108,7 @@ class WholeFileCoder(Coder):
full_path = (Path(self.root) / fname).absolute()
if mode == "diff" and full_path.exists():
orig_lines = full_path.read_text().splitlines(keepends=True)
orig_lines = self.io.read_text(full_path).splitlines(keepends=True)
show_diff = diffs.diff_partial_update(
orig_lines,
@ -123,8 +122,7 @@ class WholeFileCoder(Coder):
full_path = self.allowed_to_edit(fname)
if full_path:
edited.add(fname)
if not self.dry_run:
new_lines = "".join(new_lines)
Path(full_path).write_text(new_lines)
new_lines = "".join(new_lines)
self.io.write_text(full_path, new_lines)
return edited

View file

@ -101,8 +101,11 @@ class WholeFileFunctionCoder(Coder):
# ending an existing block
full_path = os.path.abspath(os.path.join(self.root, fname))
with open(full_path, "r") as f:
orig_lines = f.readlines()
content = self.io.read_text(full_path)
if content is None:
orig_lines = []
else:
orig_lines = content.splitlines()
show_diff = diffs.diff_partial_update(
orig_lines,

View file

@ -3,12 +3,13 @@ import os
import shlex
import subprocess
import sys
from pathlib import Path
import git
import tiktoken
from prompt_toolkit.completion import Completion
from aider import prompts, utils
from aider import prompts
class Commands:
@ -116,8 +117,10 @@ class Commands:
# files
for fname in self.coder.abs_fnames:
relative_fname = self.coder.get_rel_fname(fname)
quoted = utils.quoted_file(fname, relative_fname)
tokens = len(self.tokenizer.encode(quoted))
content = self.io.read_text(fname)
# approximate
content = f"{relative_fname}\n```\n" + content + "```\n"
tokens = len(self.tokenizer.encode(content))
res.append((tokens, f"{relative_fname}", "use /drop to drop from chat"))
self.io.tool_output("Approximate context window usage, in tokens:")
@ -238,8 +241,8 @@ class Commands:
)
if create_file:
with open(os.path.join(self.coder.root, word), "w"):
pass
(Path(self.coder.root) / word).touch()
matched_files = [word]
if self.coder.repo is not None:
self.coder.repo.git.add(os.path.join(self.coder.root, word))
@ -251,9 +254,11 @@ class Commands:
for matched_file in matched_files:
abs_file_path = os.path.abspath(os.path.join(self.coder.root, matched_file))
if abs_file_path not in self.coder.abs_fnames:
self.coder.abs_fnames.add(abs_file_path)
self.io.tool_output(f"Added {matched_file} to the chat")
added_fnames.append(matched_file)
content = self.io.read_text(abs_file_path)
if content is not None:
self.coder.abs_fnames.add(abs_file_path)
self.io.tool_output(f"Added {matched_file} to the chat")
added_fnames.append(matched_file)
else:
self.io.tool_error(f"{matched_file} is already in the chat")

View file

@ -11,10 +11,10 @@ def main():
file_orig, file_updated = sys.argv[1], sys.argv[2]
with open(file_orig, "r") as f:
with open(file_orig, "r", encoding="utf-8") as f:
lines_orig = f.readlines()
with open(file_updated, "r") as f:
with open(file_updated, "r", encoding="utf-8") as f:
lines_updated = f.readlines()
for i in range(len(file_updated)):

View file

@ -100,6 +100,8 @@ class InputOutput:
user_input_color="blue",
tool_output_color=None,
tool_error_color="red",
encoding="utf-8",
dry_run=False,
):
no_color = os.environ.get("NO_COLOR")
if no_color is not None and no_color != "":
@ -124,6 +126,9 @@ class InputOutput:
else:
self.chat_history_file = None
self.encoding = encoding
self.dry_run = dry_run
if pretty:
self.console = Console()
else:
@ -132,6 +137,23 @@ class InputOutput:
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self.append_chat_history(f"\n# aider chat started at {current_time}\n\n")
def read_text(self, filename):
try:
with open(str(filename), "r", encoding=self.encoding) as f:
return f.read()
except FileNotFoundError:
self.tool_error(f"{filename}: file not found error")
return
except UnicodeError as e:
self.tool_error(f"{filename}: {e}")
return
def write_text(self, filename, content):
if self.dry_run:
return
with open(str(filename), "w", encoding=self.encoding) as f:
f.write(content)
def get_input(self, root, rel_fnames, addable_rel_fnames, commands):
if self.pretty:
style = dict(style=self.user_input_color) if self.user_input_color else dict()

View file

@ -185,6 +185,11 @@ def main(args=None, input=None, output=None):
dest="dirty_commits",
help="Disable commits when repo is found dirty",
)
parser.add_argument(
"--encoding",
default="utf-8",
help="Specify the encoding to use when reading files (default: utf-8)",
)
parser.add_argument(
"--openai-api-key",
metavar="OPENAI_API_KEY",
@ -247,6 +252,7 @@ def main(args=None, input=None, output=None):
user_input_color=args.user_input_color,
tool_output_color=args.tool_output_color,
tool_error_color=args.tool_error_color,
dry_run=args.dry_run,
)
if args.verbose:
@ -294,8 +300,9 @@ def main(args=None, input=None, output=None):
coder.commit(ask=True, which="repo_files")
if args.apply:
with open(args.apply, "r") as f:
content = f.read()
content = io.read_text(args.apply)
if content is None:
return
coder.apply_updates(content)
return

View file

@ -74,8 +74,10 @@ class RepoMap:
main_model=models.GPT4,
io=None,
repo_content_prefix=None,
verbose=False,
):
self.io = io
self.verbose = verbose
if not root:
root = os.getcwd()
@ -130,7 +132,7 @@ class RepoMap:
files_listing = self.get_ranked_tags_map(chat_files, other_files)
if files_listing:
num_tokens = self.token_count(files_listing)
if self.io:
if self.verbose:
self.io.tool_output(f"ctags map: {num_tokens/1024:.1f} k-tokens")
ctags_msg = " with selected ctags info"
return files_listing, ctags_msg
@ -138,7 +140,7 @@ class RepoMap:
files_listing = self.get_simple_files_map(other_files)
ctags_msg = ""
num_tokens = self.token_count(files_listing)
if self.io:
if self.verbose:
self.io.tool_output(f"simple map: {num_tokens/1024:.1f} k-tokens")
if num_tokens < self.max_map_tokens:
return files_listing, ctags_msg
@ -198,7 +200,7 @@ class RepoMap:
with tempfile.TemporaryDirectory() as tempdir:
hello_py = os.path.join(tempdir, "hello.py")
with open(hello_py, "w") as f:
with open(hello_py, "w", encoding="utf-8") as f:
f.write("def hello():\n print('Hello, world!')\n")
self.run_ctags(hello_py)
except FileNotFoundError:
@ -237,10 +239,8 @@ class RepoMap:
return idents
def get_name_identifiers_uncached(self, fname):
try:
with open(fname, "r") as f:
content = f.read()
except UnicodeDecodeError:
content = self.io.read_text(fname)
if content is None:
return list()
try:

View file

@ -1,24 +1,6 @@
from pathlib import Path
from .dump import dump # noqa: F401
def quoted_file(fname, display_fname, fence=("```", "```"), number=False):
prompt = "\n"
prompt += display_fname
prompt += f"\n{fence[0]}\n"
file_content = Path(fname).read_text()
lines = file_content.splitlines()
for i, line in enumerate(lines, start=1):
if number:
prompt += f"{i:4d} "
prompt += line + "\n"
prompt += f"{fence[1]}\n"
return prompt
def show_messages(messages, title=None, functions=None):
if title:
print(title.upper(), "*" * 50)

View file

@ -1,5 +1,7 @@
import os
import tempfile
import unittest
from pathlib import Path
from unittest.mock import MagicMock, patch
import openai
@ -7,6 +9,8 @@ import requests
from aider import models
from aider.coders import Coder
from aider.dump import dump # noqa: F401
from aider.io import InputOutput
class TestCoder(unittest.TestCase):
@ -177,6 +181,132 @@ class TestCoder(unittest.TestCase):
# Assert that print was called once
mock_print.assert_called_once()
def test_run_with_file_deletion(self):
# Create a few temporary files
if __name__ == "__main__":
unittest.main()
tempdir = Path(tempfile.mkdtemp())
file1 = tempdir / "file1.txt"
file2 = tempdir / "file2.txt"
file1.touch()
file2.touch()
files = [file1, file2]
# Initialize the Coder object with the mocked IO and mocked repo
coder = Coder.create(
models.GPT4, None, io=InputOutput(), openai_api_key="fake_key", fnames=files
)
def mock_send(*args, **kwargs):
coder.partial_response_content = "ok"
coder.partial_response_function_call = dict()
coder.send = MagicMock(side_effect=mock_send)
# Call the run method with a message
coder.run(with_message="hi")
self.assertEqual(len(coder.abs_fnames), 2)
file1.unlink()
# Call the run method again with a message
coder.run(with_message="hi")
self.assertEqual(len(coder.abs_fnames), 1)
def test_run_with_file_unicode_error(self):
# Create a few temporary files
_, file1 = tempfile.mkstemp()
_, file2 = tempfile.mkstemp()
files = [file1, file2]
# Initialize the Coder object with the mocked IO and mocked repo
coder = Coder.create(
models.GPT4, None, io=InputOutput(), openai_api_key="fake_key", fnames=files
)
def mock_send(*args, **kwargs):
coder.partial_response_content = "ok"
coder.partial_response_function_call = dict()
coder.send = MagicMock(side_effect=mock_send)
# Call the run method with a message
coder.run(with_message="hi")
self.assertEqual(len(coder.abs_fnames), 2)
# Write some non-UTF8 text into the file
with open(file1, "wb") as f:
f.write(b"\x80abc")
# Call the run method again with a message
coder.run(with_message="hi")
self.assertEqual(len(coder.abs_fnames), 1)
def test_choose_fence(self):
# Create a few temporary files
_, file1 = tempfile.mkstemp()
with open(file1, "wb") as f:
f.write(b"this contains ``` backticks")
files = [file1]
# Initialize the Coder object with the mocked IO and mocked repo
coder = Coder.create(
models.GPT4, None, io=InputOutput(), openai_api_key="fake_key", fnames=files
)
def mock_send(*args, **kwargs):
coder.partial_response_content = "ok"
coder.partial_response_function_call = dict()
coder.send = MagicMock(side_effect=mock_send)
# Call the run method with a message
coder.run(with_message="hi")
self.assertNotEqual(coder.fence[0], "```")
def test_run_with_file_utf_unicode_error(self):
"make sure that we honor InputOutput(encoding) and don't just assume utf-8"
# Create a few temporary files
_, file1 = tempfile.mkstemp()
_, file2 = tempfile.mkstemp()
files = [file1, file2]
encoding = "utf-16"
# Initialize the Coder object with the mocked IO and mocked repo
coder = Coder.create(
models.GPT4,
None,
io=InputOutput(encoding=encoding),
openai_api_key="fake_key",
fnames=files,
)
def mock_send(*args, **kwargs):
coder.partial_response_content = "ok"
coder.partial_response_function_call = dict()
coder.send = MagicMock(side_effect=mock_send)
# Call the run method with a message
coder.run(with_message="hi")
self.assertEqual(len(coder.abs_fnames), 2)
some_content_which_will_error_if_read_with_encoding_utf8 = "ÅÍÎÏ".encode(encoding)
with open(file1, "wb") as f:
f.write(some_content_which_will_error_if_read_with_encoding_utf8)
coder.run(with_message="hi")
# both files should still be here
self.assertEqual(len(coder.abs_fnames), 2)
if __name__ == "__main__":
unittest.main()

View file

@ -1,10 +1,15 @@
import codecs
import os
import shutil
import sys
import tempfile
from io import StringIO
from unittest import TestCase
from aider import models
from aider.coders import Coder
from aider.commands import Commands
from aider.dump import dump # noqa: F401
from aider.io import InputOutput
@ -32,3 +37,43 @@ class TestCommands(TestCase):
# Check if both files have been created in the temporary directory
self.assertTrue(os.path.exists("foo.txt"))
self.assertTrue(os.path.exists("bar.txt"))
def test_cmd_add_bad_encoding(self):
# Initialize the Commands and InputOutput objects
io = InputOutput(pretty=False, yes=True)
from aider.coders import Coder
coder = Coder.create(models.GPT35, None, io, openai_api_key="deadbeef")
commands = Commands(io, coder)
# Create a new file foo.bad which will fail to decode as utf-8
with codecs.open("foo.bad", "w", encoding="iso-8859-15") as f:
f.write("ÆØÅ") # Characters not present in utf-8
commands.cmd_add("foo.bad")
self.assertEqual(coder.abs_fnames, set())
def test_cmd_tokens(self):
# Initialize the Commands and InputOutput objects
io = InputOutput(pretty=False, yes=True)
coder = Coder.create(models.GPT35, None, io, openai_api_key="deadbeef")
commands = Commands(io, coder)
commands.cmd_add("foo.txt bar.txt")
# Redirect the standard output to an instance of io.StringIO
stdout = StringIO()
sys.stdout = stdout
commands.cmd_tokens("")
# Reset the standard output
sys.stdout = sys.__stdout__
# Get the console output
console_output = stdout.getvalue()
self.assertIn("foo.txt", console_output)
self.assertIn("bar.txt", console_output)

View file

@ -1,11 +1,26 @@
# flake8: noqa: E501
import tempfile
import unittest
from pathlib import Path
from unittest.mock import MagicMock, patch
from aider import models
from aider.coders import Coder
from aider.coders import editblock_coder as eb
from aider.dump import dump # noqa: F401
from aider.io import InputOutput
class TestUtils(unittest.TestCase):
def setUp(self):
self.patcher = patch("aider.coders.base_coder.check_model_availability")
self.mock_check = self.patcher.start()
self.mock_check.return_value = True
def tearDown(self):
self.patcher.stop()
def test_replace_most_similar_chunk(self):
whole = "This is a sample text.\nAnother line of text.\nYet another line.\n"
part = "This is a sample text"
@ -223,6 +238,85 @@ These changes replace the `subprocess.run` patches with `subprocess.check_output
result = eb.replace_part_with_missing_leading_whitespace(whole, part, replace)
self.assertEqual(result, expected_output)
def test_full_edit(self):
# Create a few temporary files
_, file1 = tempfile.mkstemp()
with open(file1, "w", encoding="utf-8") as f:
f.write("one\ntwo\nthree\n")
files = [file1]
# Initialize the Coder object with the mocked IO and mocked repo
coder = Coder.create(
models.GPT4, "diff", io=InputOutput(), openai_api_key="fake_key", fnames=files
)
def mock_send(*args, **kwargs):
coder.partial_response_content = f"""
Do this:
{Path(file1).name}
<<<<<<< ORIGINAL
two
=======
new
>>>>>>> UPDATED
"""
coder.partial_response_function_call = dict()
coder.send = MagicMock(side_effect=mock_send)
# Call the run method with a message
coder.run(with_message="hi")
content = Path(file1).read_text(encoding="utf-8")
self.assertEqual(content, "one\nnew\nthree\n")
def test_full_edit_dry_run(self):
# Create a few temporary files
_, file1 = tempfile.mkstemp()
orig_content = "one\ntwo\nthree\n"
with open(file1, "w", encoding="utf-8") as f:
f.write(orig_content)
files = [file1]
# Initialize the Coder object with the mocked IO and mocked repo
coder = Coder.create(
models.GPT4,
"diff",
io=InputOutput(dry_run=True),
openai_api_key="fake_key",
fnames=files,
dry_run=True,
)
def mock_send(*args, **kwargs):
coder.partial_response_content = f"""
Do this:
{Path(file1).name}
<<<<<<< ORIGINAL
two
=======
new
>>>>>>> UPDATED
"""
coder.partial_response_function_call = dict()
coder.send = MagicMock(side_effect=mock_send)
# Call the run method with a message
coder.run(with_message="hi")
content = Path(file1).read_text(encoding="utf-8")
self.assertEqual(content, orig_content)
if __name__ == "__main__":
unittest.main()

View file

@ -3,6 +3,7 @@ import tempfile
import unittest
from unittest.mock import patch
from aider.io import InputOutput
from aider.repomap import RepoMap
@ -21,7 +22,8 @@ class TestRepoMap(unittest.TestCase):
with open(os.path.join(temp_dir, file), "w") as f:
f.write("")
repo_map = RepoMap(root=temp_dir)
io = InputOutput()
repo_map = RepoMap(root=temp_dir, io=io)
other_files = [os.path.join(temp_dir, file) for file in test_files]
result = repo_map.get_repo_map([], other_files)
@ -65,7 +67,8 @@ print(my_function(3, 4))
with open(os.path.join(temp_dir, test_file3), "w") as f:
f.write(file_content3)
repo_map = RepoMap(root=temp_dir)
io = InputOutput()
repo_map = RepoMap(root=temp_dir, io=io)
other_files = [
os.path.join(temp_dir, test_file1),
os.path.join(temp_dir, test_file2),
@ -83,7 +86,7 @@ print(my_function(3, 4))
def test_check_for_ctags_failure(self):
with patch("subprocess.run") as mock_run:
mock_run.side_effect = Exception("ctags not found")
repo_map = RepoMap()
repo_map = RepoMap(io=InputOutput())
self.assertFalse(repo_map.has_ctags)
def test_check_for_ctags_success(self):
@ -100,7 +103,7 @@ print(my_function(3, 4))
b' status = main()$/", "kind": "variable"}'
),
]
repo_map = RepoMap()
repo_map = RepoMap(io=InputOutput())
self.assertTrue(repo_map.has_ctags)
def test_get_repo_map_without_ctags(self):
@ -120,7 +123,7 @@ print(my_function(3, 4))
with open(os.path.join(temp_dir, file), "w") as f:
f.write("")
repo_map = RepoMap(root=temp_dir)
repo_map = RepoMap(root=temp_dir, io=InputOutput())
repo_map.has_ctags = False # force it off
other_files = [os.path.join(temp_dir, file) for file in test_files]

View file

@ -3,8 +3,10 @@ import shutil
import tempfile
import unittest
from pathlib import Path
from unittest.mock import MagicMock, patch
from aider import models
from aider.coders import Coder
from aider.coders.wholefile_coder import WholeFileCoder
from aider.io import InputOutput
@ -15,10 +17,16 @@ class TestWholeFileCoder(unittest.TestCase):
self.tempdir = tempfile.mkdtemp()
os.chdir(self.tempdir)
self.patcher = patch("aider.coders.base_coder.check_model_availability")
self.mock_check = self.patcher.start()
self.mock_check.return_value = True
def tearDown(self):
os.chdir(self.original_cwd)
shutil.rmtree(self.tempdir, ignore_errors=True)
self.patcher.stop()
def test_update_files(self):
# Create a sample file in the temporary directory
sample_file = "sample.txt"
@ -198,6 +206,45 @@ after b
self.assertEqual(fname_a.read_text(), "after a\n")
self.assertEqual(fname_b.read_text(), "after b\n")
def test_full_edit(self):
# Create a few temporary files
_, file1 = tempfile.mkstemp()
with open(file1, "w", encoding="utf-8") as f:
f.write("one\ntwo\nthree\n")
files = [file1]
# Initialize the Coder object with the mocked IO and mocked repo
coder = Coder.create(
models.GPT4, "whole", io=InputOutput(), openai_api_key="fake_key", fnames=files
)
# no trailing newline so the response content below doesn't add ANOTHER newline
new_content = "new\ntwo\nthree"
def mock_send(*args, **kwargs):
coder.partial_response_content = f"""
Do this:
{Path(file1).name}
```
{new_content}
```
"""
coder.partial_response_function_call = dict()
coder.send = MagicMock(side_effect=mock_send)
# Call the run method with a message
coder.run(with_message="hi")
content = Path(file1).read_text(encoding="utf-8")
# check for one trailing newline
self.assertEqual(content, new_content + "\n")
if __name__ == "__main__":
unittest.main()