diff --git a/aider/coder.py b/aider/coder.py index bad5e4cd5..3cb73cbed 100755 --- a/aider/coder.py +++ b/aider/coder.py @@ -414,69 +414,22 @@ class Coder: return edited - def commit(self, history=None, prefix=None, ask=False, message=None, which="chat_files"): - repo = self.repo - if not repo: - return - - if not repo.is_dirty(): - return - - def get_dirty_files(file_list): - dirty_files = [] - relative_dirty_files = [] - for fname in file_list: - relative_fname = os.path.relpath(fname, repo.working_tree_dir) - if self.pretty: - these_diffs = repo.git.diff("HEAD", "--color", relative_fname) - else: - these_diffs = repo.git.diff("HEAD", relative_fname) - - if these_diffs: - dirty_files.append(fname) - relative_dirty_files.append(relative_fname) - - return dirty_files, relative_dirty_files - - if which == "repo_files": - all_files = [os.path.join(self.root, f) for f in self.get_all_relative_files()] - dirty_fnames, relative_dirty_fnames = get_dirty_files(all_files) - elif which == "chat_files": - dirty_fnames, relative_dirty_fnames = get_dirty_files(self.abs_fnames) - else: - raise ValueError(f"Invalid value for 'which': {which}") - - diffs = "" - for (abs_fname,relative_fname) in zip(dirty_fnames, relative_dirty_fnames): - if self.pretty: - these_diffs = repo.git.diff("HEAD", "--color", relative_fname) - else: - these_diffs = repo.git.diff("HEAD", relative_fname) - - diffs += these_diffs + "\n" - - if self.show_diffs or ask: - self.console.print(Text(diffs)) - - diffs = "# Diffs:\n" + diffs - - # for fname in dirty_fnames: - # self.console.print(f"[red] {fname}") - + def get_context_from_history(self, history): context = "" if history: context += "# Context:\n" for msg in history: context += msg["role"].upper() + ": " + msg["content"] + "\n" + return context + + def get_commit_message(self, diffs, context): + diffs = "# Diffs:\n" + diffs messages = [ dict(role="system", content=prompts.commit_system), dict(role="user", content=context + diffs), ] - # if history: - # self.show_messages(messages, "commit") - commit_message, interrupted = self.send( messages, model="gpt-3.5-turbo", @@ -491,6 +444,46 @@ class Coder: ) return + return commit_message + + def commit(self, history=None, prefix=None, ask=False, message=None, which="chat_files"): + repo = self.repo + if not repo: + return + + if not repo.is_dirty(): + return + + def get_dirty_files(file_list): + diffs = "" + relative_dirty_files = [] + for fname in file_list: + relative_fname = os.path.relpath(fname, repo.working_tree_dir) + if self.pretty: + these_diffs = repo.git.diff("HEAD", "--color", relative_fname) + else: + these_diffs = repo.git.diff("HEAD", relative_fname) + + if these_diffs: + relative_dirty_files.append(relative_fname) + diffs += these_diffs + "\n" + + return relative_dirty_files, diffs + + if which == "repo_files": + all_files = [os.path.join(self.root, f) for f in self.get_all_relative_files()] + relative_dirty_fnames,diffs = get_dirty_files(all_files) + elif which == "chat_files": + relative_dirty_fnames,diffs = get_dirty_files(self.abs_fnames) + else: + raise ValueError(f"Invalid value for 'which': {which}") + + if self.show_diffs or ask: + self.console.print(Text(diffs)) + + context = self.get_context_from_history(history) + commit_message = self.get_commit_message(diffs, context) + if prefix: commit_message = prefix + commit_message