mirror of
https://github.com/Aider-AI/aider.git
synced 2025-05-24 22:34:59 +00:00
wip
This commit is contained in:
parent
151b7b2811
commit
ba6e2eeb63
3 changed files with 49 additions and 21 deletions
|
@ -1,8 +1,10 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
from json.decoder import JSONDecodeError
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import backoff
|
import backoff
|
||||||
|
@ -356,7 +358,7 @@ class Coder:
|
||||||
utils.show_messages(messages)
|
utils.show_messages(messages)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
content, interrupted = self.send(messages)
|
interrupted = self.send(messages)
|
||||||
except ExhaustedContextWindow:
|
except ExhaustedContextWindow:
|
||||||
self.io.tool_error("Exhausted context window!")
|
self.io.tool_error("Exhausted context window!")
|
||||||
self.io.tool_error(" - Use /tokens to see token usage.")
|
self.io.tool_error(" - Use /tokens to see token usage.")
|
||||||
|
@ -364,6 +366,8 @@ class Coder:
|
||||||
self.io.tool_error(" - Use /clear to clear chat history.")
|
self.io.tool_error(" - Use /clear to clear chat history.")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
content = self.partial_response_content
|
||||||
|
|
||||||
if interrupted:
|
if interrupted:
|
||||||
self.io.tool_error("\n\n^C KeyboardInterrupt")
|
self.io.tool_error("\n\n^C KeyboardInterrupt")
|
||||||
content += "\n^C KeyboardInterrupt"
|
content += "\n^C KeyboardInterrupt"
|
||||||
|
@ -468,19 +472,13 @@ class Coder:
|
||||||
res = openai.ChatCompletion.create(**kwargs)
|
res = openai.ChatCompletion.create(**kwargs)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
dump(res)
|
|
||||||
msg = res.choices[0].message
|
|
||||||
dump(msg)
|
|
||||||
print(msg.content)
|
|
||||||
print(msg.function_call.arguments)
|
|
||||||
sys.exit()
|
|
||||||
return res
|
|
||||||
|
|
||||||
def send(self, messages, model=None, silent=False):
|
def send(self, messages, model=None, silent=False):
|
||||||
if not model:
|
if not model:
|
||||||
model = self.main_model.name
|
model = self.main_model.name
|
||||||
|
|
||||||
self.resp = ""
|
self.partial_response_content = ""
|
||||||
|
self.partial_response_function_call = dict()
|
||||||
|
|
||||||
interrupted = False
|
interrupted = False
|
||||||
try:
|
try:
|
||||||
completion = self.send_with_retries(model, messages)
|
completion = self.send_with_retries(model, messages)
|
||||||
|
@ -489,9 +487,10 @@ class Coder:
|
||||||
interrupted = True
|
interrupted = True
|
||||||
|
|
||||||
if not silent:
|
if not silent:
|
||||||
self.io.ai_output(self.resp)
|
# TODO: self.partial_response_function_call
|
||||||
|
self.io.ai_output(self.partial_response_content)
|
||||||
|
|
||||||
return self.resp, interrupted
|
return interrupted
|
||||||
|
|
||||||
def show_send_output(self, completion, silent):
|
def show_send_output(self, completion, silent):
|
||||||
live = None
|
live = None
|
||||||
|
@ -503,20 +502,29 @@ class Coder:
|
||||||
live.start()
|
live.start()
|
||||||
|
|
||||||
for chunk in completion:
|
for chunk in completion:
|
||||||
dump(chunk)
|
|
||||||
if chunk.choices[0].finish_reason == "length":
|
if chunk.choices[0].finish_reason == "length":
|
||||||
raise ExhaustedContextWindow()
|
raise ExhaustedContextWindow()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
func = chunk.choices[0].delta.function_call
|
func = chunk.choices[0].delta.function_call
|
||||||
print(func)
|
# dump(func)
|
||||||
|
for k, v in func.items():
|
||||||
|
if k in self.partial_response_function_call:
|
||||||
|
self.partial_response_function_call[k] += v
|
||||||
|
else:
|
||||||
|
self.partial_response_function_call[k] = v
|
||||||
|
# dump(self.partial_response_function_call)
|
||||||
|
args = self.partial_response_function_call.get("arguments")
|
||||||
|
args = parse_partial_args(args)
|
||||||
|
dump(args)
|
||||||
|
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
text = chunk.choices[0].delta.content
|
text = chunk.choices[0].delta.content
|
||||||
if text:
|
if text:
|
||||||
self.resp += text
|
self.partial_response_content += text
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -524,7 +532,7 @@ class Coder:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if self.pretty:
|
if self.pretty:
|
||||||
show_resp = self.modify_incremental_response(self.resp)
|
show_resp = self.modify_incremental_response()
|
||||||
md = Markdown(
|
md = Markdown(
|
||||||
show_resp, style=self.assistant_output_color, code_theme="default"
|
show_resp, style=self.assistant_output_color, code_theme="default"
|
||||||
)
|
)
|
||||||
|
@ -536,8 +544,8 @@ class Coder:
|
||||||
if live:
|
if live:
|
||||||
live.stop()
|
live.stop()
|
||||||
|
|
||||||
def modify_incremental_response(self, resp):
|
def modify_incremental_response(self):
|
||||||
return resp
|
return self.partial_response_content
|
||||||
|
|
||||||
def get_context_from_history(self, history):
|
def get_context_from_history(self, history):
|
||||||
context = ""
|
context = ""
|
||||||
|
@ -562,7 +570,7 @@ class Coder:
|
||||||
]
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
commit_message, interrupted = self.send(
|
interrupted = self.send(
|
||||||
messages,
|
messages,
|
||||||
model=models.GPT35.name,
|
model=models.GPT35.name,
|
||||||
silent=True,
|
silent=True,
|
||||||
|
@ -574,6 +582,7 @@ class Coder:
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
commit_message = self.partial_response_content
|
||||||
commit_message = commit_message.strip()
|
commit_message = commit_message.strip()
|
||||||
if commit_message and commit_message[0] == '"' and commit_message[-1] == '"':
|
if commit_message and commit_message[0] == '"' and commit_message[-1] == '"':
|
||||||
commit_message = commit_message[1:-1].strip()
|
commit_message = commit_message[1:-1].strip()
|
||||||
|
@ -728,3 +737,20 @@ def check_model_availability(main_model):
|
||||||
available_models = openai.Model.list()
|
available_models = openai.Model.list()
|
||||||
model_ids = [model.id for model in available_models["data"]]
|
model_ids = [model.id for model in available_models["data"]]
|
||||||
return main_model.name in model_ids
|
return main_model.name in model_ids
|
||||||
|
|
||||||
|
|
||||||
|
def parse_partial_args(data):
|
||||||
|
try:
|
||||||
|
return json.loads(data)
|
||||||
|
except JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
return json.loads(data + "}")
|
||||||
|
except JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
return json.loads(data + '"}')
|
||||||
|
except JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
|
@ -47,7 +47,8 @@ class FunctionCoder(Coder):
|
||||||
else:
|
else:
|
||||||
self.cur_messages += [dict(role="assistant", content=content)]
|
self.cur_messages += [dict(role="assistant", content=content)]
|
||||||
|
|
||||||
def modify_incremental_response(self, resp):
|
def modify_incremental_response(self):
|
||||||
|
resp = self.partial_response_content
|
||||||
return self.update_files(resp, mode="diff")
|
return self.update_files(resp, mode="diff")
|
||||||
|
|
||||||
def update_files(self, content, mode="update"):
|
def update_files(self, content, mode="update"):
|
||||||
|
|
|
@ -20,7 +20,8 @@ class WholeFileCoder(Coder):
|
||||||
else:
|
else:
|
||||||
self.cur_messages += [dict(role="assistant", content=content)]
|
self.cur_messages += [dict(role="assistant", content=content)]
|
||||||
|
|
||||||
def modify_incremental_response(self, resp):
|
def modify_incremental_response(self):
|
||||||
|
resp = self.partial_response_content
|
||||||
return self.update_files(resp, mode="diff")
|
return self.update_files(resp, mode="diff")
|
||||||
|
|
||||||
def update_files(self, content, mode="update"):
|
def update_files(self, content, mode="update"):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue