refactor: Simplify context window handling and remove Ollama-specific warnings

This commit is contained in:
Paul Gauthier 2025-02-04 12:45:44 -08:00 committed by Paul Gauthier (aider)
parent e313a2ea45
commit 0af6dc3838
2 changed files with 5 additions and 21 deletions

View file

@ -1246,26 +1246,13 @@ class Coder:
self.io.tool_output("- Use /drop to remove unneeded files from the chat")
self.io.tool_output("- Use /clear to clear the chat history")
self.io.tool_output("- Break your code into smaller files")
proceed = "Y"
self.io.tool_output(
"It's probably safe to try and send the request, most providers won't charge if"
" the context limit is exceeded."
)
# Special warning for Ollama models about context window size
if self.main_model.name.startswith(("ollama/", "ollama_chat/")):
extra_params = getattr(self.main_model, "extra_params", None) or {}
num_ctx = extra_params.get("num_ctx", 2048)
if input_tokens > num_ctx:
proceed = "N"
self.io.tool_warning(f"""
Your Ollama model is configured with num_ctx={num_ctx} tokens of context window.
You are attempting to send {input_tokens} tokens.
See https://aider.chat/docs/llms/ollama.html#setting-the-context-window-size
""".strip()) # noqa
if proceed and not self.io.confirm_ask("Try to proceed anyway?", default=proceed):
return False
if not self.io.confirm_ask("Try to proceed anyway?"):
return False
return True
def send_message(self, inp):

View file

@ -261,10 +261,6 @@ class Model(ModelSettings):
if not exact_match:
self.apply_generic_model_settings(model)
if model.startswith("ollama/") or model.startswith("ollama_chat/"):
if not (self.extra_params and "num_ctx" in self.extra_params):
self.extra_params = dict(num_ctx=8 * 1024)
# Apply override settings last if they exist
if self.extra_model_settings and self.extra_model_settings.extra_params:
# Initialize extra_params if it doesn't exist
@ -561,9 +557,10 @@ class Model(ModelSettings):
if self.extra_params:
kwargs.update(self.extra_params)
if self.is_ollama() and "num_ctx" not in kwargs:
kwargs["num_ctx"] = int(self.token_count(messages) * 1.5)
num_ctx = int(self.token_count(messages) * 1.25) + 8192
kwargs["num_ctx"] = num_ctx
key = json.dumps(kwargs, sort_keys=True).encode()
# dump(kwargs)
hash_object = hashlib.sha1(key)
res = litellm.completion(**kwargs)
return hash_object, res