diff --git a/coder.py b/coder.py index 5ece1d1de..1e33bfa47 100755 --- a/coder.py +++ b/coder.py @@ -44,6 +44,7 @@ def find_index(list1, list2): class Coder: fnames = dict() + last_modified = 0 def __init__(self, use_gpt_4): if use_gpt_4: @@ -76,6 +77,31 @@ class Coder: prompt += self.quoted_file(fname) return prompt + def set_files_messages(self): + + last_modified = max(Path(fname).stat().st_mtime for fname in self.fnames) + if last_modified <= self.last_modified: + return + did_edits = (self.last_modified > 0) + + self.last_modified = last_modified + print('Reloading files...') + + if did_edits: + files_content = prompts.files_content_prefix_edited + else: + files_content = prompts.files_content_prefix_plain + + files_content += self.get_files_content() + files_content += prompts.files_content_suffix + + self.files_messages = [ + dict(role = 'user', content = files_content), + dict(role = 'assistant', content = "Ok."), + ] + + return True + def get_input(self): print() @@ -115,40 +141,40 @@ class Coder: return files_messages def run(self): - done_messages = [ - dict(role = 'system', content = prompts.main_system), - ] - cur_messages = [] + self.done_messages = [] + self.cur_messages = [] + self.set_files_messages() - files_messages = self.get_files_messages(False) while True: inp = self.get_input() if inp is None: return - cur_messages += [ + if self.set_files_messages(): + # files changed, move cur messages back behind the files messages + self.done_messages += self.cur_messages + self.cur_messages = [] + + self.cur_messages += [ dict(role = 'user', content = inp), ] - self.show_messages(done_messages, "done") - self.show_messages(cur_messages, "cur") + #self.show_messages(self.done_messages, "done") + #self.show_messages(self.files_messages, "files") + #self.show_messages(self.cur_messages, "cur") + + messages = [ + dict(role = 'system', content = prompts.main_system), + ] + messages += self.done_messages + messages += self.files_messages + messages += self.cur_messages + + self.show_messages(messages, 'all') - messages = ( - done_messages - + files_messages - + cur_messages - ) content = self.send(messages) - # disabled - if False and '```' in content: - messages += [ - dict(role = 'assistant', content = content), - dict(role = 'system', content = prompts.returned_code), - ] - content = self.send(messages) - - cur_messages += [ + self.cur_messages += [ dict(role = 'assistant', content = content), ] @@ -165,20 +191,9 @@ class Coder: if not edited: continue - files_messages = self.get_files_messages(True) - - edited_message = 'ORIGINAL/UPDATED formatted changes: ' - edited_message += ', '.join(edited) - edited_message += '' - #cur_messages.pop() - #cur_messages += [ - # dict(role = 'user', content = edited_message), - #] - done_messages += cur_messages - cur_messages = [] - - - + self.done_messages += self.cur_messages + self.files_messages = self.get_files_messages(True) + self.cur_messages = [] def show_messages(self, messages, title): print(title.upper(), '*' * 50)