This commit is contained in:
Paul Gauthier 2023-06-21 18:05:49 -07:00
parent 151b7b2811
commit ba6e2eeb63
3 changed files with 49 additions and 21 deletions

View file

@ -1,8 +1,10 @@
#!/usr/bin/env python #!/usr/bin/env python
import json
import os import os
import sys import sys
import traceback import traceback
from json.decoder import JSONDecodeError
from pathlib import Path from pathlib import Path
import backoff import backoff
@ -356,7 +358,7 @@ class Coder:
utils.show_messages(messages) utils.show_messages(messages)
try: try:
content, interrupted = self.send(messages) interrupted = self.send(messages)
except ExhaustedContextWindow: except ExhaustedContextWindow:
self.io.tool_error("Exhausted context window!") self.io.tool_error("Exhausted context window!")
self.io.tool_error(" - Use /tokens to see token usage.") self.io.tool_error(" - Use /tokens to see token usage.")
@ -364,6 +366,8 @@ class Coder:
self.io.tool_error(" - Use /clear to clear chat history.") self.io.tool_error(" - Use /clear to clear chat history.")
return return
content = self.partial_response_content
if interrupted: if interrupted:
self.io.tool_error("\n\n^C KeyboardInterrupt") self.io.tool_error("\n\n^C KeyboardInterrupt")
content += "\n^C KeyboardInterrupt" content += "\n^C KeyboardInterrupt"
@ -468,19 +472,13 @@ class Coder:
res = openai.ChatCompletion.create(**kwargs) res = openai.ChatCompletion.create(**kwargs)
return res return res
dump(res)
msg = res.choices[0].message
dump(msg)
print(msg.content)
print(msg.function_call.arguments)
sys.exit()
return res
def send(self, messages, model=None, silent=False): def send(self, messages, model=None, silent=False):
if not model: if not model:
model = self.main_model.name model = self.main_model.name
self.resp = "" self.partial_response_content = ""
self.partial_response_function_call = dict()
interrupted = False interrupted = False
try: try:
completion = self.send_with_retries(model, messages) completion = self.send_with_retries(model, messages)
@ -489,9 +487,10 @@ class Coder:
interrupted = True interrupted = True
if not silent: if not silent:
self.io.ai_output(self.resp) # TODO: self.partial_response_function_call
self.io.ai_output(self.partial_response_content)
return self.resp, interrupted return interrupted
def show_send_output(self, completion, silent): def show_send_output(self, completion, silent):
live = None live = None
@ -503,20 +502,29 @@ class Coder:
live.start() live.start()
for chunk in completion: for chunk in completion:
dump(chunk)
if chunk.choices[0].finish_reason == "length": if chunk.choices[0].finish_reason == "length":
raise ExhaustedContextWindow() raise ExhaustedContextWindow()
try: try:
func = chunk.choices[0].delta.function_call func = chunk.choices[0].delta.function_call
print(func) # dump(func)
for k, v in func.items():
if k in self.partial_response_function_call:
self.partial_response_function_call[k] += v
else:
self.partial_response_function_call[k] = v
# dump(self.partial_response_function_call)
args = self.partial_response_function_call.get("arguments")
args = parse_partial_args(args)
dump(args)
except AttributeError: except AttributeError:
pass pass
try: try:
text = chunk.choices[0].delta.content text = chunk.choices[0].delta.content
if text: if text:
self.resp += text self.partial_response_content += text
except AttributeError: except AttributeError:
pass pass
@ -524,7 +532,7 @@ class Coder:
continue continue
if self.pretty: if self.pretty:
show_resp = self.modify_incremental_response(self.resp) show_resp = self.modify_incremental_response()
md = Markdown( md = Markdown(
show_resp, style=self.assistant_output_color, code_theme="default" show_resp, style=self.assistant_output_color, code_theme="default"
) )
@ -536,8 +544,8 @@ class Coder:
if live: if live:
live.stop() live.stop()
def modify_incremental_response(self, resp): def modify_incremental_response(self):
return resp return self.partial_response_content
def get_context_from_history(self, history): def get_context_from_history(self, history):
context = "" context = ""
@ -562,7 +570,7 @@ class Coder:
] ]
try: try:
commit_message, interrupted = self.send( interrupted = self.send(
messages, messages,
model=models.GPT35.name, model=models.GPT35.name,
silent=True, silent=True,
@ -574,6 +582,7 @@ class Coder:
) )
return return
commit_message = self.partial_response_content
commit_message = commit_message.strip() commit_message = commit_message.strip()
if commit_message and commit_message[0] == '"' and commit_message[-1] == '"': if commit_message and commit_message[0] == '"' and commit_message[-1] == '"':
commit_message = commit_message[1:-1].strip() commit_message = commit_message[1:-1].strip()
@ -728,3 +737,20 @@ def check_model_availability(main_model):
available_models = openai.Model.list() available_models = openai.Model.list()
model_ids = [model.id for model in available_models["data"]] model_ids = [model.id for model in available_models["data"]]
return main_model.name in model_ids return main_model.name in model_ids
def parse_partial_args(data):
try:
return json.loads(data)
except JSONDecodeError:
pass
try:
return json.loads(data + "}")
except JSONDecodeError:
pass
try:
return json.loads(data + '"}')
except JSONDecodeError:
pass

View file

@ -47,7 +47,8 @@ class FunctionCoder(Coder):
else: else:
self.cur_messages += [dict(role="assistant", content=content)] self.cur_messages += [dict(role="assistant", content=content)]
def modify_incremental_response(self, resp): def modify_incremental_response(self):
resp = self.partial_response_content
return self.update_files(resp, mode="diff") return self.update_files(resp, mode="diff")
def update_files(self, content, mode="update"): def update_files(self, content, mode="update"):

View file

@ -20,7 +20,8 @@ class WholeFileCoder(Coder):
else: else:
self.cur_messages += [dict(role="assistant", content=content)] self.cur_messages += [dict(role="assistant", content=content)]
def modify_incremental_response(self, resp): def modify_incremental_response(self):
resp = self.partial_response_content
return self.update_files(resp, mode="diff") return self.update_files(resp, mode="diff")
def update_files(self, content, mode="update"): def update_files(self, content, mode="update"):