diff --git a/aider/coder.py b/aider/coder.py index bf2e1164c..dd43f392d 100755 --- a/aider/coder.py +++ b/aider/coder.py @@ -10,7 +10,6 @@ from rich.console import Console from rich.live import Live from rich.markdown import Markdown from pathlib import Path -import tiktoken import git import openai @@ -19,7 +18,7 @@ import openai from aider import utils from aider import prompts from aider.commands import Commands -from aider.ctags import get_tags_map +from aider.ctags import RepoMap openai.api_key = os.getenv("OPENAI_API_KEY") @@ -43,7 +42,6 @@ class Coder: verbose=False, ): self.verbose = verbose - self.use_ctags = use_ctags self.abs_fnames = set() self.cur_messages = [] self.done_messages = [] @@ -75,7 +73,8 @@ class Coder: self.pretty = pretty self.show_diffs = show_diffs - self.tokenizer = tiktoken.encoding_for_model(self.main_model) + + self.repo_map = RepoMap(use_ctags, self.root, self.main_model) def find_common_root(self): if self.abs_fnames: @@ -172,7 +171,7 @@ class Coder: all_content += files_content other_files = set(self.get_all_abs_files()) - set(self.abs_fnames) - repo_content = self.get_repo_map(self.abs_fnames, other_files) + repo_content = self.repo_map.get_repo_map(self.abs_fnames, other_files) if repo_content: if all_content: all_content += "\n" @@ -189,51 +188,6 @@ class Coder: return files_messages - def get_repo_map(self, chat_files, other_files): - res = self.choose_files_listing(other_files) - if not res: - return - - files_listing, ctags_msg = res - - if chat_files: - other = "other " - else: - other = "" - - repo_content = prompts.repo_content_prefix.format( - other=other, - ctags_msg=ctags_msg, - ) - repo_content += files_listing - - return repo_content - - def choose_files_listing(self, other_files): - # 1/4 of gpt-4's context window - max_map_tokens = 2048 - - if not other_files: - return - - if self.use_ctags: - files_listing = get_tags_map(other_files) - if self.token_count(files_listing) < max_map_tokens: - ctags_msg = " with selected ctags info" - return files_listing, ctags_msg - - files_listing = self.get_simple_files_map(other_files) - ctags_msg = "" - if self.token_count(files_listing) < max_map_tokens: - return files_listing, ctags_msg - - def get_simple_files_map(self, other_files): - files_listing = "\n".join(self.get_rel_fname(ofn) for ofn in sorted(other_files)) - return files_listing - - def token_count(self, string): - return len(self.tokenizer.encode(string)) - def run(self): self.done_messages = [] self.cur_messages = [] diff --git a/aider/ctags.py b/aider/ctags.py index 983bbd11e..30de11320 100644 --- a/aider/ctags.py +++ b/aider/ctags.py @@ -2,6 +2,9 @@ import os import json import sys import subprocess +import tiktoken + +from aider import prompts # Global cache for tags TAGS_CACHE = {} @@ -91,6 +94,63 @@ def get_tags(filename, root_dname): return tags +class RepoMap: + use_ctags = False + + def __init__(self, use_ctags, root, main_model): + self.use_ctags = use_ctags + self.tokenizer = tiktoken.encoding_for_model(main_model) + self.root = root + + def get_repo_map(self, chat_files, other_files): + res = self.choose_files_listing(other_files) + if not res: + return + + files_listing, ctags_msg = res + + if chat_files: + other = "other " + else: + other = "" + + repo_content = prompts.repo_content_prefix.format( + other=other, + ctags_msg=ctags_msg, + ) + repo_content += files_listing + + return repo_content + + def choose_files_listing(self, other_files): + # 1/4 of gpt-4's context window + max_map_tokens = 2048 + + if not other_files: + return + + if self.use_ctags: + files_listing = get_tags_map(other_files) + if self.token_count(files_listing) < max_map_tokens: + ctags_msg = " with selected ctags info" + return files_listing, ctags_msg + + files_listing = self.get_simple_files_map(other_files) + ctags_msg = "" + if self.token_count(files_listing) < max_map_tokens: + return files_listing, ctags_msg + + def get_simple_files_map(self, other_files): + files_listing = "\n".join(self.get_rel_fname(ofn) for ofn in sorted(other_files)) + return files_listing + + def token_count(self, string): + return len(self.tokenizer.encode(string)) + + def get_rel_fname(self, fname): + return os.path.relpath(fname, self.root) + + if __name__ == "__main__": res = get_tags_map(sys.argv[1:]) print(res)