use valid_fnames to improve find_filename

This commit is contained in:
Paul Gauthier 2024-08-26 12:03:36 -07:00
parent ce2324c0c6
commit 8c766f81b2

View file

@ -23,7 +23,13 @@ class EditBlockCoder(Coder):
content = self.partial_response_content
# might raise ValueError for malformed ORIG/UPD blocks
edits = list(find_original_update_blocks(content, self.fence))
edits = list(
find_original_update_blocks(
content,
self.fence,
self.get_inchat_relative_files(),
)
)
self.shell_commands += [edit[1] for edit in edits if edit[0] is None]
edits = [edit for edit in edits if edit[0] is not None]
@ -414,7 +420,7 @@ def strip_filename(filename, fence):
return filename
def find_original_update_blocks(content, fence=DEFAULT_FENCE):
def find_original_update_blocks(content, fence=DEFAULT_FENCE, valid_fnames=None):
lines = content.splitlines(keepends=True)
i = 0
current_filename = None
@ -454,7 +460,7 @@ def find_original_update_blocks(content, fence=DEFAULT_FENCE):
# Check for SEARCH/REPLACE blocks
if line.strip() == HEAD:
try:
filename = find_filename(lines[max(0, i - 3) : i], fence)
filename = find_filename(lines[max(0, i - 3) : i], fence, valid_fnames)
if not filename:
if current_filename:
filename = current_filename
@ -491,7 +497,7 @@ def find_original_update_blocks(content, fence=DEFAULT_FENCE):
i += 1
def find_filename(lines, fence):
def find_filename(lines, fence, valid_fnames):
"""
Deepseek Coder v2 has been doing this:
@ -505,19 +511,50 @@ def find_filename(lines, fence):
This is a more flexible search back for filenames.
"""
if valid_fnames is None:
valid_fnames = []
# Go back through the 3 preceding lines
lines.reverse()
lines = lines[:3]
filenames = []
for line in lines:
# If we find a filename, done
filename = strip_filename(line, fence)
if filename:
return filename
filenames.append(filename)
# Only continue as long as we keep seeing fences
if not line.startswith(fence[0]):
return
break
if not len(filenames):
return
if len(filenames) == 1:
return filenames[0]
# pick the *best* filename found
# pick a valid fname
for fname in filenames:
if fname in valid_fnames:
return fname
# match just on basename, return full valid fname
for fname in filenames:
for vfn in valid_fnames:
if fname == vfn.name:
return vfn
# look for a file w/extension
for fname in filenames:
if "." in fname:
return fname
return filenames[0]
def find_similar_lines(search_lines, content_lines, threshold=0.6):