This commit is contained in:
Paul Gauthier 2023-06-21 14:05:23 -07:00
parent 2fce31209c
commit ab6ef7eb5c
2 changed files with 38 additions and 62 deletions

View file

@ -444,70 +444,18 @@ class Coder:
on_backoff=lambda details: print(f"Retry in {details['wait']} seconds."), on_backoff=lambda details: print(f"Retry in {details['wait']} seconds."),
) )
def send_with_retries(self, model, messages): def send_with_retries(self, model, messages):
_functions = [ kwargs = dict(
dict(
name="replace_lines",
description="replace a block of contiguous lines with a new set of lines",
parameters=dict(
type="object",
required=["file_path", "original_lines", "updated_lines"],
properties=dict(
file_path=dict(
type="string",
description="path of file to edit",
),
original_lines=dict(
type="string",
description=(
"block of contiguous lines from the file (including newlines)"
),
),
updated_lines=dict(
type="string",
description=(
"block of contiguous lines from the file (including newlines)"
),
),
),
),
),
]
functions = [
dict(
name="write_file",
description="create or update a file",
parameters=dict(
type="object",
required=["explanation", "file_path", "file_content"],
properties=dict(
explanation=dict(
type="string",
description=(
"Explanation of the changes to be made to the code (markdown"
" format)"
),
),
file_path=dict(
type="string",
description="Path of file to write",
),
file_content=dict(
type="string",
description="Content to write to the file",
),
),
),
),
]
dump(functions)
res = openai.ChatCompletion.create(
model=model, model=model,
messages=messages, messages=messages,
temperature=0, temperature=0,
stream=False, stream=True,
functions=functions,
) )
if self.functions:
kwargs["functions"] = self.functions
res = openai.ChatCompletion.create(**kwargs)
return res
dump(res) dump(res)
msg = res.choices[0].message msg = res.choices[0].message
dump(msg) dump(msg)
@ -543,12 +491,13 @@ class Coder:
live.start() live.start()
for chunk in completion: for chunk in completion:
if chunk.choices[0].finish_reason not in (None, "stop"): dump(chunk)
if chunk.choices[0].finish_reason not in (None, "stop", "function_call"):
assert False, "Exceeded context window!" assert False, "Exceeded context window!"
try: try:
func = chunk.choices[0].delta.function_call func = chunk.choices[0].delta.function_call
dump(func) print(func)
except AttributeError: except AttributeError:
pass pass

View file

@ -8,6 +8,33 @@ from .func_prompts import FunctionPrompts
class FunctionCoder(Coder): class FunctionCoder(Coder):
functions = [
dict(
name="write_file",
description="create or update a file",
parameters=dict(
type="object",
required=["explanation", "file_path", "file_content"],
properties=dict(
explanation=dict(
type="string",
description=(
"Explanation of the changes to be made to the code (markdown format)"
),
),
file_path=dict(
type="string",
description="Path of file to write",
),
file_content=dict(
type="string",
description="Content to write to the file",
),
),
),
),
]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.gpt_prompts = FunctionPrompts() self.gpt_prompts = FunctionPrompts()
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)