refactor: Update repo map token handling and improve warning message

This commit is contained in:
Paul Gauthier 2025-01-10 14:38:12 -08:00 committed by Paul Gauthier (aider)
parent a9cf438100
commit d48008e13d
4 changed files with 15 additions and 6 deletions

View file

@ -287,7 +287,7 @@ def get_parser(default_config_files, git_root):
"--map-tokens", "--map-tokens",
type=int, type=int,
default=None, default=None,
help="Suggested number of tokens to use for repo map, use 0 to disable (default: 1024)", help="Suggested number of tokens to use for repo map, use 0 to disable",
) )
group.add_argument( group.add_argument(
"--map-refresh", "--map-refresh",

View file

@ -231,10 +231,10 @@ class Coder:
if map_tokens > 0: if map_tokens > 0:
refresh = self.repo_map.refresh refresh = self.repo_map.refresh
lines.append(f"Repo-map: using {map_tokens} tokens, {refresh} refresh") lines.append(f"Repo-map: using {map_tokens} tokens, {refresh} refresh")
max_map_tokens = 2048 max_map_tokens = self.main_model.get_repo_map_tokens() * 2
if map_tokens > max_map_tokens: if map_tokens > max_map_tokens:
lines.append( lines.append(
f"Warning: map-tokens > {max_map_tokens} is not recommended as too much" f"Warning: map-tokens > {max_map_tokens} is not recommended. Too much"
" irrelevant code can confuse LLMs." " irrelevant code can confuse LLMs."
) )
else: else:

View file

@ -636,8 +636,6 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F
analytics.event("launched") analytics.event("launched")
# ai
if args.gui and not return_coder: if args.gui and not return_coder:
if not check_streamlit_install(io): if not check_streamlit_install(io):
analytics.event("exit", reason="Streamlit not installed") analytics.event("exit", reason="Streamlit not installed")
@ -850,6 +848,8 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F
) )
args.stream = False args.stream = False
map_tokens = args.map_tokens or main_model.get_repo_map_tokens()
try: try:
coder = Coder.create( coder = Coder.create(
main_model=main_model, main_model=main_model,
@ -862,7 +862,7 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F
auto_commits=args.auto_commits, auto_commits=args.auto_commits,
dirty_commits=args.dirty_commits, dirty_commits=args.dirty_commits,
dry_run=args.dry_run, dry_run=args.dry_run,
map_tokens=args.map_tokens, map_tokens=map_tokens,
verbose=args.verbose, verbose=args.verbose,
stream=args.stream, stream=args.stream,
use_git=args.git, use_git=args.git,

View file

@ -1177,6 +1177,15 @@ class Model(ModelSettings):
return res return res
def get_repo_map_tokens(self):
map_tokens = 1024
max_inp_tokens = self.info.get("max_input_tokens")
if max_inp_tokens:
map_tokens = max_inp_tokens / 8
map_tokens = min(map_tokens, 4096)
max_tokens = max(map_tokens, 1024)
return map_tokens
def register_models(model_settings_fnames): def register_models(model_settings_fnames):
files_loaded = [] files_loaded = []