Use fq model name in --models search

This commit is contained in:
Paul Gauthier 2024-04-22 19:17:27 -07:00
parent 25b8d6fec8
commit 89a7b3470a
3 changed files with 48 additions and 29 deletions

View file

@ -346,7 +346,11 @@ def sanity_check_model(io, model):
if possible_matches:
io.tool_error("Did you mean one of these?")
for match in possible_matches:
io.tool_error(f"- {match}")
fq, m = match
if fq == m:
io.tool_error(f"- {m}")
else:
io.tool_error(f"- {m} ({fq})")
if show:
io.tool_error("For more info see https://aider.chat/docs/llms.html#model-warnings")
@ -355,30 +359,41 @@ def sanity_check_model(io, model):
def fuzzy_match_models(name):
chat_models = [
model for model, attrs in litellm.model_cost.items() if attrs.get("mode") == "chat"
]
chat_models = []
for model, attrs in litellm.model_cost.items():
if attrs.get("mode") != "chat":
continue
provider = attrs["litellm_provider"] + "/"
if model.startswith(provider):
fq_model = model
else:
fq_model = provider + model
chat_models.append((fq_model, model))
# exactly matching model
matching_models = [model for model in chat_models if name == model]
if matching_models:
return matching_models
# exactly matching provider
matching_models = [
model for model in chat_models if litellm.model_cost[model]["litellm_provider"] == name
]
if matching_models:
return matching_models
# matching_models = [
# (fq,m) for fq,m in chat_models
# if name == fq or name == m
# ]
# if matching_models:
# return matching_models
# Check for model names containing the name
matching_models = [model for model in chat_models if name in model]
matching_models = [(fq, m) for fq, m in chat_models if name in fq]
if matching_models:
return matching_models
# Check for slight misspellings
matching_models = difflib.get_close_matches(name, chat_models, n=3, cutoff=0.8)
return matching_models
models = [m for fq, m in chat_models]
matching_models = difflib.get_close_matches(name, models, n=3, cutoff=0.8)
if matching_models:
return list(zip(matching_models, matching_models))
fq_models = [fq for fq, m in chat_models]
matching_models = difflib.get_close_matches(name, fq_models, n=3, cutoff=0.8)
return list(zip(matching_models, matching_models))
def main():