continue roughly working using anthropic's prefill

This commit is contained in:
Paul Gauthier 2024-06-27 14:40:46 -07:00
parent 87f4d25133
commit 044617b1b7
2 changed files with 43 additions and 24 deletions

View file

@ -13,7 +13,6 @@ from json.decoder import JSONDecodeError
from pathlib import Path from pathlib import Path
import git import git
import openai
from jsonschema import Draft7Validator from jsonschema import Draft7Validator
from rich.console import Console, Text from rich.console import Console, Text
from rich.markdown import Markdown from rich.markdown import Markdown
@ -37,7 +36,7 @@ class MissingAPIKeyError(ValueError):
pass pass
class ExhaustedContextWindow(Exception): class FinishReasonLength(Exception):
pass pass
@ -812,28 +811,43 @@ class Coder:
if self.verbose: if self.verbose:
utils.show_messages(messages, functions=self.functions) utils.show_messages(messages, functions=self.functions)
multi_response_content = ""
exhausted = False exhausted = False
interrupted = False interrupted = False
try: while True:
yield from self.send(messages, functions=self.functions) try:
except KeyboardInterrupt: yield from self.send(messages, functions=self.functions)
interrupted = True break
except ExhaustedContextWindow: except KeyboardInterrupt:
exhausted = True interrupted = True
except litellm.exceptions.BadRequestError as err: break
if "ContextWindowExceededError" in err.message: except litellm.ContextWindowExceededError as cwe_err:
# the input is overflowing the context window
exhausted = True exhausted = True
else: dump(cwe_err)
self.io.tool_error(f"BadRequestError: {err}") break
except litellm.exceptions.BadRequestError as br_err:
dump(br_err)
self.io.tool_error(f"BadRequestError: {br_err}")
return return
except openai.BadRequestError as err: except FinishReasonLength as frl_err:
if "maximum context length" in str(err): # finish_reason=length means 4k output limit?
exhausted = True dump(frl_err)
else: # exhausted = True
raise err
except Exception as err: multi_response_content += self.partial_response_content
self.io.tool_error(f"Unexpected error: {err}") if messages[-1]["role"] == "assistant":
return messages[-1]["content"] = multi_response_content
else:
messages.append(dict(role="assistant", content=multi_response_content))
except Exception as err:
self.io.tool_error(f"Unexpected error: {err}")
traceback.print_exc()
return
if multi_response_content:
multi_response_content += self.partial_response_content
self.partial_response_content = multi_response_content
if exhausted: if exhausted:
self.show_exhausted_error() self.show_exhausted_error()
@ -1103,7 +1117,7 @@ class Coder:
if show_func_err and show_content_err: if show_func_err and show_content_err:
self.io.tool_error(show_func_err) self.io.tool_error(show_func_err)
self.io.tool_error(show_content_err) self.io.tool_error(show_content_err)
raise Exception("No data found in openai response!") raise Exception("No data found in LLM response!")
tokens = None tokens = None
if hasattr(completion, "usage") and completion.usage is not None: if hasattr(completion, "usage") and completion.usage is not None:
@ -1131,6 +1145,12 @@ class Coder:
if tokens is not None: if tokens is not None:
self.io.tool_output(tokens) self.io.tool_output(tokens)
if (
hasattr(completion.choices[0], "finish_reason")
and completion.choices[0].finish_reason == "length"
):
raise FinishReasonLength()
def show_send_output_stream(self, completion): def show_send_output_stream(self, completion):
if self.show_pretty(): if self.show_pretty():
mdargs = dict(style=self.assistant_output_color, code_theme=self.code_theme) mdargs = dict(style=self.assistant_output_color, code_theme=self.code_theme)
@ -1147,7 +1167,7 @@ class Coder:
hasattr(chunk.choices[0], "finish_reason") hasattr(chunk.choices[0], "finish_reason")
and chunk.choices[0].finish_reason == "length" and chunk.choices[0].finish_reason == "length"
): ):
raise ExhaustedContextWindow() raise FinishReasonLength()
try: try:
func = chunk.choices[0].delta.function_call func = chunk.choices[0].delta.function_call

View file

@ -3,7 +3,6 @@ import json
import backoff import backoff
import httpx import httpx
import openai
from aider.dump import dump # noqa: F401 from aider.dump import dump # noqa: F401
from aider.litellm import litellm from aider.litellm import litellm
@ -85,5 +84,5 @@ def simple_send_with_retries(model_name, messages):
stream=False, stream=False,
) )
return response.choices[0].message.content return response.choices[0].message.content
except (AttributeError, openai.BadRequestError): except (AttributeError, litellm.exceptions.BadRequestError):
return return