mirror of
https://github.com/Aider-AI/aider.git
synced 2025-05-29 08:44:59 +00:00
114 lines
3.1 KiB
Python
114 lines
3.1 KiB
Python
import argparse
|
|
import json
|
|
|
|
import tiktoken
|
|
|
|
from aider import models, prompts
|
|
from aider.dump import dump # noqa: F401
|
|
from aider.sendchat import simple_send_with_retries
|
|
|
|
|
|
class ChatSummary:
|
|
def __init__(self, model=models.GPT35.name, max_tokens=1024):
|
|
self.tokenizer = tiktoken.encoding_for_model(model)
|
|
self.max_tokens = max_tokens
|
|
|
|
def summarize(self, messages):
|
|
num = len(messages)
|
|
dump(num)
|
|
if num < 2:
|
|
return messages
|
|
|
|
total = 0
|
|
sized = []
|
|
for msg in messages:
|
|
tokens = len(self.tokenizer.encode(json.dumps(msg)))
|
|
sized.append((tokens, msg))
|
|
total += tokens
|
|
|
|
dump(total, self.max_tokens)
|
|
if total <= self.max_tokens:
|
|
return messages
|
|
|
|
num = num // 2
|
|
|
|
# we want the head to end with an assistant msg
|
|
while messages[num]["role"] == "assistant" and num < len(messages) - 1:
|
|
num += 1
|
|
|
|
head = messages[:num]
|
|
tail = messages[num:]
|
|
|
|
summary = self.summarize_all(head)
|
|
|
|
tail_tokens = sum(tokens for tokens, msg in sized[num:])
|
|
summary_tokens = len(self.tokenizer.encode(json.dumps(summary)))
|
|
|
|
result = summary + tail
|
|
if summary_tokens + tail_tokens < self.max_tokens:
|
|
return result
|
|
|
|
return self.summarize(result)
|
|
|
|
def summarize_all(self, messages):
|
|
content = ""
|
|
for msg in messages:
|
|
role = msg["role"].upper()
|
|
if role not in ("USER", "ASSISTANT"):
|
|
continue
|
|
content += f"# {role}\n"
|
|
content += msg["content"]
|
|
if not content.endswith("\n"):
|
|
content += "\n"
|
|
|
|
messages = [
|
|
dict(role="system", content=prompts.summarize),
|
|
dict(role="user", content=content),
|
|
]
|
|
|
|
summary = simple_send_with_retries(model=models.GPT35.name, messages=messages)
|
|
summary = prompts.summary_prefix + summary
|
|
dump(summary)
|
|
|
|
return [dict(role="user", content=summary)]
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("filename", help="Markdown file to parse")
|
|
args = parser.parse_args()
|
|
|
|
with open(args.filename, "r") as f:
|
|
text = f.read()
|
|
|
|
messages = []
|
|
assistant = []
|
|
for line in text.splitlines(keepends=True):
|
|
if line.startswith("# "):
|
|
continue
|
|
if line.startswith(">"):
|
|
continue
|
|
if line.startswith("#### /"):
|
|
continue
|
|
|
|
if line.startswith("#### "):
|
|
if assistant:
|
|
assistant = "".join(assistant)
|
|
if assistant.strip():
|
|
messages.append(dict(role="assistant", content=assistant))
|
|
assistant = []
|
|
|
|
content = line[5:]
|
|
if content.strip() and content.strip() != "<blank>":
|
|
messages.append(dict(role="user", content=line[5:]))
|
|
continue
|
|
|
|
assistant.append(line)
|
|
|
|
summarizer = ChatSummary(models.GPT35.name)
|
|
summary = summarizer.summarize(messages[-40:])
|
|
dump(summary)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|