refactor: Refactor ChatSummary initialization and remove max_chat_history_tokens param from Coder

This commit is contained in:
Paul Gauthier (aider) 2024-08-01 14:49:54 -03:00
parent 2ee34ac189
commit aa3e17dae6
2 changed files with 10 additions and 6 deletions

View file

@ -23,6 +23,7 @@ from rich.markdown import Markdown
from aider import __version__, models, prompts, urls, utils from aider import __version__, models, prompts, urls, utils
from aider.commands import Commands from aider.commands import Commands
from aider.history import ChatSummary from aider.history import ChatSummary
from aider.coders import Coder
from aider.io import InputOutput from aider.io import InputOutput
from aider.linter import Linter from aider.linter import Linter
from aider.llm import litellm from aider.llm import litellm
@ -206,7 +207,6 @@ class Coder:
use_git=True, use_git=True,
cur_messages=None, cur_messages=None,
done_messages=None, done_messages=None,
max_chat_history_tokens=None,
restore_chat_history=False, restore_chat_history=False,
auto_lint=True, auto_lint=True,
auto_test=False, auto_test=False,
@ -215,6 +215,7 @@ class Coder:
aider_commit_hashes=None, aider_commit_hashes=None,
map_mul_no_files=8, map_mul_no_files=8,
commands=None, commands=None,
summarizer=None,
): ):
if not fnames: if not fnames:
fnames = [] fnames = []
@ -319,11 +320,9 @@ class Coder:
map_mul_no_files=map_mul_no_files, map_mul_no_files=map_mul_no_files,
) )
if max_chat_history_tokens is None: self.summarizer = summarizer or ChatSummary(
max_chat_history_tokens = self.main_model.max_chat_history_tokens
self.summarizer = ChatSummary(
[self.main_model, self.main_model.weak_model], [self.main_model, self.main_model.weak_model],
max_chat_history_tokens, self.main_model.max_chat_history_tokens,
) )
self.summarizer_thread = None self.summarizer_thread = None

View file

@ -477,6 +477,11 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F
commands = Commands(io, None, verify_ssl=args.verify_ssl) commands = Commands(io, None, verify_ssl=args.verify_ssl)
summarizer = ChatSummary(
[main_model, main_model.weak_model],
args.max_chat_history_tokens,
)
try: try:
coder = Coder.create( coder = Coder.create(
main_model=main_model, main_model=main_model,
@ -495,13 +500,13 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F
code_theme=args.code_theme, code_theme=args.code_theme,
stream=args.stream, stream=args.stream,
use_git=args.git, use_git=args.git,
max_chat_history_tokens=args.max_chat_history_tokens,
restore_chat_history=args.restore_chat_history, restore_chat_history=args.restore_chat_history,
auto_lint=args.auto_lint, auto_lint=args.auto_lint,
auto_test=args.auto_test, auto_test=args.auto_test,
lint_cmds=lint_cmds, lint_cmds=lint_cmds,
test_cmd=args.test_cmd, test_cmd=args.test_cmd,
commands=commands, commands=commands,
summarizer=summarizer,
) )
except ValueError as err: except ValueError as err: