diff --git a/tests/basic/test_io.py b/tests/basic/test_io.py index 89e759dca..642e19863 100644 --- a/tests/basic/test_io.py +++ b/tests/basic/test_io.py @@ -162,6 +162,20 @@ class TestInputOutput(unittest.TestCase): mock_prompt.assert_called_once() mock_prompt.reset_mock() + def test_get_command_completions(self): + root = "" + rel_fnames = [] + addable_rel_fnames = [] + commands = MagicMock() + commands.get_commands.return_value = ["model", "chat", "help"] + commands.get_completions.return_value = ["gpt-3.5-turbo", "gpt-4"] + + autocompleter = AutoCompleter(root, rel_fnames, addable_rel_fnames, commands, "utf-8") + + # Test case for "/model gpt" + result = autocompleter.get_command_completions("/model gpt", ["/model", "gpt"]) + self.assertEqual(result, ["gpt-3.5-turbo", "gpt-4"]) + commands.get_completions.assert_called_once_with("model") if __name__ == "__main__": unittest.main()