refactored summarize_chat_history_markdown

This commit is contained in:
Paul Gauthier 2024-05-05 18:58:16 -07:00
parent 3bb237bdc1
commit cd84cc0d3e

View file

@ -1,6 +1,6 @@
import argparse
from aider import prompts
from aider import models, prompts
from aider.dump import dump # noqa: F401
from aider.sendchat import simple_send_with_retries
@ -90,41 +90,47 @@ class ChatSummary:
return [dict(role="user", content=summary)]
def summarize_chat_history_markdown(self, text):
messages = []
assistant = []
for line in text.splitlines(keepends=True):
if line.startswith("# "):
continue
if line.startswith(">"):
continue
if line.startswith("#### /"):
continue
if line.startswith("#### "):
if assistant:
assistant = "".join(assistant)
if assistant.strip():
messages.append(dict(role="assistant", content=assistant))
assistant = []
content = line[5:]
if content.strip() and content.strip() != "<blank>":
messages.append(dict(role="user", content=line[5:]))
continue
assistant.append(line)
summary = self.summarize(messages[-40:])
return summary
def main():
parser = argparse.ArgumentParser()
parser.add_argument("filename", help="Markdown file to parse")
args = parser.parse_args()
model = models.Model("gpt-3.5-turbo")
summarizer = ChatSummary(model)
with open(args.filename, "r") as f:
text = f.read()
messages = []
assistant = []
for line in text.splitlines(keepends=True):
if line.startswith("# "):
continue
if line.startswith(">"):
continue
if line.startswith("#### /"):
continue
if line.startswith("#### "):
if assistant:
assistant = "".join(assistant)
if assistant.strip():
messages.append(dict(role="assistant", content=assistant))
assistant = []
content = line[5:]
if content.strip() and content.strip() != "<blank>":
messages.append(dict(role="user", content=line[5:]))
continue
assistant.append(line)
summarizer = ChatSummary("gpt-3.5-turbo", weak_model=False)
summary = summarizer.summarize(messages[-40:])
summary = summarizer.summarize_chat_history_markdown(text)
dump(summary)