mirror of
https://github.com/Aider-AI/aider.git
synced 2025-05-31 09:44:59 +00:00
feat: Add reasoning content handling in Coder and Model classes
This commit is contained in:
parent
52162a5604
commit
d1d40a9a76
2 changed files with 34 additions and 6 deletions
|
@ -36,6 +36,8 @@ from aider.utils import format_content, format_messages, format_tokens, is_image
|
||||||
from ..dump import dump # noqa: F401
|
from ..dump import dump # noqa: F401
|
||||||
from .chat_chunks import ChatChunks
|
from .chat_chunks import ChatChunks
|
||||||
|
|
||||||
|
REASONING_TAG = "reasoning-content-" + "7bbeb8e1441453ad999a0bbba8a46d4b"
|
||||||
|
|
||||||
|
|
||||||
class UnknownEditFormat(ValueError):
|
class UnknownEditFormat(ValueError):
|
||||||
def __init__(self, edit_format, valid_formats):
|
def __init__(self, edit_format, valid_formats):
|
||||||
|
@ -1306,6 +1308,8 @@ class Coder:
|
||||||
|
|
||||||
litellm_ex = LiteLLMExceptions()
|
litellm_ex = LiteLLMExceptions()
|
||||||
|
|
||||||
|
self.got_reasoning_content = False
|
||||||
|
self.ended_reasoning_content = False
|
||||||
self.usage_report = None
|
self.usage_report = None
|
||||||
exhausted = False
|
exhausted = False
|
||||||
interrupted = False
|
interrupted = False
|
||||||
|
@ -1372,8 +1376,15 @@ class Coder:
|
||||||
self.mdstream = None
|
self.mdstream = None
|
||||||
|
|
||||||
self.partial_response_content = self.get_multi_response_content_in_progress(True)
|
self.partial_response_content = self.get_multi_response_content_in_progress(True)
|
||||||
|
|
||||||
|
if self.got_reasoning_content:
|
||||||
|
reasoning_tag = REASONING_TAG
|
||||||
|
else:
|
||||||
|
reasoning_tag = None
|
||||||
|
|
||||||
self.partial_response_content = self.main_model.remove_reasoning_content(
|
self.partial_response_content = self.main_model.remove_reasoning_content(
|
||||||
self.partial_response_content
|
self.partial_response_content,
|
||||||
|
reasoning_tag=reasoning_tag,
|
||||||
)
|
)
|
||||||
self.multi_response_content = ""
|
self.multi_response_content = ""
|
||||||
|
|
||||||
|
@ -1744,9 +1755,23 @@ class Coder:
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
text = chunk.choices[0].delta.reasoning_content
|
||||||
|
if text:
|
||||||
|
self.got_reasoning_content = True
|
||||||
|
self.partial_response_content += text
|
||||||
|
received_content = True
|
||||||
|
except AttributeError:
|
||||||
|
text = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
text = chunk.choices[0].delta.content
|
text = chunk.choices[0].delta.content
|
||||||
if text:
|
if text:
|
||||||
|
if self.got_reasoning_content and not self.ended_reasoning_content:
|
||||||
|
tag = f"\n\n------\n\n</{REASONING_TAG}>\n\n"
|
||||||
|
self.partial_response_content += tag
|
||||||
|
self.ended_reasoning_content = True
|
||||||
|
|
||||||
self.partial_response_content += text
|
self.partial_response_content += text
|
||||||
received_content = True
|
received_content = True
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
|
|
|
@ -625,7 +625,7 @@ class Model(ModelSettings):
|
||||||
kwargs["num_ctx"] = num_ctx
|
kwargs["num_ctx"] = num_ctx
|
||||||
key = json.dumps(kwargs, sort_keys=True).encode()
|
key = json.dumps(kwargs, sort_keys=True).encode()
|
||||||
|
|
||||||
# dump(kwargs)
|
dump(kwargs)
|
||||||
|
|
||||||
hash_object = hashlib.sha1(key)
|
hash_object = hashlib.sha1(key)
|
||||||
if "timeout" not in kwargs:
|
if "timeout" not in kwargs:
|
||||||
|
@ -633,17 +633,20 @@ class Model(ModelSettings):
|
||||||
res = litellm.completion(**kwargs)
|
res = litellm.completion(**kwargs)
|
||||||
return hash_object, res
|
return hash_object, res
|
||||||
|
|
||||||
def remove_reasoning_content(self, res):
|
def remove_reasoning_content(self, res, reasoning_tag=None):
|
||||||
if not self.remove_reasoning:
|
if not reasoning_tag:
|
||||||
|
reasoning_tag = self.remove_reasoning
|
||||||
|
|
||||||
|
if not reasoning_tag:
|
||||||
return res
|
return res
|
||||||
|
|
||||||
# Try to match the complete tag pattern first
|
# Try to match the complete tag pattern first
|
||||||
pattern = f"<{self.remove_reasoning}>.*?</{self.remove_reasoning}>"
|
pattern = f"<{reasoning_tag}>.*?</{reasoning_tag}>"
|
||||||
res = re.sub(pattern, "", res, flags=re.DOTALL).strip()
|
res = re.sub(pattern, "", res, flags=re.DOTALL).strip()
|
||||||
|
|
||||||
# If closing tag exists but opening tag might be missing, remove everything before closing
|
# If closing tag exists but opening tag might be missing, remove everything before closing
|
||||||
# tag
|
# tag
|
||||||
closing_tag = f"</{self.remove_reasoning}>"
|
closing_tag = f"</{reasoning_tag}>"
|
||||||
if closing_tag in res:
|
if closing_tag in res:
|
||||||
# Split on the closing tag and keep everything after it
|
# Split on the closing tag and keep everything after it
|
||||||
parts = res.split(closing_tag, 1)
|
parts = res.split(closing_tag, 1)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue