fix: Refactor format_messages method to use ChatChunks dataclass

This commit is contained in:
Paul Gauthier 2024-08-16 16:34:41 -07:00 committed by Paul Gauthier (aider)
parent 653bb350ca
commit 2209f7b7eb

View file

@ -21,6 +21,7 @@ from pathlib import Path
import git import git
from rich.console import Console, Text from rich.console import Console, Text
from rich.markdown import Markdown from rich.markdown import Markdown
from dataclasses import dataclass, fields
from aider import __version__, models, prompts, urls, utils from aider import __version__, models, prompts, urls, utils
from aider.commands import Commands from aider.commands import Commands
@ -49,6 +50,11 @@ def wrap_fence(name):
return f"<{name}>", f"</{name}>" return f"<{name}>", f"</{name}>"
@dataclass
class ChatChunks:
pass
class Coder: class Coder:
abs_fnames = None abs_fnames = None
abs_read_only_fnames = None abs_read_only_fnames = None
@ -859,7 +865,7 @@ class Coder:
) )
return prompt return prompt
def format_messages(self): def format_chat_chunks(self):
self.choose_fence() self.choose_fence()
main_sys = self.fmt_system_prompt(self.gpt_prompts.main_system) main_sys = self.fmt_system_prompt(self.gpt_prompts.main_system)
@ -895,15 +901,17 @@ class Coder:
if self.gpt_prompts.system_reminder: if self.gpt_prompts.system_reminder:
main_sys += "\n" + self.fmt_system_prompt(self.gpt_prompts.system_reminder) main_sys += "\n" + self.fmt_system_prompt(self.gpt_prompts.system_reminder)
messages = [ chunks = ChatChunks()
chunks.system = [
dict(role="system", content=main_sys), dict(role="system", content=main_sys),
] ]
messages += example_messages chunks.examples = example_messages
self.summarize_end() self.summarize_end()
messages += self.done_messages chunks.done = self.done_messages
messages += self.get_files_messages() chunks.files = self.get_files_messages()
if self.gpt_prompts.system_reminder: if self.gpt_prompts.system_reminder:
reminder_message = [ reminder_message = [
@ -914,10 +922,13 @@ class Coder:
else: else:
reminder_message = [] reminder_message = []
chunks.cur = list(self.cur_messages)
chunks.reminder = []
# TODO review impact of token count on image messages # TODO review impact of token count on image messages
messages_tokens = self.main_model.token_count(messages) messages_tokens = 0 # self.main_model.token_count(messages)
reminder_tokens = self.main_model.token_count(reminder_message) reminder_tokens = self.main_model.token_count(reminder_message)
cur_tokens = self.main_model.token_count(self.cur_messages) cur_tokens = self.main_model.token_count(chunks.cur)
if None not in (messages_tokens, reminder_tokens, cur_tokens): if None not in (messages_tokens, reminder_tokens, cur_tokens):
total_tokens = messages_tokens + reminder_tokens + cur_tokens total_tokens = messages_tokens + reminder_tokens + cur_tokens
@ -925,9 +936,7 @@ class Coder:
# add the reminder anyway # add the reminder anyway
total_tokens = 0 total_tokens = 0
messages += self.cur_messages final = chunks.cur[-1]
final = messages[-1]
max_input_tokens = self.main_model.info.get("max_input_tokens") or 0 max_input_tokens = self.main_model.info.get("max_input_tokens") or 0
# Add the reminder prompt if we still have room to include it. # Add the reminder prompt if we still have room to include it.
@ -937,7 +946,7 @@ class Coder:
and self.gpt_prompts.system_reminder and self.gpt_prompts.system_reminder
): ):
if self.main_model.reminder_as_sys_msg: if self.main_model.reminder_as_sys_msg:
messages += reminder_message chunks.reminder = reminder_message
elif final["role"] == "user": elif final["role"] == "user":
# stuff it into the user message # stuff it into the user message
new_content = ( new_content = (
@ -945,9 +954,23 @@ class Coder:
+ "\n\n" + "\n\n"
+ self.fmt_system_prompt(self.gpt_prompts.system_reminder) + self.fmt_system_prompt(self.gpt_prompts.system_reminder)
) )
messages[-1] = dict(role=final["role"], content=new_content) chunks.cur[-1] = dict(role=final["role"], content=new_content)
return messages return chunks
def format_messages(self):
chunks = self.format_chat_chunks()
msgs = (
chunks.system
+ chunks.examples
+ chunks.done
+ chunks.files
+ chunks.cur
+ chunks.reminder
)
return msgs
def send_message(self, inp): def send_message(self, inp):
self.aider_edited_files = None self.aider_edited_files = None