From ba6e2eeb639f4750c404d33bfa7e1e316b5d2cca Mon Sep 17 00:00:00 2001 From: Paul Gauthier Date: Wed, 21 Jun 2023 18:05:49 -0700 Subject: [PATCH] wip --- aider/coders/base_coder.py | 64 +++++++++++++++++++++++---------- aider/coders/func_coder.py | 3 +- aider/coders/wholefile_coder.py | 3 +- 3 files changed, 49 insertions(+), 21 deletions(-) diff --git a/aider/coders/base_coder.py b/aider/coders/base_coder.py index 3a23f3cfd..39d63a45a 100755 --- a/aider/coders/base_coder.py +++ b/aider/coders/base_coder.py @@ -1,8 +1,10 @@ #!/usr/bin/env python +import json import os import sys import traceback +from json.decoder import JSONDecodeError from pathlib import Path import backoff @@ -356,7 +358,7 @@ class Coder: utils.show_messages(messages) try: - content, interrupted = self.send(messages) + interrupted = self.send(messages) except ExhaustedContextWindow: self.io.tool_error("Exhausted context window!") self.io.tool_error(" - Use /tokens to see token usage.") @@ -364,6 +366,8 @@ class Coder: self.io.tool_error(" - Use /clear to clear chat history.") return + content = self.partial_response_content + if interrupted: self.io.tool_error("\n\n^C KeyboardInterrupt") content += "\n^C KeyboardInterrupt" @@ -468,19 +472,13 @@ class Coder: res = openai.ChatCompletion.create(**kwargs) return res - dump(res) - msg = res.choices[0].message - dump(msg) - print(msg.content) - print(msg.function_call.arguments) - sys.exit() - return res - def send(self, messages, model=None, silent=False): if not model: model = self.main_model.name - self.resp = "" + self.partial_response_content = "" + self.partial_response_function_call = dict() + interrupted = False try: completion = self.send_with_retries(model, messages) @@ -489,9 +487,10 @@ class Coder: interrupted = True if not silent: - self.io.ai_output(self.resp) + # TODO: self.partial_response_function_call + self.io.ai_output(self.partial_response_content) - return self.resp, interrupted + return interrupted def show_send_output(self, completion, silent): live = None @@ -503,20 +502,29 @@ class Coder: live.start() for chunk in completion: - dump(chunk) if chunk.choices[0].finish_reason == "length": raise ExhaustedContextWindow() try: func = chunk.choices[0].delta.function_call - print(func) + # dump(func) + for k, v in func.items(): + if k in self.partial_response_function_call: + self.partial_response_function_call[k] += v + else: + self.partial_response_function_call[k] = v + # dump(self.partial_response_function_call) + args = self.partial_response_function_call.get("arguments") + args = parse_partial_args(args) + dump(args) + except AttributeError: pass try: text = chunk.choices[0].delta.content if text: - self.resp += text + self.partial_response_content += text except AttributeError: pass @@ -524,7 +532,7 @@ class Coder: continue if self.pretty: - show_resp = self.modify_incremental_response(self.resp) + show_resp = self.modify_incremental_response() md = Markdown( show_resp, style=self.assistant_output_color, code_theme="default" ) @@ -536,8 +544,8 @@ class Coder: if live: live.stop() - def modify_incremental_response(self, resp): - return resp + def modify_incremental_response(self): + return self.partial_response_content def get_context_from_history(self, history): context = "" @@ -562,7 +570,7 @@ class Coder: ] try: - commit_message, interrupted = self.send( + interrupted = self.send( messages, model=models.GPT35.name, silent=True, @@ -574,6 +582,7 @@ class Coder: ) return + commit_message = self.partial_response_content commit_message = commit_message.strip() if commit_message and commit_message[0] == '"' and commit_message[-1] == '"': commit_message = commit_message[1:-1].strip() @@ -728,3 +737,20 @@ def check_model_availability(main_model): available_models = openai.Model.list() model_ids = [model.id for model in available_models["data"]] return main_model.name in model_ids + + +def parse_partial_args(data): + try: + return json.loads(data) + except JSONDecodeError: + pass + + try: + return json.loads(data + "}") + except JSONDecodeError: + pass + + try: + return json.loads(data + '"}') + except JSONDecodeError: + pass diff --git a/aider/coders/func_coder.py b/aider/coders/func_coder.py index 72bc914b5..16444f09e 100644 --- a/aider/coders/func_coder.py +++ b/aider/coders/func_coder.py @@ -47,7 +47,8 @@ class FunctionCoder(Coder): else: self.cur_messages += [dict(role="assistant", content=content)] - def modify_incremental_response(self, resp): + def modify_incremental_response(self): + resp = self.partial_response_content return self.update_files(resp, mode="diff") def update_files(self, content, mode="update"): diff --git a/aider/coders/wholefile_coder.py b/aider/coders/wholefile_coder.py index a3f512b3b..1d50c5696 100644 --- a/aider/coders/wholefile_coder.py +++ b/aider/coders/wholefile_coder.py @@ -20,7 +20,8 @@ class WholeFileCoder(Coder): else: self.cur_messages += [dict(role="assistant", content=content)] - def modify_incremental_response(self, resp): + def modify_incremental_response(self): + resp = self.partial_response_content return self.update_files(resp, mode="diff") def update_files(self, content, mode="update"):