This commit is contained in:
Paul Gauthier 2023-06-21 18:05:49 -07:00
parent 151b7b2811
commit ba6e2eeb63
3 changed files with 49 additions and 21 deletions

View file

@ -1,8 +1,10 @@
#!/usr/bin/env python
import json
import os
import sys
import traceback
from json.decoder import JSONDecodeError
from pathlib import Path
import backoff
@ -356,7 +358,7 @@ class Coder:
utils.show_messages(messages)
try:
content, interrupted = self.send(messages)
interrupted = self.send(messages)
except ExhaustedContextWindow:
self.io.tool_error("Exhausted context window!")
self.io.tool_error(" - Use /tokens to see token usage.")
@ -364,6 +366,8 @@ class Coder:
self.io.tool_error(" - Use /clear to clear chat history.")
return
content = self.partial_response_content
if interrupted:
self.io.tool_error("\n\n^C KeyboardInterrupt")
content += "\n^C KeyboardInterrupt"
@ -468,19 +472,13 @@ class Coder:
res = openai.ChatCompletion.create(**kwargs)
return res
dump(res)
msg = res.choices[0].message
dump(msg)
print(msg.content)
print(msg.function_call.arguments)
sys.exit()
return res
def send(self, messages, model=None, silent=False):
if not model:
model = self.main_model.name
self.resp = ""
self.partial_response_content = ""
self.partial_response_function_call = dict()
interrupted = False
try:
completion = self.send_with_retries(model, messages)
@ -489,9 +487,10 @@ class Coder:
interrupted = True
if not silent:
self.io.ai_output(self.resp)
# TODO: self.partial_response_function_call
self.io.ai_output(self.partial_response_content)
return self.resp, interrupted
return interrupted
def show_send_output(self, completion, silent):
live = None
@ -503,20 +502,29 @@ class Coder:
live.start()
for chunk in completion:
dump(chunk)
if chunk.choices[0].finish_reason == "length":
raise ExhaustedContextWindow()
try:
func = chunk.choices[0].delta.function_call
print(func)
# dump(func)
for k, v in func.items():
if k in self.partial_response_function_call:
self.partial_response_function_call[k] += v
else:
self.partial_response_function_call[k] = v
# dump(self.partial_response_function_call)
args = self.partial_response_function_call.get("arguments")
args = parse_partial_args(args)
dump(args)
except AttributeError:
pass
try:
text = chunk.choices[0].delta.content
if text:
self.resp += text
self.partial_response_content += text
except AttributeError:
pass
@ -524,7 +532,7 @@ class Coder:
continue
if self.pretty:
show_resp = self.modify_incremental_response(self.resp)
show_resp = self.modify_incremental_response()
md = Markdown(
show_resp, style=self.assistant_output_color, code_theme="default"
)
@ -536,8 +544,8 @@ class Coder:
if live:
live.stop()
def modify_incremental_response(self, resp):
return resp
def modify_incremental_response(self):
return self.partial_response_content
def get_context_from_history(self, history):
context = ""
@ -562,7 +570,7 @@ class Coder:
]
try:
commit_message, interrupted = self.send(
interrupted = self.send(
messages,
model=models.GPT35.name,
silent=True,
@ -574,6 +582,7 @@ class Coder:
)
return
commit_message = self.partial_response_content
commit_message = commit_message.strip()
if commit_message and commit_message[0] == '"' and commit_message[-1] == '"':
commit_message = commit_message[1:-1].strip()
@ -728,3 +737,20 @@ def check_model_availability(main_model):
available_models = openai.Model.list()
model_ids = [model.id for model in available_models["data"]]
return main_model.name in model_ids
def parse_partial_args(data):
try:
return json.loads(data)
except JSONDecodeError:
pass
try:
return json.loads(data + "}")
except JSONDecodeError:
pass
try:
return json.loads(data + '"}')
except JSONDecodeError:
pass