improved utils.split_chat_history_markdown

This commit is contained in:
Paul Gauthier 2024-05-11 08:06:16 -07:00
parent 602a0c7c31
commit 478b5c8e7c

View file

@ -107,30 +107,52 @@ def show_messages(messages, title=None, functions=None):
dump(functions)
def split_chat_history_markdown(text):
def split_chat_history_markdown(text, include_tool=False):
messages = []
user = []
assistant = []
tool = []
lines = text.splitlines(keepends=True)
def append_msg(role, lines):
lines = "".join(lines)
if lines.strip():
messages.append(dict(role=role, content=lines))
for line in lines:
if line.startswith("# "):
continue
if line.startswith(">"):
continue
if line.startswith("#### /"):
if line.startswith("> "):
append_msg("assistant", assistant)
assistant = []
append_msg("user", user)
user = []
tool.append(line[2:])
continue
# if line.startswith("#### /"):
# continue
if line.startswith("#### "):
if assistant:
assistant = "".join(assistant)
if assistant.strip():
messages.append(dict(role="assistant", content=assistant))
assistant = []
append_msg("assistant", assistant)
assistant = []
append_msg("tool", tool)
tool = []
content = line[5:]
if content.strip() and content.strip() != "<blank>":
messages.append(dict(role="user", content=line[5:]))
user.append(content)
continue
append_msg("user", user)
user = []
append_msg("tool", tool)
tool = []
assistant.append(line)
append_msg("assistant", assistant)
append_msg("user", user)
if not include_tool:
messages = [m for m in messages if m["role"] != "tool"]
return messages