diff --git a/tests/basic/test_scripting.py b/tests/basic/test_scripting.py index ed4edbb0b..b2873a15c 100644 --- a/tests/basic/test_scripting.py +++ b/tests/basic/test_scripting.py @@ -1,26 +1,27 @@ import unittest -from unittest.mock import patch, MagicMock +from unittest.mock import MagicMock, patch + from aider.coders import Coder -from aider.models import Model from aider.io import InputOutput +from aider.models import Model + class TestScriptingAPI(unittest.TestCase): - - @patch('aider.coders.Coder.create') - @patch('aider.models.Model') + @patch("aider.coders.Coder.create") + @patch("aider.models.Model") def test_basic_scripting(self, mock_model, mock_coder_create): # Setup mock_coder = MagicMock() mock_coder_create.return_value = mock_coder - + # Test script fnames = ["greeting.py"] model = Model("gpt-4-turbo") coder = Coder.create(main_model=model, fnames=fnames) - + coder.run("make a script that prints hello world") coder.run("make it say goodbye") - + # Assertions mock_model.assert_called_once_with("gpt-4-turbo") mock_coder_create.assert_called_once_with(main_model=model, fnames=fnames) @@ -28,26 +29,27 @@ class TestScriptingAPI(unittest.TestCase): mock_coder.run.assert_any_call("make a script that prints hello world") mock_coder.run.assert_any_call("make it say goodbye") - @patch('aider.coders.Coder.create') - @patch('aider.models.Model') + @patch("aider.coders.Coder.create") + @patch("aider.models.Model") def test_scripting_with_io(self, mock_model, mock_coder_create): # Setup mock_coder = MagicMock() mock_coder_create.return_value = mock_coder - + # Test script fnames = ["greeting.py"] model = Model("gpt-4-turbo") io = InputOutput(yes=True) coder = Coder.create(main_model=model, fnames=fnames, io=io) - + coder.run("add a new function") - + # Assertions mock_model.assert_called_once_with("gpt-4-turbo") mock_coder_create.assert_called_once_with(main_model=model, fnames=fnames, io=io) mock_coder.run.assert_called_once_with("add a new function") self.assertTrue(io.yes) # Check that 'yes' is set to True -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main()