mirror of
https://github.com/Aider-AI/aider.git
synced 2025-05-24 22:34:59 +00:00
Merge pull request #591 from paul-gauthier/restore-chat-history
Restore prior chat history on launch
This commit is contained in:
commit
45b2ba8a10
4 changed files with 68 additions and 29 deletions
|
@ -151,6 +151,12 @@ def get_parser(default_config_files, git_root):
|
||||||
default=1024,
|
default=1024,
|
||||||
help="Max number of tokens to use for repo map, use 0 to disable (default: 1024)",
|
help="Max number of tokens to use for repo map, use 0 to disable (default: 1024)",
|
||||||
)
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--max-chat-history-tokens",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Maximum number of tokens to use for chat history. If not specified, uses the model's max_chat_history_tokens.",
|
||||||
|
)
|
||||||
default_env_file = os.path.join(git_root, ".env") if git_root else ".env"
|
default_env_file = os.path.join(git_root, ".env") if git_root else ".env"
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--env-file",
|
"--env-file",
|
||||||
|
|
|
@ -165,6 +165,9 @@ class Coder:
|
||||||
for fname in self.get_inchat_relative_files():
|
for fname in self.get_inchat_relative_files():
|
||||||
lines.append(f"Added {fname} to the chat.")
|
lines.append(f"Added {fname} to the chat.")
|
||||||
|
|
||||||
|
if self.done_messages:
|
||||||
|
lines.append("Restored previous conversation history.")
|
||||||
|
|
||||||
return lines
|
return lines
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -188,6 +191,7 @@ class Coder:
|
||||||
aider_ignore_file=None,
|
aider_ignore_file=None,
|
||||||
cur_messages=None,
|
cur_messages=None,
|
||||||
done_messages=None,
|
done_messages=None,
|
||||||
|
max_chat_history_tokens=None,
|
||||||
):
|
):
|
||||||
if not fnames:
|
if not fnames:
|
||||||
fnames = []
|
fnames = []
|
||||||
|
@ -282,14 +286,22 @@ class Coder:
|
||||||
self.verbose,
|
self.verbose,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if max_chat_history_tokens is None:
|
||||||
|
max_chat_history_tokens = self.main_model.max_chat_history_tokens
|
||||||
self.summarizer = ChatSummary(
|
self.summarizer = ChatSummary(
|
||||||
self.main_model.weak_model,
|
self.main_model.weak_model,
|
||||||
self.main_model.max_chat_history_tokens,
|
max_chat_history_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.summarizer_thread = None
|
self.summarizer_thread = None
|
||||||
self.summarized_done_messages = []
|
self.summarized_done_messages = []
|
||||||
|
|
||||||
|
if not self.done_messages:
|
||||||
|
history_md = self.io.read_text(self.io.chat_history_file)
|
||||||
|
if history_md:
|
||||||
|
self.done_messages = self.summarizer.split_chat_history_markdown(history_md)
|
||||||
|
self.summarize_start()
|
||||||
|
|
||||||
# validate the functions jsonschema
|
# validate the functions jsonschema
|
||||||
if self.functions:
|
if self.functions:
|
||||||
for function in self.functions:
|
for function in self.functions:
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from aider import prompts
|
from aider import models, prompts
|
||||||
from aider.dump import dump # noqa: F401
|
from aider.dump import dump # noqa: F401
|
||||||
from aider.sendchat import simple_send_with_retries
|
from aider.sendchat import simple_send_with_retries
|
||||||
|
|
||||||
|
@ -56,7 +56,21 @@ class ChatSummary:
|
||||||
head = messages[:split_index]
|
head = messages[:split_index]
|
||||||
tail = messages[split_index:]
|
tail = messages[split_index:]
|
||||||
|
|
||||||
summary = self.summarize_all(head)
|
sized = sized[:split_index]
|
||||||
|
head.reverse()
|
||||||
|
sized.reverse()
|
||||||
|
keep = []
|
||||||
|
total = 0
|
||||||
|
model_max_input_tokens = self.model.info.get("max_input_tokens", 4096) - 512
|
||||||
|
for i in range(split_index):
|
||||||
|
total += sized[i][0]
|
||||||
|
if total > model_max_input_tokens:
|
||||||
|
break
|
||||||
|
keep.append(head[i])
|
||||||
|
|
||||||
|
keep.reverse()
|
||||||
|
|
||||||
|
summary = self.summarize_all(keep)
|
||||||
|
|
||||||
tail_tokens = sum(tokens for tokens, msg in sized[split_index:])
|
tail_tokens = sum(tokens for tokens, msg in sized[split_index:])
|
||||||
summary_tokens = self.token_count(summary)
|
summary_tokens = self.token_count(summary)
|
||||||
|
@ -90,41 +104,47 @@ class ChatSummary:
|
||||||
|
|
||||||
return [dict(role="user", content=summary)]
|
return [dict(role="user", content=summary)]
|
||||||
|
|
||||||
|
def split_chat_history_markdown(self, text):
|
||||||
|
messages = []
|
||||||
|
assistant = []
|
||||||
|
lines = text.splitlines(keepends=True)
|
||||||
|
for line in lines:
|
||||||
|
if line.startswith("# "):
|
||||||
|
continue
|
||||||
|
if line.startswith(">"):
|
||||||
|
continue
|
||||||
|
if line.startswith("#### /"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if line.startswith("#### "):
|
||||||
|
if assistant:
|
||||||
|
assistant = "".join(assistant)
|
||||||
|
if assistant.strip():
|
||||||
|
messages.append(dict(role="assistant", content=assistant))
|
||||||
|
assistant = []
|
||||||
|
|
||||||
|
content = line[5:]
|
||||||
|
if content.strip() and content.strip() != "<blank>":
|
||||||
|
messages.append(dict(role="user", content=line[5:]))
|
||||||
|
continue
|
||||||
|
|
||||||
|
assistant.append(line)
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("filename", help="Markdown file to parse")
|
parser.add_argument("filename", help="Markdown file to parse")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
model = models.Model("gpt-3.5-turbo")
|
||||||
|
summarizer = ChatSummary(model)
|
||||||
|
|
||||||
with open(args.filename, "r") as f:
|
with open(args.filename, "r") as f:
|
||||||
text = f.read()
|
text = f.read()
|
||||||
|
|
||||||
messages = []
|
summary = summarizer.summarize_chat_history_markdown(text)
|
||||||
assistant = []
|
|
||||||
for line in text.splitlines(keepends=True):
|
|
||||||
if line.startswith("# "):
|
|
||||||
continue
|
|
||||||
if line.startswith(">"):
|
|
||||||
continue
|
|
||||||
if line.startswith("#### /"):
|
|
||||||
continue
|
|
||||||
|
|
||||||
if line.startswith("#### "):
|
|
||||||
if assistant:
|
|
||||||
assistant = "".join(assistant)
|
|
||||||
if assistant.strip():
|
|
||||||
messages.append(dict(role="assistant", content=assistant))
|
|
||||||
assistant = []
|
|
||||||
|
|
||||||
content = line[5:]
|
|
||||||
if content.strip() and content.strip() != "<blank>":
|
|
||||||
messages.append(dict(role="user", content=line[5:]))
|
|
||||||
continue
|
|
||||||
|
|
||||||
assistant.append(line)
|
|
||||||
|
|
||||||
summarizer = ChatSummary("gpt-3.5-turbo", weak_model=False)
|
|
||||||
summary = summarizer.summarize(messages[-40:])
|
|
||||||
dump(summary)
|
dump(summary)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -330,6 +330,7 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F
|
||||||
use_git=args.git,
|
use_git=args.git,
|
||||||
voice_language=args.voice_language,
|
voice_language=args.voice_language,
|
||||||
aider_ignore_file=args.aiderignore,
|
aider_ignore_file=args.aiderignore,
|
||||||
|
max_chat_history_tokens=args.max_chat_history_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
except ValueError as err:
|
except ValueError as err:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue