fix: Improve confirm_ask function in io.py

This commit is contained in:
Paul Gauthier 2024-08-24 09:05:14 -07:00 committed by Paul Gauthier (aider)
parent 0b33892137
commit e5b0fe3dfd
2 changed files with 9 additions and 16 deletions

View file

@ -418,8 +418,6 @@ class InputOutput:
return True
return text.lower() in valid_responses
error_message = f"Please answer with one of: {', '.join(valid_responses)}"
if self.yes is True:
res = "n" if explicit_yes_required else "y"
elif self.yes is False:
@ -432,12 +430,17 @@ class InputOutput:
res = prompt(
question,
style=Style.from_dict(style),
validator=validator,
)
if not res:
res = "y" # Default to Yes if no input
break
# todo: break if res.lower() is a prefix a `valid_responses`
res = res.lower()
good = any(valid_response.startswith(res) for valid_response in valid_responses)
if good:
break
error_message = f"Please answer with one of: {', '.join(valid_responses)}"
self.tool_error(error_message)
res = res.lower()[0]

View file

@ -101,7 +101,7 @@ class TestInputOutput(unittest.TestCase):
mock_prompt.return_value = "a"
result = io.confirm_ask("Are you sure?", group=group)
self.assertTrue(result)
self.assertEqual(group.preference, "a")
self.assertEqual(group.preference, "all")
mock_prompt.assert_called_once()
mock_prompt.reset_mock()
@ -115,7 +115,7 @@ class TestInputOutput(unittest.TestCase):
mock_prompt.return_value = "s"
result = io.confirm_ask("Are you sure?", group=group)
self.assertFalse(result)
self.assertEqual(group.preference, "s")
self.assertEqual(group.preference, "skip")
mock_prompt.assert_called_once()
mock_prompt.reset_mock()
@ -134,16 +134,6 @@ class TestInputOutput(unittest.TestCase):
self.assertNotIn("(A)ll", mock_prompt.call_args[0][0])
mock_prompt.reset_mock()
# Test case 6: explicit_yes_required=True, user tries to select 'All'
group.preference = None
mock_prompt.return_value = "a"
result = io.confirm_ask("Are you sure?", group=group, explicit_yes_required=True)
self.assertFalse(result)
self.assertIsNone(group.preference)
mock_prompt.assert_called_once()
self.assertNotIn("(A)ll", mock_prompt.call_args[0][0])
mock_prompt.reset_mock()
@patch("aider.io.prompt")
def test_confirm_ask_yes_no(self, mock_prompt):
io = InputOutput(pretty=False)