added test for choose_fence

This commit is contained in:
Paul Gauthier 2023-07-06 11:08:13 -07:00
parent b03f2e54cc
commit 610dd223b5
2 changed files with 26 additions and 3 deletions

View file

@ -290,7 +290,7 @@ class Coder:
def get_abs_fnames_content(self): def get_abs_fnames_content(self):
for fname in list(self.abs_fnames): for fname in list(self.abs_fnames):
content = self.io.read_text(fname) content = self.io.read_text(fname)
dump(fname, content)
if content is None: if content is None:
relative_fname = self.get_rel_fname(fname) relative_fname = self.get_rel_fname(fname)
self.io.tool_error(f"Dropping {relative_fname} from the chat.") self.io.tool_error(f"Dropping {relative_fname} from the chat.")
@ -303,8 +303,6 @@ class Coder:
for _fname, content in self.get_abs_fnames_content(): for _fname, content in self.get_abs_fnames_content():
all_content += content + "\n" all_content += content + "\n"
all_content = all_content.splitlines()
good = False good = False
for fence_open, fence_close in self.fences: for fence_open, fence_close in self.fences:
if fence_open in all_content or fence_close in all_content: if fence_open in all_content or fence_close in all_content:

View file

@ -239,5 +239,30 @@ class TestCoder(unittest.TestCase):
coder.run(with_message="hi") coder.run(with_message="hi")
self.assertEqual(len(coder.abs_fnames), 1) self.assertEqual(len(coder.abs_fnames), 1)
def test_choose_fence(self):
# Create a few temporary files
_, file1 = tempfile.mkstemp()
with open(file1, "wb") as f:
f.write(b"this contains ``` backticks")
files = [file1]
# Initialize the Coder object with the mocked IO and mocked repo
coder = Coder.create(
models.GPT4, None, io=InputOutput(), openai_api_key="fake_key", fnames=files
)
def mock_send(*args, **kwargs):
coder.partial_response_content = "ok"
coder.partial_response_function_call = dict()
coder.send = MagicMock(side_effect=mock_send)
# Call the run method with a message
coder.run(with_message="hi")
self.assertNotEqual(coder.fence[0], "```")
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()