refactor orig/upd into utils

This commit is contained in:
Paul Gauthier 2023-05-11 20:57:27 -07:00
parent 5e593eb631
commit 797372c69e
2 changed files with 26 additions and 23 deletions

View file

@ -1,9 +1,7 @@
#!/usr/bin/env python #!/usr/bin/env python
import os import os
import sys import sys
import re
import traceback import traceback
import time import time
from openai.error import RateLimitError from openai.error import RateLimitError
@ -374,29 +372,9 @@ class Coder:
if live: if live:
live.stop() live.stop()
pattern = re.compile(
# Optional: Matches the start of a code block (e.g., ```python) and any following whitespace
r"(^```\S*\s*)?"
# Matches the file path
r"^(\S*)\s*"
# Optional: Matches the end of a code block (e.g., ```) and any following whitespace
r"(^```\S*\s*)?"
# Matches the start of the ORIGINAL section and captures its content
r"^<<<<<<< ORIGINAL\n(.*?\n?)"
# Matches sep between ORIGINAL and UPDATED sections, captures UPDATED content
r"^=======\n(.*?)"
# Matches the end of the UPDATED section
r"^>>>>>>> UPDATED",
re.MULTILINE | re.DOTALL,
)
def update_files(self, content, inp): def update_files(self, content, inp):
edited = set() edited = set()
for match in self.pattern.finditer(content): for path, original, updated in utils.find_original_update_blocks(content):
_, path, _, original, updated = match.groups()
path = path.strip()
full_path = os.path.abspath(os.path.join(self.root, path)) full_path = os.path.abspath(os.path.join(self.root, path))
if full_path not in self.abs_fnames: if full_path not in self.abs_fnames:

View file

@ -1,3 +1,4 @@
import re
import math import math
from difflib import SequenceMatcher from difflib import SequenceMatcher
@ -126,3 +127,27 @@ def show_messages(messages, title):
content = msg["content"].splitlines() content = msg["content"].splitlines()
for line in content: for line in content:
print(role, line) print(role, line)
pattern = re.compile(
# Optional: Matches the start of a code block (e.g., ```python) and any following whitespace
r"(^```\S*\s*)?"
# Matches the file path
r"^(\S+)\s*"
# Optional: Matches the end of a code block (e.g., ```) and any following whitespace
r"(^```\S*\s*)?"
# Matches the start of the ORIGINAL section and captures its content
r"^<<<<<<< ORIGINAL\n(.*?\n?)"
# Matches sep between ORIGINAL and UPDATED sections, captures UPDATED content
r"^=======\n(.*?)"
# Matches the end of the UPDATED section
r"^>>>>>>> UPDATED",
re.MULTILINE | re.DOTALL,
)
def find_original_update_blocks(content):
for match in pattern.finditer(content):
_, path, _, original, updated = match.groups()
path = path.strip()
yield path, original, updated