This commit is contained in:
Yusuf Özer 2025-05-18 16:36:03 +10:00 committed by GitHub
commit 5a1344b709
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 201 additions and 2 deletions

View file

@ -12,7 +12,7 @@ from .patch_coder import PatchCoder
from .udiff_coder import UnifiedDiffCoder
from .udiff_simple import UnifiedDiffSimpleCoder
from .wholefile_coder import WholeFileCoder
from .batch_coder import BatchCoder
# from .single_wholefile_func_coder import SingleWholeFileFunctionCoder
__all__ = [
@ -21,6 +21,7 @@ __all__ = [
Coder,
EditBlockCoder,
EditBlockFencedCoder,
BatchCoder,
WholeFileCoder,
PatchCoder,
UnifiedDiffCoder,

110
aider/coders/batch_coder.py Normal file
View file

@ -0,0 +1,110 @@
from typing import Tuple, override
import copy
from aider.coders.base_coder import Coder
from aider.coders.base_prompts import CoderPrompts
"""Perform a coding task on multiple files in batches that fit the context and outpot token limits, without sending them all at once."""
class BatchCoder(Coder):
coder : Coder = None
original_kwargs: dict = None
edit_format = "batch"
def __init__(self, main_model, io, **kwargs):
super().__init__(main_model, io,**kwargs)
if 'gpt_prompts' not in kwargs: self.gpt_prompts = CoderPrompts()
@override
def run_one(self, user_message, preproc):
if self.coder is None:
self.coder = Coder.create(main_model=self.main_model, edit_format=self.main_model.edit_format,from_coder=self,**self.original_kwargs)
self.coder.auto_lint, self.coder.auto_commits = (False,False)
chat_files_with_type_and_length = self.get_chat_files_with_type_and_length()
max_tokens = self.main_model.info.get('max_tokens')
max_context = self.main_model.info['max_input_tokens']
max_output = max_tokens if max_tokens is not None else self.main_model.info['max_output_tokens']
repo_token_count = self.main_model.get_repo_map_tokens()
history_token_count = sum([tup[0] for tup in self.summarizer.tokenize( [msg["content"] for msg in self.done_messages])])
prev_io= self.io.yes #shell commmands will still need confirmation for each command, this can be overridden by extending InputOutput class and overriding confirm_ask method.
self.io.yes = True
cruncher = self.file_cruncher( max_context, max_output,repo_token_count + history_token_count,
chat_files_with_type_and_length)
edited_files = self.batch_process(user_message,preproc, cruncher)
self.io.yes= prev_io
if len(edited_files) == 0: return
if self.auto_lint:
cruncher.files_to_crunch = [(fname,True,self.main_model.token_count(self.io.read_text(fname))) for fname in edited_files]
self.batch_lint(cruncher,preproc)
if self.auto_commits:
self.batch_commit(edited_files)
def get_chat_files_with_type_and_length(self):
chat_files_with_type_and_length : list[Tuple[str,bool,int]]=[]
for f in self.abs_fnames:
chat_files_with_type_and_length.append((f, True, self.main_model.token_count(self.io.read_text(f))))
for f in self.abs_read_only_fnames:
chat_files_with_type_and_length.append((f,False,self.main_model.token_count(self.io.read_text(f))))
return chat_files_with_type_and_length
def batch_process(self,message,preproc, cruncher):
edited_files= []
for files_to_send_with_types in cruncher:
self.prepare_batch(files_to_send_with_types)
self.coder.run_one(message,preproc)
edited_files.extend(self.coder.aider_edited_files)
self.coder.aider_edited_files = set()
return edited_files
def prepare_batch(self,files_to_send_with_types : list[Tuple[str,bool]]):
self.coder.done_messages = copy.deepcopy(self.done_messages)
self.coder.cur_messages = []
self.coder.abs_fnames=set([f[0] for f in files_to_send_with_types if f[1]])
self.coder.abs_read_only_fnames=set(f[0] for f in files_to_send_with_types if not f[1])
def batch_lint(self, cruncher,preproc):
for files_with_type in cruncher:
files = [ft[0] for ft in files_with_type]
lint_msg = self.coder.lint_edited(files)
self.auto_commit(files,context="Ran the linter")
if lint_msg:
ok = self.io.confirm_ask("Attempt to fix lint errors?", subject="batch_lint", allow_never=True)
if ok:
self.coder.done_messages, self.coder.cur_messages = ([],[])
self.coder.run_one(lint_msg,preproc)
def batch_commit(self, files : list[str]):
self.repo.commit(files)
class file_cruncher:
context_tokens: int
max_context:int
max_output:int
files_to_crunch : list[Tuple[str,bool,int]]
PADDING:int = 50
def __init__(self,max_context:int,max_output:int,context_tokens,files_to_crunch : list[Tuple[str,bool,int]] ):
self.context_tokens = context_tokens
self.max_context = max_context
self.max_output = max_output
self.files_to_crunch = sorted(files_to_crunch, key = lambda x: x[2])
def __iter__(self):
return self
"""fitting input files + chat history + repo_map + files_to_send to context limit and
files_to_send to the output limit.
output files are assumed to be half the size of input files"""
def __next__(self):
if len(self.files_to_crunch) == 0:
raise StopIteration
files_to_send : list[Tuple[str,bool]]= []
i:int =0
total_context= 0
total_output= 0
for file_name, type_, length in self.files_to_crunch:
if length + length / 2 + self.context_tokens + total_context>= self.max_context or length / 2 + total_output >= self.max_output:
break
total_context+=length + length + self.PADDING
total_output+=length + self.PADDING
files_to_send.append((file_name,type_))
i+=1
if i == 0: #no file fits the limits, roll the dice and let the user deal with it
f,t,_ = self.files_to_crunch[i]
files_to_send.append((copy.copy(f), t))
i=1
self.files_to_crunch = self.files_to_crunch[i:]
return files_to_send

View file

@ -1144,7 +1144,9 @@ class Commands:
map_mul_no_files=map_mul_no_files,
show_announcements=False,
)
def cmd_batch(self, args):
"""Iteratively perform the change on files in batches that fit to context and output limits"""
return self._generic_chat_command(args, "batch")
def completions_ask(self):
raise CommandCompletionException()

86
tests/basic/test_batch.py Normal file
View file

@ -0,0 +1,86 @@
import os
import unittest
from pathlib import Path
from unittest.mock import MagicMock, patch
from aider.coders import Coder
from aider.io import InputOutput
from aider.models import Model
from aider.repo import GitRepo
from aider.utils import GitTemporaryDirectory
class TestBatchCoder(unittest.TestCase):
def setUp(self):
self.GPT35 = Model("gpt-3.5-turbo")
self.io = InputOutput(yes=True)
# self.webbrowser_patcher = patch("aider.io.webbrowser.open")
# self.mock_webbrowser = self.webbrowser_patcher.start()
# Get all Python files in aider/coders directory
coders_dir = Path(__file__).parent.parent.parent / "aider" / "coders"
self.files = [str(f) for f in coders_dir.glob("*.py") if f.is_file()]
# Create coder with all files
self.coder = Coder.create(
main_model=self.GPT35,
io=self.io,
fnames=self.files,
edit_format='batch'
)
def tearDown(self):
# self.webbrowser_patcher.stop()
return
"""Tests that:
- Every request retains the chat history until the /batch command but not the history of other iterations.
- Added files and history until the /batch is unmodified.
- Every file is processed(even if a single file that'll be sent with the request exceeds the limits.) and no duplicate processing
"""
def test_iterate_resets_history_and_processes_all_files(self):
processed_files :list[str]= []
original_context:list[dict[str,str]]
prev_file_names : list[str] = None
# Track messages sent to LLM and files processed
def mock_send(self,messages, model=None, functions=None):
nonlocal original_context
nonlocal processed_files
nonlocal prev_file_names
for original_message in original_context:
assert original_message in messages, f"Chat history before start of the command is not retained."
# Simulate response mentioning filename
files_message = [msg['content'] for msg in messages if "*added these files to the chat*" in msg['content']][0]
from re import findall
file_names = findall(r'.*\n(\S+\.py)\n```.*',files_message)
for f_name in file_names:
assert prev_file_names == None or f_name not in prev_file_names, "files from previous iterations hasn't been cleaned up."
prev_file_names = file_names
processed_files.extend(file_names)
# Return minimal response
self.partial_response_content = "Done."
self.partial_response_function_call=dict()
with GitTemporaryDirectory():
# Mock the send method
with (patch.object(Coder, 'send',new_callable=lambda: mock_send), patch.object(Coder, 'lint_edited',lambda *_,**__:None), patch.object(GitRepo,'commit',lambda *_,**__:None)):
self.coder.coder = Coder.create(main_model=self.coder.main_model, edit_format=self.coder.main_model.edit_format,from_coder=self.coder,**self.coder.original_kwargs)
# Add initial conversation history
original_context = self.coder.done_messages = [
{"role": "user", "content": "Initial conversation"},
{"role": "assistant", "content": "OK"}
]
# Run iterate command
self.coder.run(with_message="Process all files")
# Verify all files were processed
input_basenames = {Path(f).name for f in self.files}
processed_basenames = {Path(f).name for f in processed_files}
missing = input_basenames - processed_basenames
assert not missing, f"Files not processed: {missing}"
# Verify history preservation and structure
assert len(self.coder.done_messages) == 2, "Original chat history was modified"
# Verify final file state
assert len(self.coder.abs_fnames) == len(self.files), "Not all files remained in chat"
if __name__ == "__main__":
unittest.main()