mirror of
https://github.com/Aider-AI/aider.git
synced 2025-05-24 22:34:59 +00:00
wip
This commit is contained in:
parent
151b7b2811
commit
ba6e2eeb63
3 changed files with 49 additions and 21 deletions
|
@ -1,8 +1,10 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from json.decoder import JSONDecodeError
|
||||
from pathlib import Path
|
||||
|
||||
import backoff
|
||||
|
@ -356,7 +358,7 @@ class Coder:
|
|||
utils.show_messages(messages)
|
||||
|
||||
try:
|
||||
content, interrupted = self.send(messages)
|
||||
interrupted = self.send(messages)
|
||||
except ExhaustedContextWindow:
|
||||
self.io.tool_error("Exhausted context window!")
|
||||
self.io.tool_error(" - Use /tokens to see token usage.")
|
||||
|
@ -364,6 +366,8 @@ class Coder:
|
|||
self.io.tool_error(" - Use /clear to clear chat history.")
|
||||
return
|
||||
|
||||
content = self.partial_response_content
|
||||
|
||||
if interrupted:
|
||||
self.io.tool_error("\n\n^C KeyboardInterrupt")
|
||||
content += "\n^C KeyboardInterrupt"
|
||||
|
@ -468,19 +472,13 @@ class Coder:
|
|||
res = openai.ChatCompletion.create(**kwargs)
|
||||
return res
|
||||
|
||||
dump(res)
|
||||
msg = res.choices[0].message
|
||||
dump(msg)
|
||||
print(msg.content)
|
||||
print(msg.function_call.arguments)
|
||||
sys.exit()
|
||||
return res
|
||||
|
||||
def send(self, messages, model=None, silent=False):
|
||||
if not model:
|
||||
model = self.main_model.name
|
||||
|
||||
self.resp = ""
|
||||
self.partial_response_content = ""
|
||||
self.partial_response_function_call = dict()
|
||||
|
||||
interrupted = False
|
||||
try:
|
||||
completion = self.send_with_retries(model, messages)
|
||||
|
@ -489,9 +487,10 @@ class Coder:
|
|||
interrupted = True
|
||||
|
||||
if not silent:
|
||||
self.io.ai_output(self.resp)
|
||||
# TODO: self.partial_response_function_call
|
||||
self.io.ai_output(self.partial_response_content)
|
||||
|
||||
return self.resp, interrupted
|
||||
return interrupted
|
||||
|
||||
def show_send_output(self, completion, silent):
|
||||
live = None
|
||||
|
@ -503,20 +502,29 @@ class Coder:
|
|||
live.start()
|
||||
|
||||
for chunk in completion:
|
||||
dump(chunk)
|
||||
if chunk.choices[0].finish_reason == "length":
|
||||
raise ExhaustedContextWindow()
|
||||
|
||||
try:
|
||||
func = chunk.choices[0].delta.function_call
|
||||
print(func)
|
||||
# dump(func)
|
||||
for k, v in func.items():
|
||||
if k in self.partial_response_function_call:
|
||||
self.partial_response_function_call[k] += v
|
||||
else:
|
||||
self.partial_response_function_call[k] = v
|
||||
# dump(self.partial_response_function_call)
|
||||
args = self.partial_response_function_call.get("arguments")
|
||||
args = parse_partial_args(args)
|
||||
dump(args)
|
||||
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
try:
|
||||
text = chunk.choices[0].delta.content
|
||||
if text:
|
||||
self.resp += text
|
||||
self.partial_response_content += text
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
@ -524,7 +532,7 @@ class Coder:
|
|||
continue
|
||||
|
||||
if self.pretty:
|
||||
show_resp = self.modify_incremental_response(self.resp)
|
||||
show_resp = self.modify_incremental_response()
|
||||
md = Markdown(
|
||||
show_resp, style=self.assistant_output_color, code_theme="default"
|
||||
)
|
||||
|
@ -536,8 +544,8 @@ class Coder:
|
|||
if live:
|
||||
live.stop()
|
||||
|
||||
def modify_incremental_response(self, resp):
|
||||
return resp
|
||||
def modify_incremental_response(self):
|
||||
return self.partial_response_content
|
||||
|
||||
def get_context_from_history(self, history):
|
||||
context = ""
|
||||
|
@ -562,7 +570,7 @@ class Coder:
|
|||
]
|
||||
|
||||
try:
|
||||
commit_message, interrupted = self.send(
|
||||
interrupted = self.send(
|
||||
messages,
|
||||
model=models.GPT35.name,
|
||||
silent=True,
|
||||
|
@ -574,6 +582,7 @@ class Coder:
|
|||
)
|
||||
return
|
||||
|
||||
commit_message = self.partial_response_content
|
||||
commit_message = commit_message.strip()
|
||||
if commit_message and commit_message[0] == '"' and commit_message[-1] == '"':
|
||||
commit_message = commit_message[1:-1].strip()
|
||||
|
@ -728,3 +737,20 @@ def check_model_availability(main_model):
|
|||
available_models = openai.Model.list()
|
||||
model_ids = [model.id for model in available_models["data"]]
|
||||
return main_model.name in model_ids
|
||||
|
||||
|
||||
def parse_partial_args(data):
|
||||
try:
|
||||
return json.loads(data)
|
||||
except JSONDecodeError:
|
||||
pass
|
||||
|
||||
try:
|
||||
return json.loads(data + "}")
|
||||
except JSONDecodeError:
|
||||
pass
|
||||
|
||||
try:
|
||||
return json.loads(data + '"}')
|
||||
except JSONDecodeError:
|
||||
pass
|
||||
|
|
|
@ -47,7 +47,8 @@ class FunctionCoder(Coder):
|
|||
else:
|
||||
self.cur_messages += [dict(role="assistant", content=content)]
|
||||
|
||||
def modify_incremental_response(self, resp):
|
||||
def modify_incremental_response(self):
|
||||
resp = self.partial_response_content
|
||||
return self.update_files(resp, mode="diff")
|
||||
|
||||
def update_files(self, content, mode="update"):
|
||||
|
|
|
@ -20,7 +20,8 @@ class WholeFileCoder(Coder):
|
|||
else:
|
||||
self.cur_messages += [dict(role="assistant", content=content)]
|
||||
|
||||
def modify_incremental_response(self, resp):
|
||||
def modify_incremental_response(self):
|
||||
resp = self.partial_response_content
|
||||
return self.update_files(resp, mode="diff")
|
||||
|
||||
def update_files(self, content, mode="update"):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue