stronger wholefile prompt about file listings; adopt allowed_to_edit in wholefile; tests

This commit is contained in:
Paul Gauthier 2023-06-24 14:45:43 -07:00
parent d00e7c39d1
commit eaf02da46b
4 changed files with 50 additions and 12 deletions

View file

@ -95,6 +95,9 @@ class Coder:
stream=True, stream=True,
use_git=True, use_git=True,
): ):
if not fnames:
fnames = []
self.verbose = verbose self.verbose = verbose
self.abs_fnames = set() self.abs_fnames = set()
self.cur_messages = [] self.cur_messages = []

View file

@ -28,14 +28,12 @@ class WholeFileCoder(Coder):
edited = set() edited = set()
chat_files = self.get_inchat_relative_files() chat_files = self.get_inchat_relative_files()
if not chat_files:
if mode == "diff":
return content
return
output = [] output = []
lines = content.splitlines(keepends=True) lines = content.splitlines(keepends=True)
allowed_to_edit = False
fname = None fname = None
new_lines = [] new_lines = []
for i, line in enumerate(lines): for i, line in enumerate(lines):
@ -55,6 +53,7 @@ class WholeFileCoder(Coder):
).splitlines() ).splitlines()
output += show_diff output += show_diff
else: else:
if allowed_to_edit:
edited.add(fname) edited.add(fname)
if not self.dry_run: if not self.dry_run:
new_lines = "".join(new_lines) new_lines = "".join(new_lines)
@ -74,8 +73,11 @@ class WholeFileCoder(Coder):
else: else:
fname = lines[i - 1].strip() fname = lines[i - 1].strip()
if mode == "update" and not self.allowed_to_edit(fname): if mode == "update":
raise ValueError(f"{fname} is not one of: {show_chat_files}") if self.allowed_to_edit(fname):
allowed_to_edit = True
else:
allowed_to_edit = False
elif fname: elif fname:
new_lines.append(line) new_lines.append(line)

View file

@ -14,14 +14,20 @@ Once you understand the request you MUST:
3. If changes are needed, output a copy of each file that needs changes. 3. If changes are needed, output a copy of each file that needs changes.
""" """
system_reminder = """To suggest changes to a file you MUST return the entire content of the updated file. system_reminder = """
You MUST use this format: To return code you MUST use this *file listing* format:
exact/path/to/filename.js exact/path/to/filename.js
```javascript ```javascript
// file content goes in the // file content goes in the
// triple backticked fenced block // triple backticked fenced block
``` ```
Every *file listing* must start with the full path!
Every *file listing* must be fenced w/triple backticks with the correct code language.
To suggest changes to a file you MUST return a *file listing* that contains the entire content of the file.
Create a new file you MUST return a *file listing* which includes an appropriate path/filename.
""" """
files_content_prefix = "Here is the current content of the files:\n" files_content_prefix = "Here is the current content of the files:\n"

View file

@ -41,6 +41,33 @@ class TestWholeFileCoder(unittest.TestCase):
updated_content = f.read() updated_content = f.read()
self.assertEqual(updated_content, "Updated content\n") self.assertEqual(updated_content, "Updated content\n")
def test_update_files_not_in_chat(self):
with tempfile.TemporaryDirectory() as temp_dir:
os.chdir(temp_dir)
# Create a sample file in the temporary directory
sample_file = "sample.txt"
with open(sample_file, "w") as f:
f.write("Original content\n")
# Initialize WholeFileCoder with the temporary directory
io = InputOutput(yes=True)
coder = WholeFileCoder(main_model=models.GPT35, io=io)
# Set the partial response content with the updated content
coder.partial_response_content = f"{sample_file}\n```\nUpdated content\n```"
# Call update_files method
edited_files = coder.update_files()
# Check if the sample file was updated
self.assertIn("sample.txt", edited_files)
# Check if the content of the sample file was updated
with open(sample_file, "r") as f:
updated_content = f.read()
self.assertEqual(updated_content, "Updated content\n")
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()