unified diffs

This commit is contained in:
Paul Gauthier 2023-12-17 12:54:34 -08:00
parent 3aa17c46dd
commit 7113a30271
18 changed files with 243 additions and 96 deletions

View file

@ -49,7 +49,9 @@ class Coder:
functions = None
total_cost = 0.0
num_exhausted_context_windows = 0
num_malformed_responses = 0
last_keyboard_interrupt = None
max_apply_update_errors = 3
@classmethod
def create(
@ -61,7 +63,7 @@ class Coder:
skip_model_availabily_check=False,
**kwargs,
):
from . import EditBlockCoder, WholeFileCoder
from . import EditBlockCoder, UnifiedDiffCoder, WholeFileCoder
if not main_model:
main_model = models.GPT4
@ -83,6 +85,8 @@ class Coder:
return EditBlockCoder(client, main_model, io, **kwargs)
elif edit_format == "whole":
return WholeFileCoder(client, main_model, io, **kwargs)
elif edit_format == "udiff":
return UnifiedDiffCoder(client, main_model, io, **kwargs)
else:
raise ValueError(f"Unknown edit format {edit_format}")
@ -296,7 +300,13 @@ class Coder:
prompt += "\n"
prompt += relative_fname
prompt += f"\n{self.fence[0]}\n"
prompt += content
# lines = content.splitlines(keepends=True)
# lines = [f"{i+1:03}:{line}" for i, line in enumerate(lines)]
# prompt += "".join(lines)
prompt += f"{self.fence[1]}\n"
return prompt
@ -346,7 +356,7 @@ class Coder:
new_user_message = self.send_new_user_message(new_user_message)
if with_message:
return
return self.partial_response_content
except KeyboardInterrupt:
self.keyboard_interrupt()
@ -456,12 +466,12 @@ class Coder:
# add the reminder anyway
total_tokens = 0
messages += self.cur_messages
# Add the reminder prompt if we still have room to include it.
if total_tokens < self.main_model.max_context_tokens:
messages += reminder_message
messages += self.cur_messages
return messages
def send_new_user_message(self, inp):
@ -850,19 +860,19 @@ class Coder:
return set(edit[0] for edit in edits)
def apply_updates(self):
max_apply_update_errors = 3
try:
edited = self.update_files()
except ValueError as err:
self.num_malformed_responses += 1
err = err.args[0]
self.apply_update_errors += 1
if self.apply_update_errors < max_apply_update_errors:
if self.apply_update_errors < self.max_apply_update_errors:
self.io.tool_error(f"Malformed response #{self.apply_update_errors}, retrying...")
self.io.tool_error(str(err))
return None, err
else:
self.io.tool_error(f"Malformed response #{self.apply_update_errors}, aborting.")
self.io.tool_error(str(err))
return False, None
except Exception as err:
@ -870,11 +880,13 @@ class Coder:
print()
traceback.print_exc()
self.apply_update_errors += 1
if self.apply_update_errors < max_apply_update_errors:
if self.apply_update_errors < self.max_apply_update_errors:
self.io.tool_error(f"Update exception #{self.apply_update_errors}, retrying...")
self.io.tool_error(str(err))
return None, str(err)
else:
self.io.tool_error(f"Update exception #{self.apply_update_errors}, aborting")
self.io.tool_error(str(err))
return False, None
self.apply_update_errors = 0