mirror of
https://github.com/Aider-AI/aider.git
synced 2025-05-29 16:54:59 +00:00
refactor
This commit is contained in:
parent
a1c0a84f74
commit
80a2f10e83
2 changed files with 64 additions and 50 deletions
|
@ -10,7 +10,6 @@ from rich.console import Console
|
||||||
from rich.live import Live
|
from rich.live import Live
|
||||||
from rich.markdown import Markdown
|
from rich.markdown import Markdown
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import tiktoken
|
|
||||||
|
|
||||||
import git
|
import git
|
||||||
import openai
|
import openai
|
||||||
|
@ -19,7 +18,7 @@ import openai
|
||||||
from aider import utils
|
from aider import utils
|
||||||
from aider import prompts
|
from aider import prompts
|
||||||
from aider.commands import Commands
|
from aider.commands import Commands
|
||||||
from aider.ctags import get_tags_map
|
from aider.ctags import RepoMap
|
||||||
|
|
||||||
openai.api_key = os.getenv("OPENAI_API_KEY")
|
openai.api_key = os.getenv("OPENAI_API_KEY")
|
||||||
|
|
||||||
|
@ -43,7 +42,6 @@ class Coder:
|
||||||
verbose=False,
|
verbose=False,
|
||||||
):
|
):
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.use_ctags = use_ctags
|
|
||||||
self.abs_fnames = set()
|
self.abs_fnames = set()
|
||||||
self.cur_messages = []
|
self.cur_messages = []
|
||||||
self.done_messages = []
|
self.done_messages = []
|
||||||
|
@ -75,7 +73,8 @@ class Coder:
|
||||||
|
|
||||||
self.pretty = pretty
|
self.pretty = pretty
|
||||||
self.show_diffs = show_diffs
|
self.show_diffs = show_diffs
|
||||||
self.tokenizer = tiktoken.encoding_for_model(self.main_model)
|
|
||||||
|
self.repo_map = RepoMap(use_ctags, self.root, self.main_model)
|
||||||
|
|
||||||
def find_common_root(self):
|
def find_common_root(self):
|
||||||
if self.abs_fnames:
|
if self.abs_fnames:
|
||||||
|
@ -172,7 +171,7 @@ class Coder:
|
||||||
all_content += files_content
|
all_content += files_content
|
||||||
|
|
||||||
other_files = set(self.get_all_abs_files()) - set(self.abs_fnames)
|
other_files = set(self.get_all_abs_files()) - set(self.abs_fnames)
|
||||||
repo_content = self.get_repo_map(self.abs_fnames, other_files)
|
repo_content = self.repo_map.get_repo_map(self.abs_fnames, other_files)
|
||||||
if repo_content:
|
if repo_content:
|
||||||
if all_content:
|
if all_content:
|
||||||
all_content += "\n"
|
all_content += "\n"
|
||||||
|
@ -189,51 +188,6 @@ class Coder:
|
||||||
|
|
||||||
return files_messages
|
return files_messages
|
||||||
|
|
||||||
def get_repo_map(self, chat_files, other_files):
|
|
||||||
res = self.choose_files_listing(other_files)
|
|
||||||
if not res:
|
|
||||||
return
|
|
||||||
|
|
||||||
files_listing, ctags_msg = res
|
|
||||||
|
|
||||||
if chat_files:
|
|
||||||
other = "other "
|
|
||||||
else:
|
|
||||||
other = ""
|
|
||||||
|
|
||||||
repo_content = prompts.repo_content_prefix.format(
|
|
||||||
other=other,
|
|
||||||
ctags_msg=ctags_msg,
|
|
||||||
)
|
|
||||||
repo_content += files_listing
|
|
||||||
|
|
||||||
return repo_content
|
|
||||||
|
|
||||||
def choose_files_listing(self, other_files):
|
|
||||||
# 1/4 of gpt-4's context window
|
|
||||||
max_map_tokens = 2048
|
|
||||||
|
|
||||||
if not other_files:
|
|
||||||
return
|
|
||||||
|
|
||||||
if self.use_ctags:
|
|
||||||
files_listing = get_tags_map(other_files)
|
|
||||||
if self.token_count(files_listing) < max_map_tokens:
|
|
||||||
ctags_msg = " with selected ctags info"
|
|
||||||
return files_listing, ctags_msg
|
|
||||||
|
|
||||||
files_listing = self.get_simple_files_map(other_files)
|
|
||||||
ctags_msg = ""
|
|
||||||
if self.token_count(files_listing) < max_map_tokens:
|
|
||||||
return files_listing, ctags_msg
|
|
||||||
|
|
||||||
def get_simple_files_map(self, other_files):
|
|
||||||
files_listing = "\n".join(self.get_rel_fname(ofn) for ofn in sorted(other_files))
|
|
||||||
return files_listing
|
|
||||||
|
|
||||||
def token_count(self, string):
|
|
||||||
return len(self.tokenizer.encode(string))
|
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
self.done_messages = []
|
self.done_messages = []
|
||||||
self.cur_messages = []
|
self.cur_messages = []
|
||||||
|
|
|
@ -2,6 +2,9 @@ import os
|
||||||
import json
|
import json
|
||||||
import sys
|
import sys
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
|
from aider import prompts
|
||||||
|
|
||||||
# Global cache for tags
|
# Global cache for tags
|
||||||
TAGS_CACHE = {}
|
TAGS_CACHE = {}
|
||||||
|
@ -91,6 +94,63 @@ def get_tags(filename, root_dname):
|
||||||
return tags
|
return tags
|
||||||
|
|
||||||
|
|
||||||
|
class RepoMap:
|
||||||
|
use_ctags = False
|
||||||
|
|
||||||
|
def __init__(self, use_ctags, root, main_model):
|
||||||
|
self.use_ctags = use_ctags
|
||||||
|
self.tokenizer = tiktoken.encoding_for_model(main_model)
|
||||||
|
self.root = root
|
||||||
|
|
||||||
|
def get_repo_map(self, chat_files, other_files):
|
||||||
|
res = self.choose_files_listing(other_files)
|
||||||
|
if not res:
|
||||||
|
return
|
||||||
|
|
||||||
|
files_listing, ctags_msg = res
|
||||||
|
|
||||||
|
if chat_files:
|
||||||
|
other = "other "
|
||||||
|
else:
|
||||||
|
other = ""
|
||||||
|
|
||||||
|
repo_content = prompts.repo_content_prefix.format(
|
||||||
|
other=other,
|
||||||
|
ctags_msg=ctags_msg,
|
||||||
|
)
|
||||||
|
repo_content += files_listing
|
||||||
|
|
||||||
|
return repo_content
|
||||||
|
|
||||||
|
def choose_files_listing(self, other_files):
|
||||||
|
# 1/4 of gpt-4's context window
|
||||||
|
max_map_tokens = 2048
|
||||||
|
|
||||||
|
if not other_files:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.use_ctags:
|
||||||
|
files_listing = get_tags_map(other_files)
|
||||||
|
if self.token_count(files_listing) < max_map_tokens:
|
||||||
|
ctags_msg = " with selected ctags info"
|
||||||
|
return files_listing, ctags_msg
|
||||||
|
|
||||||
|
files_listing = self.get_simple_files_map(other_files)
|
||||||
|
ctags_msg = ""
|
||||||
|
if self.token_count(files_listing) < max_map_tokens:
|
||||||
|
return files_listing, ctags_msg
|
||||||
|
|
||||||
|
def get_simple_files_map(self, other_files):
|
||||||
|
files_listing = "\n".join(self.get_rel_fname(ofn) for ofn in sorted(other_files))
|
||||||
|
return files_listing
|
||||||
|
|
||||||
|
def token_count(self, string):
|
||||||
|
return len(self.tokenizer.encode(string))
|
||||||
|
|
||||||
|
def get_rel_fname(self, fname):
|
||||||
|
return os.path.relpath(fname, self.root)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
res = get_tags_map(sys.argv[1:])
|
res = get_tags_map(sys.argv[1:])
|
||||||
print(res)
|
print(res)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue