This commit is contained in:
Paul Gauthier 2023-05-24 14:17:13 -07:00
parent a1c0a84f74
commit 80a2f10e83
2 changed files with 64 additions and 50 deletions

View file

@ -10,7 +10,6 @@ from rich.console import Console
from rich.live import Live from rich.live import Live
from rich.markdown import Markdown from rich.markdown import Markdown
from pathlib import Path from pathlib import Path
import tiktoken
import git import git
import openai import openai
@ -19,7 +18,7 @@ import openai
from aider import utils from aider import utils
from aider import prompts from aider import prompts
from aider.commands import Commands from aider.commands import Commands
from aider.ctags import get_tags_map from aider.ctags import RepoMap
openai.api_key = os.getenv("OPENAI_API_KEY") openai.api_key = os.getenv("OPENAI_API_KEY")
@ -43,7 +42,6 @@ class Coder:
verbose=False, verbose=False,
): ):
self.verbose = verbose self.verbose = verbose
self.use_ctags = use_ctags
self.abs_fnames = set() self.abs_fnames = set()
self.cur_messages = [] self.cur_messages = []
self.done_messages = [] self.done_messages = []
@ -75,7 +73,8 @@ class Coder:
self.pretty = pretty self.pretty = pretty
self.show_diffs = show_diffs self.show_diffs = show_diffs
self.tokenizer = tiktoken.encoding_for_model(self.main_model)
self.repo_map = RepoMap(use_ctags, self.root, self.main_model)
def find_common_root(self): def find_common_root(self):
if self.abs_fnames: if self.abs_fnames:
@ -172,7 +171,7 @@ class Coder:
all_content += files_content all_content += files_content
other_files = set(self.get_all_abs_files()) - set(self.abs_fnames) other_files = set(self.get_all_abs_files()) - set(self.abs_fnames)
repo_content = self.get_repo_map(self.abs_fnames, other_files) repo_content = self.repo_map.get_repo_map(self.abs_fnames, other_files)
if repo_content: if repo_content:
if all_content: if all_content:
all_content += "\n" all_content += "\n"
@ -189,51 +188,6 @@ class Coder:
return files_messages return files_messages
def get_repo_map(self, chat_files, other_files):
res = self.choose_files_listing(other_files)
if not res:
return
files_listing, ctags_msg = res
if chat_files:
other = "other "
else:
other = ""
repo_content = prompts.repo_content_prefix.format(
other=other,
ctags_msg=ctags_msg,
)
repo_content += files_listing
return repo_content
def choose_files_listing(self, other_files):
# 1/4 of gpt-4's context window
max_map_tokens = 2048
if not other_files:
return
if self.use_ctags:
files_listing = get_tags_map(other_files)
if self.token_count(files_listing) < max_map_tokens:
ctags_msg = " with selected ctags info"
return files_listing, ctags_msg
files_listing = self.get_simple_files_map(other_files)
ctags_msg = ""
if self.token_count(files_listing) < max_map_tokens:
return files_listing, ctags_msg
def get_simple_files_map(self, other_files):
files_listing = "\n".join(self.get_rel_fname(ofn) for ofn in sorted(other_files))
return files_listing
def token_count(self, string):
return len(self.tokenizer.encode(string))
def run(self): def run(self):
self.done_messages = [] self.done_messages = []
self.cur_messages = [] self.cur_messages = []

View file

@ -2,6 +2,9 @@ import os
import json import json
import sys import sys
import subprocess import subprocess
import tiktoken
from aider import prompts
# Global cache for tags # Global cache for tags
TAGS_CACHE = {} TAGS_CACHE = {}
@ -91,6 +94,63 @@ def get_tags(filename, root_dname):
return tags return tags
class RepoMap:
use_ctags = False
def __init__(self, use_ctags, root, main_model):
self.use_ctags = use_ctags
self.tokenizer = tiktoken.encoding_for_model(main_model)
self.root = root
def get_repo_map(self, chat_files, other_files):
res = self.choose_files_listing(other_files)
if not res:
return
files_listing, ctags_msg = res
if chat_files:
other = "other "
else:
other = ""
repo_content = prompts.repo_content_prefix.format(
other=other,
ctags_msg=ctags_msg,
)
repo_content += files_listing
return repo_content
def choose_files_listing(self, other_files):
# 1/4 of gpt-4's context window
max_map_tokens = 2048
if not other_files:
return
if self.use_ctags:
files_listing = get_tags_map(other_files)
if self.token_count(files_listing) < max_map_tokens:
ctags_msg = " with selected ctags info"
return files_listing, ctags_msg
files_listing = self.get_simple_files_map(other_files)
ctags_msg = ""
if self.token_count(files_listing) < max_map_tokens:
return files_listing, ctags_msg
def get_simple_files_map(self, other_files):
files_listing = "\n".join(self.get_rel_fname(ofn) for ofn in sorted(other_files))
return files_listing
def token_count(self, string):
return len(self.tokenizer.encode(string))
def get_rel_fname(self, fname):
return os.path.relpath(fname, self.root)
if __name__ == "__main__": if __name__ == "__main__":
res = get_tags_map(sys.argv[1:]) res = get_tags_map(sys.argv[1:])
print(res) print(res)