mirror of
https://github.com/Aider-AI/aider.git
synced 2025-05-30 17:24:59 +00:00
test: add tests for Ollama model num_ctx handling
This commit is contained in:
parent
d7de908c66
commit
9094af565f
1 changed files with 59 additions and 0 deletions
|
@ -208,6 +208,65 @@ class TestModels(unittest.TestCase):
|
||||||
except OSError:
|
except OSError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@patch('aider.models.litellm.completion')
|
||||||
|
@patch.object(Model, 'token_count')
|
||||||
|
def test_ollama_num_ctx_set_when_missing(self, mock_token_count, mock_completion):
|
||||||
|
mock_token_count.return_value = 1000
|
||||||
|
|
||||||
|
model = Model("ollama/llama3")
|
||||||
|
messages = [{"role": "user", "content": "Hello"}]
|
||||||
|
|
||||||
|
model.send_completion(messages, functions=None, stream=False)
|
||||||
|
|
||||||
|
# Verify num_ctx was calculated and added to call
|
||||||
|
expected_ctx = int(1000 * 1.25) + 8192 # 9442
|
||||||
|
mock_completion.assert_called_once_with(
|
||||||
|
model=model.name,
|
||||||
|
messages=messages,
|
||||||
|
stream=False,
|
||||||
|
temperature=ANY,
|
||||||
|
num_ctx=expected_ctx,
|
||||||
|
tools=None,
|
||||||
|
tool_choice=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch('aider.models.litellm.completion')
|
||||||
|
def test_ollama_uses_existing_num_ctx(self, mock_completion):
|
||||||
|
model = Model("ollama/llama3")
|
||||||
|
model.extra_params = {"num_ctx": 4096}
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": "Hello"}]
|
||||||
|
model.send_completion(messages, functions=None, stream=False)
|
||||||
|
|
||||||
|
# Should use provided num_ctx from extra_params
|
||||||
|
mock_completion.assert_called_once_with(
|
||||||
|
model=model.name,
|
||||||
|
messages=messages,
|
||||||
|
stream=False,
|
||||||
|
temperature=ANY,
|
||||||
|
num_ctx=4096,
|
||||||
|
tools=None,
|
||||||
|
tool_choice=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch('aider.models.litellm.completion')
|
||||||
|
def test_non_ollama_no_num_ctx(self, mock_completion):
|
||||||
|
model = Model("gpt-4")
|
||||||
|
messages = [{"role": "user", "content": "Hello"}]
|
||||||
|
|
||||||
|
model.send_completion(messages, functions=None, stream=False)
|
||||||
|
|
||||||
|
# Regular models shouldn't get num_ctx
|
||||||
|
mock_completion.assert_called_once_with(
|
||||||
|
model=model.name,
|
||||||
|
messages=messages,
|
||||||
|
stream=False,
|
||||||
|
temperature=ANY,
|
||||||
|
tools=None,
|
||||||
|
tool_choice=None,
|
||||||
|
)
|
||||||
|
self.assertNotIn('num_ctx', mock_completion.call_args.kwargs)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue