From d54f8836980d27997e3a2ad0e28f3484d2b7d028 Mon Sep 17 00:00:00 2001 From: Paul Gauthier Date: Wed, 15 Nov 2023 12:52:59 -0800 Subject: [PATCH] do not include added files in the repo map #315 --- aider/repomap.py | 19 ++++++++++--------- tests/test_repomap.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 9 deletions(-) diff --git a/aider/repomap.py b/aider/repomap.py index 3bdc22899..25a6ad463 100644 --- a/aider/repomap.py +++ b/aider/repomap.py @@ -231,7 +231,7 @@ class RepoMap: continue # dump(fname) - rel_fname = os.path.relpath(fname, self.root) + rel_fname = self.get_rel_fname(fname) if fname in chat_fnames: personalization[rel_fname] = 1.0 @@ -304,9 +304,7 @@ class RepoMap: continue ranked_tags += list(definitions.get((fname, ident), [])) - rel_other_fnames_without_tags = set( - os.path.relpath(fname, self.root) for fname in other_fnames - ) + rel_other_fnames_without_tags = set(self.get_rel_fname(fname) for fname in other_fnames) fnames_already_included = set(rt[0] for rt in ranked_tags) @@ -333,9 +331,11 @@ class RepoMap: upper_bound = num_tags best_tree = None + chat_rel_fnames = [self.get_rel_fname(fname) for fname in chat_fnames] + while lower_bound <= upper_bound: middle = (lower_bound + upper_bound) // 2 - tree = self.to_tree(ranked_tags[:middle]) + tree = self.to_tree(ranked_tags[:middle], chat_rel_fnames) num_tokens = self.token_count(tree) if num_tokens < self.max_map_tokens: @@ -346,10 +346,11 @@ class RepoMap: return best_tree - def to_tree(self, tags): + def to_tree(self, tags, chat_rel_fnames): if not tags: return "" + tags = [tag for tag in tags if tag[0] not in chat_rel_fnames] tags = sorted(tags) cur_fname = None @@ -359,10 +360,10 @@ class RepoMap: # add a bogus tag at the end so we trip the this_fname != cur_fname... dummy_tag = (None,) for tag in tags + [dummy_tag]: - this_fname = tag[0] + this_rel_fname = tag[0] # ... here ... to output the final real entry in the list - if this_fname != cur_fname: + if this_rel_fname != cur_fname: if context: context.add_context() output += "\n" @@ -388,7 +389,7 @@ class RepoMap: # header_max=30, show_top_of_file_parent_scope=False, ) - cur_fname = this_fname + cur_fname = this_rel_fname if context: context.add_lines_of_interest([tag.line]) diff --git a/tests/test_repomap.py b/tests/test_repomap.py index 9fb7dd66f..e081cc66e 100644 --- a/tests/test_repomap.py +++ b/tests/test_repomap.py @@ -119,6 +119,36 @@ print(my_function(3, 4)) # close the open cache files, so Windows won't error del repo_map + def test_get_repo_map_excludes_added_files(self): + # Create a temporary directory with sample files for testing + test_files = [ + "test_file1.py", + "test_file2.py", + "test_file3.md", + "test_file4.json", + ] + + with IgnorantTemporaryDirectory() as temp_dir: + for file in test_files: + with open(os.path.join(temp_dir, file), "w") as f: + f.write("def foo(): pass\n") + + io = InputOutput() + repo_map = RepoMap(root=temp_dir, io=io) + test_files = [os.path.join(temp_dir, file) for file in test_files] + result = repo_map.get_repo_map(test_files[:2], test_files[2:]) + + dump(result) + + # Check if the result contains the expected tags map + self.assertNotIn("test_file1.py", result) + self.assertNotIn("test_file2.py", result) + self.assertIn("test_file3.md", result) + self.assertIn("test_file4.json", result) + + # close the open cache files, so Windows won't error + del repo_map + if __name__ == "__main__": unittest.main()