From b2232cda7b097e902044dbe08c5570cddda08f82 Mon Sep 17 00:00:00 2001 From: Paul Gauthier Date: Tue, 26 Nov 2024 15:04:19 -0800 Subject: [PATCH] refactor: modify check_for_urls to return modified input string --- aider/coders/base_coder.py | 8 +++----- tests/basic/test_coder.py | 7 ++++--- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/aider/coders/base_coder.py b/aider/coders/base_coder.py index b350b7e90..1e0e8048b 100755 --- a/aider/coders/base_coder.py +++ b/aider/coders/base_coder.py @@ -784,7 +784,7 @@ class Coder: return self.commands.run(inp) self.check_for_file_mentions(inp) - self.check_for_urls(inp) + inp = self.check_for_urls(inp) return inp @@ -830,11 +830,10 @@ class Coder: def check_for_urls(self, inp: str) -> List[str]: """Check input for URLs and offer to add them to the chat.""" if not self.detect_urls: - return [] + return inp url_pattern = re.compile(r"(https?://[^\s/$.?#].[^\s]*[^\s,.])") urls = list(set(url_pattern.findall(inp))) # Use set to remove duplicates - added_urls = [] group = ConfirmGroup(urls) for url in urls: if url not in self.rejected_urls: @@ -844,11 +843,10 @@ class Coder: ): inp += "\n\n" inp += self.commands.cmd_web(url, return_content=True) - added_urls.append(url) else: self.rejected_urls.add(url) - return added_urls + return inp def keyboard_interrupt(self): now = time.time() diff --git a/tests/basic/test_coder.py b/tests/basic/test_coder.py index f04e6671e..88026cfb0 100644 --- a/tests/basic/test_coder.py +++ b/tests/basic/test_coder.py @@ -771,7 +771,7 @@ two # Test case with no URL no_url_input = "This text contains no URL" result = coder.check_for_urls(no_url_input) - self.assertEqual(result, []) + self.assertEqual(result, no_url_input) # Test case with the same URL appearing multiple times repeated_url_input = ( @@ -779,7 +779,8 @@ two " more time" ) result = coder.check_for_urls(repeated_url_input) - self.assertEqual(result.count("https://example.com"), 1) + # the original 3 in the input text, plus 1 more for the scraped text + self.assertEqual(result.count("https://example.com"), 4) self.assertIn("https://example.com", result) def test_coder_from_coder_with_subdir(self): @@ -877,7 +878,7 @@ This command will print 'Hello, World!' to the console.""" # Test with a message containing a URL message = "Check out https://example.com" result = coder.check_for_urls(message) - self.assertEqual(result, []) + self.assertEqual(result, message) coder.commands.scraper.scrape.assert_not_called() def test_unknown_edit_format_exception(self):