This commit is contained in:
Paul Gauthier 2023-05-24 14:17:13 -07:00
parent a1c0a84f74
commit 80a2f10e83
2 changed files with 64 additions and 50 deletions

View file

@ -10,7 +10,6 @@ from rich.console import Console
from rich.live import Live
from rich.markdown import Markdown
from pathlib import Path
import tiktoken
import git
import openai
@ -19,7 +18,7 @@ import openai
from aider import utils
from aider import prompts
from aider.commands import Commands
from aider.ctags import get_tags_map
from aider.ctags import RepoMap
openai.api_key = os.getenv("OPENAI_API_KEY")
@ -43,7 +42,6 @@ class Coder:
verbose=False,
):
self.verbose = verbose
self.use_ctags = use_ctags
self.abs_fnames = set()
self.cur_messages = []
self.done_messages = []
@ -75,7 +73,8 @@ class Coder:
self.pretty = pretty
self.show_diffs = show_diffs
self.tokenizer = tiktoken.encoding_for_model(self.main_model)
self.repo_map = RepoMap(use_ctags, self.root, self.main_model)
def find_common_root(self):
if self.abs_fnames:
@ -172,7 +171,7 @@ class Coder:
all_content += files_content
other_files = set(self.get_all_abs_files()) - set(self.abs_fnames)
repo_content = self.get_repo_map(self.abs_fnames, other_files)
repo_content = self.repo_map.get_repo_map(self.abs_fnames, other_files)
if repo_content:
if all_content:
all_content += "\n"
@ -189,51 +188,6 @@ class Coder:
return files_messages
def get_repo_map(self, chat_files, other_files):
res = self.choose_files_listing(other_files)
if not res:
return
files_listing, ctags_msg = res
if chat_files:
other = "other "
else:
other = ""
repo_content = prompts.repo_content_prefix.format(
other=other,
ctags_msg=ctags_msg,
)
repo_content += files_listing
return repo_content
def choose_files_listing(self, other_files):
# 1/4 of gpt-4's context window
max_map_tokens = 2048
if not other_files:
return
if self.use_ctags:
files_listing = get_tags_map(other_files)
if self.token_count(files_listing) < max_map_tokens:
ctags_msg = " with selected ctags info"
return files_listing, ctags_msg
files_listing = self.get_simple_files_map(other_files)
ctags_msg = ""
if self.token_count(files_listing) < max_map_tokens:
return files_listing, ctags_msg
def get_simple_files_map(self, other_files):
files_listing = "\n".join(self.get_rel_fname(ofn) for ofn in sorted(other_files))
return files_listing
def token_count(self, string):
return len(self.tokenizer.encode(string))
def run(self):
self.done_messages = []
self.cur_messages = []

View file

@ -2,6 +2,9 @@ import os
import json
import sys
import subprocess
import tiktoken
from aider import prompts
# Global cache for tags
TAGS_CACHE = {}
@ -91,6 +94,63 @@ def get_tags(filename, root_dname):
return tags
class RepoMap:
use_ctags = False
def __init__(self, use_ctags, root, main_model):
self.use_ctags = use_ctags
self.tokenizer = tiktoken.encoding_for_model(main_model)
self.root = root
def get_repo_map(self, chat_files, other_files):
res = self.choose_files_listing(other_files)
if not res:
return
files_listing, ctags_msg = res
if chat_files:
other = "other "
else:
other = ""
repo_content = prompts.repo_content_prefix.format(
other=other,
ctags_msg=ctags_msg,
)
repo_content += files_listing
return repo_content
def choose_files_listing(self, other_files):
# 1/4 of gpt-4's context window
max_map_tokens = 2048
if not other_files:
return
if self.use_ctags:
files_listing = get_tags_map(other_files)
if self.token_count(files_listing) < max_map_tokens:
ctags_msg = " with selected ctags info"
return files_listing, ctags_msg
files_listing = self.get_simple_files_map(other_files)
ctags_msg = ""
if self.token_count(files_listing) < max_map_tokens:
return files_listing, ctags_msg
def get_simple_files_map(self, other_files):
files_listing = "\n".join(self.get_rel_fname(ofn) for ofn in sorted(other_files))
return files_listing
def token_count(self, string):
return len(self.tokenizer.encode(string))
def get_rel_fname(self, fname):
return os.path.relpath(fname, self.root)
if __name__ == "__main__":
res = get_tags_map(sys.argv[1:])
print(res)