fixed bug found by testing

This commit is contained in:
Paul Gauthier 2023-05-24 19:09:05 -07:00
parent 82804e2a6c
commit 60031ecd29
2 changed files with 8 additions and 3 deletions

View file

@ -112,6 +112,7 @@ class RepoMap:
tags = [] tags = []
for filename in filenames: for filename in filenames:
if filename.endswith(".md") or filename.endswith(".json"): if filename.endswith(".md") or filename.endswith(".json"):
tags.append(self.split_path(filename))
continue continue
tags += self.get_tags(filename) tags += self.get_tags(filename)
if not tags: if not tags:

View file

@ -3,6 +3,7 @@ import tempfile
import unittest import unittest
from aider.repomap import RepoMap from aider.repomap import RepoMap
class TestRepoMap(unittest.TestCase): class TestRepoMap(unittest.TestCase):
def test_get_tags_map(self): def test_get_tags_map(self):
# Create a temporary directory with sample files for testing # Create a temporary directory with sample files for testing
@ -25,8 +26,8 @@ class TestRepoMap(unittest.TestCase):
# Check if the result contains the expected tags map # Check if the result contains the expected tags map
self.assertIn("test_file1.py:", result) self.assertIn("test_file1.py:", result)
self.assertIn("test_file2.py:", result) self.assertIn("test_file2.py:", result)
self.assertNotIn("test_file3.md:", result) self.assertIn("test_file3.md:", result)
self.assertNotIn("test_file4.json:", result) self.assertIn("test_file4.json:", result)
def test_get_tags_map_with_identifiers(self): def test_get_tags_map_with_identifiers(self):
# Create a temporary directory with a sample Python file containing identifiers # Create a temporary directory with a sample Python file containing identifiers
@ -77,7 +78,10 @@ def my_function(arg1, arg2):
# Check if the result contains each specific file in the expected tags map without ctags # Check if the result contains each specific file in the expected tags map without ctags
for file in test_files: for file in test_files:
self.assertIn(f"{os.path.splitext(file)[0]}.{os.path.splitext(file)[1][1:]}:", result) self.assertIn(
f"{os.path.splitext(file)[0]}.{os.path.splitext(file)[1][1:]}:", result
)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()