mirror of
https://github.com/Aider-AI/aider.git
synced 2025-05-31 01:35:00 +00:00
do not include added files in the repo map #315
This commit is contained in:
parent
95fe1be4c5
commit
d54f883698
2 changed files with 40 additions and 9 deletions
|
@ -231,7 +231,7 @@ class RepoMap:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# dump(fname)
|
# dump(fname)
|
||||||
rel_fname = os.path.relpath(fname, self.root)
|
rel_fname = self.get_rel_fname(fname)
|
||||||
|
|
||||||
if fname in chat_fnames:
|
if fname in chat_fnames:
|
||||||
personalization[rel_fname] = 1.0
|
personalization[rel_fname] = 1.0
|
||||||
|
@ -304,9 +304,7 @@ class RepoMap:
|
||||||
continue
|
continue
|
||||||
ranked_tags += list(definitions.get((fname, ident), []))
|
ranked_tags += list(definitions.get((fname, ident), []))
|
||||||
|
|
||||||
rel_other_fnames_without_tags = set(
|
rel_other_fnames_without_tags = set(self.get_rel_fname(fname) for fname in other_fnames)
|
||||||
os.path.relpath(fname, self.root) for fname in other_fnames
|
|
||||||
)
|
|
||||||
|
|
||||||
fnames_already_included = set(rt[0] for rt in ranked_tags)
|
fnames_already_included = set(rt[0] for rt in ranked_tags)
|
||||||
|
|
||||||
|
@ -333,9 +331,11 @@ class RepoMap:
|
||||||
upper_bound = num_tags
|
upper_bound = num_tags
|
||||||
best_tree = None
|
best_tree = None
|
||||||
|
|
||||||
|
chat_rel_fnames = [self.get_rel_fname(fname) for fname in chat_fnames]
|
||||||
|
|
||||||
while lower_bound <= upper_bound:
|
while lower_bound <= upper_bound:
|
||||||
middle = (lower_bound + upper_bound) // 2
|
middle = (lower_bound + upper_bound) // 2
|
||||||
tree = self.to_tree(ranked_tags[:middle])
|
tree = self.to_tree(ranked_tags[:middle], chat_rel_fnames)
|
||||||
num_tokens = self.token_count(tree)
|
num_tokens = self.token_count(tree)
|
||||||
|
|
||||||
if num_tokens < self.max_map_tokens:
|
if num_tokens < self.max_map_tokens:
|
||||||
|
@ -346,10 +346,11 @@ class RepoMap:
|
||||||
|
|
||||||
return best_tree
|
return best_tree
|
||||||
|
|
||||||
def to_tree(self, tags):
|
def to_tree(self, tags, chat_rel_fnames):
|
||||||
if not tags:
|
if not tags:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
tags = [tag for tag in tags if tag[0] not in chat_rel_fnames]
|
||||||
tags = sorted(tags)
|
tags = sorted(tags)
|
||||||
|
|
||||||
cur_fname = None
|
cur_fname = None
|
||||||
|
@ -359,10 +360,10 @@ class RepoMap:
|
||||||
# add a bogus tag at the end so we trip the this_fname != cur_fname...
|
# add a bogus tag at the end so we trip the this_fname != cur_fname...
|
||||||
dummy_tag = (None,)
|
dummy_tag = (None,)
|
||||||
for tag in tags + [dummy_tag]:
|
for tag in tags + [dummy_tag]:
|
||||||
this_fname = tag[0]
|
this_rel_fname = tag[0]
|
||||||
|
|
||||||
# ... here ... to output the final real entry in the list
|
# ... here ... to output the final real entry in the list
|
||||||
if this_fname != cur_fname:
|
if this_rel_fname != cur_fname:
|
||||||
if context:
|
if context:
|
||||||
context.add_context()
|
context.add_context()
|
||||||
output += "\n"
|
output += "\n"
|
||||||
|
@ -388,7 +389,7 @@ class RepoMap:
|
||||||
# header_max=30,
|
# header_max=30,
|
||||||
show_top_of_file_parent_scope=False,
|
show_top_of_file_parent_scope=False,
|
||||||
)
|
)
|
||||||
cur_fname = this_fname
|
cur_fname = this_rel_fname
|
||||||
|
|
||||||
if context:
|
if context:
|
||||||
context.add_lines_of_interest([tag.line])
|
context.add_lines_of_interest([tag.line])
|
||||||
|
|
|
@ -119,6 +119,36 @@ print(my_function(3, 4))
|
||||||
# close the open cache files, so Windows won't error
|
# close the open cache files, so Windows won't error
|
||||||
del repo_map
|
del repo_map
|
||||||
|
|
||||||
|
def test_get_repo_map_excludes_added_files(self):
|
||||||
|
# Create a temporary directory with sample files for testing
|
||||||
|
test_files = [
|
||||||
|
"test_file1.py",
|
||||||
|
"test_file2.py",
|
||||||
|
"test_file3.md",
|
||||||
|
"test_file4.json",
|
||||||
|
]
|
||||||
|
|
||||||
|
with IgnorantTemporaryDirectory() as temp_dir:
|
||||||
|
for file in test_files:
|
||||||
|
with open(os.path.join(temp_dir, file), "w") as f:
|
||||||
|
f.write("def foo(): pass\n")
|
||||||
|
|
||||||
|
io = InputOutput()
|
||||||
|
repo_map = RepoMap(root=temp_dir, io=io)
|
||||||
|
test_files = [os.path.join(temp_dir, file) for file in test_files]
|
||||||
|
result = repo_map.get_repo_map(test_files[:2], test_files[2:])
|
||||||
|
|
||||||
|
dump(result)
|
||||||
|
|
||||||
|
# Check if the result contains the expected tags map
|
||||||
|
self.assertNotIn("test_file1.py", result)
|
||||||
|
self.assertNotIn("test_file2.py", result)
|
||||||
|
self.assertIn("test_file3.md", result)
|
||||||
|
self.assertIn("test_file4.json", result)
|
||||||
|
|
||||||
|
# close the open cache files, so Windows won't error
|
||||||
|
del repo_map
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue