improved utils.split_chat_history_markdown

This commit is contained in:
Paul Gauthier 2024-05-11 08:06:16 -07:00
parent 602a0c7c31
commit 478b5c8e7c

View file

@ -107,30 +107,52 @@ def show_messages(messages, title=None, functions=None):
dump(functions) dump(functions)
def split_chat_history_markdown(text): def split_chat_history_markdown(text, include_tool=False):
messages = [] messages = []
user = []
assistant = [] assistant = []
tool = []
lines = text.splitlines(keepends=True) lines = text.splitlines(keepends=True)
def append_msg(role, lines):
lines = "".join(lines)
if lines.strip():
messages.append(dict(role=role, content=lines))
for line in lines: for line in lines:
if line.startswith("# "): if line.startswith("# "):
continue continue
if line.startswith(">"): if line.startswith("> "):
continue append_msg("assistant", assistant)
if line.startswith("#### /"): assistant = []
append_msg("user", user)
user = []
tool.append(line[2:])
continue continue
# if line.startswith("#### /"):
# continue
if line.startswith("#### "): if line.startswith("#### "):
if assistant: append_msg("assistant", assistant)
assistant = "".join(assistant) assistant = []
if assistant.strip(): append_msg("tool", tool)
messages.append(dict(role="assistant", content=assistant)) tool = []
assistant = []
content = line[5:] content = line[5:]
if content.strip() and content.strip() != "<blank>": user.append(content)
messages.append(dict(role="user", content=line[5:]))
continue continue
append_msg("user", user)
user = []
append_msg("tool", tool)
tool = []
assistant.append(line) assistant.append(line)
append_msg("assistant", assistant)
append_msg("user", user)
if not include_tool:
messages = [m for m in messages if m["role"] != "tool"]
return messages return messages