This commit is contained in:
Paul Gauthier 2024-08-08 13:50:42 -03:00
parent 17c13da008
commit 5f3cf2cb2f

View file

@ -71,6 +71,8 @@ class Coder:
test_outcome = None test_outcome = None
multi_response_content = "" multi_response_content = ""
partial_response_content = None
@classmethod @classmethod
def create( def create(
self, self,
@ -609,7 +611,7 @@ class Coder:
def run_stream(self, user_message): def run_stream(self, user_message):
self.io.user_input(user_message) self.io.user_input(user_message)
self.init_before_message() self.init_before_message()
yield from self.send_new_user_message(user_message) yield from self.send_message(user_message)
def init_before_message(self): def init_before_message(self):
self.reflected_message = None self.reflected_message = None
@ -619,46 +621,29 @@ class Coder:
self.edit_outcome = None self.edit_outcome = None
def run(self, with_message=None): def run(self, with_message=None):
while True: try:
self.init_before_message() if with_message:
self.io.user_input(with_message)
self.run_one(with_message)
return self.partial_response_content
try: while True:
if with_message: user_message = self.get_input()
new_user_message = with_message self.run_one(user_message)
self.io.user_input(with_message) except KeyboardInterrupt:
else: self.keyboard_interrupt()
new_user_message = self.run_loop() except EOFError:
return
while new_user_message: def get_input(self):
self.reflected_message = None return self.io.get_input(
list(self.send_new_user_message(new_user_message))
new_user_message = None
if self.reflected_message:
if self.num_reflections < self.max_reflections:
self.num_reflections += 1
new_user_message = self.reflected_message
else:
self.io.tool_error(
f"Only {self.max_reflections} reflections allowed, stopping."
)
if with_message:
return self.partial_response_content
except KeyboardInterrupt:
self.keyboard_interrupt()
except EOFError:
return
def run_loop(self):
inp = self.io.get_input(
self.root, self.root,
self.get_inchat_relative_files(), self.get_inchat_relative_files(),
self.get_addable_relative_files(), self.get_addable_relative_files(),
self.commands, self.commands,
) )
def preproc_user_input(self, inp):
if not inp: if not inp:
return return
@ -670,6 +655,25 @@ class Coder:
return inp return inp
def run_one(self, user_message):
self.init_before_message()
message = self.preproc_user_input(user_message)
while message:
self.reflected_message = None
list(self.send_message(message))
if not self.reflected_message:
break
if self.num_reflections >= self.max_reflections:
self.io.tool_error(f"Only {self.max_reflections} reflections allowed, stopping.")
return
self.num_reflections += 1
message = self.reflected_message
def check_for_urls(self, inp): def check_for_urls(self, inp):
url_pattern = re.compile(r"(https?://[^\s/$.?#].[^\s]*[^\s,.])") url_pattern = re.compile(r"(https?://[^\s/$.?#].[^\s]*[^\s,.])")
urls = list(set(url_pattern.findall(inp))) # Use set to remove duplicates urls = list(set(url_pattern.findall(inp))) # Use set to remove duplicates
@ -872,7 +876,7 @@ class Coder:
return messages return messages
def send_new_user_message(self, inp): def send_message(self, inp):
self.aider_edited_files = None self.aider_edited_files = None
self.cur_messages += [ self.cur_messages += [