refactor: modify check_for_urls to return modified input string

This commit is contained in:
Paul Gauthier 2024-11-26 15:04:19 -08:00 committed by Paul Gauthier (aider)
parent 3ba4aca268
commit b2232cda7b
2 changed files with 7 additions and 8 deletions

View file

@ -784,7 +784,7 @@ class Coder:
return self.commands.run(inp)
self.check_for_file_mentions(inp)
self.check_for_urls(inp)
inp = self.check_for_urls(inp)
return inp
@ -830,11 +830,10 @@ class Coder:
def check_for_urls(self, inp: str) -> List[str]:
"""Check input for URLs and offer to add them to the chat."""
if not self.detect_urls:
return []
return inp
url_pattern = re.compile(r"(https?://[^\s/$.?#].[^\s]*[^\s,.])")
urls = list(set(url_pattern.findall(inp))) # Use set to remove duplicates
added_urls = []
group = ConfirmGroup(urls)
for url in urls:
if url not in self.rejected_urls:
@ -844,11 +843,10 @@ class Coder:
):
inp += "\n\n"
inp += self.commands.cmd_web(url, return_content=True)
added_urls.append(url)
else:
self.rejected_urls.add(url)
return added_urls
return inp
def keyboard_interrupt(self):
now = time.time()

View file

@ -771,7 +771,7 @@ two
# Test case with no URL
no_url_input = "This text contains no URL"
result = coder.check_for_urls(no_url_input)
self.assertEqual(result, [])
self.assertEqual(result, no_url_input)
# Test case with the same URL appearing multiple times
repeated_url_input = (
@ -779,7 +779,8 @@ two
" more time"
)
result = coder.check_for_urls(repeated_url_input)
self.assertEqual(result.count("https://example.com"), 1)
# the original 3 in the input text, plus 1 more for the scraped text
self.assertEqual(result.count("https://example.com"), 4)
self.assertIn("https://example.com", result)
def test_coder_from_coder_with_subdir(self):
@ -877,7 +878,7 @@ This command will print 'Hello, World!' to the console."""
# Test with a message containing a URL
message = "Check out https://example.com"
result = coder.check_for_urls(message)
self.assertEqual(result, [])
self.assertEqual(result, message)
coder.commands.scraper.scrape.assert_not_called()
def test_unknown_edit_format_exception(self):