diff --git a/aider/coders/base_coder.py b/aider/coders/base_coder.py index 89a9017bf..e7947961b 100755 --- a/aider/coders/base_coder.py +++ b/aider/coders/base_coder.py @@ -7,9 +7,8 @@ import sys import time import traceback from json.decoder import JSONDecodeError -from pathlib import Path, PurePosixPath +from pathlib import Path -import git import openai from jsonschema import Draft7Validator from rich.console import Console, Text @@ -18,6 +17,7 @@ from rich.markdown import Markdown from aider import models, prompts, utils from aider.commands import Commands +from aider.repo import AiderRepo from aider.repomap import RepoMap from aider.sendchat import send_with_retries @@ -149,12 +149,16 @@ class Coder: self.commands = Commands(self.io, self) if use_git: - self.set_repo(fnames) + try: + self.repo = AiderRepo(fnames) + self.root = self.repo.root + except FileNotFoundError: + self.repo = None else: self.abs_fnames = set([str(Path(fname).resolve()) for fname in fnames]) if self.repo: - rel_repo_dir = self.get_rel_repo_dir() + rel_repo_dir = self.repo.get_rel_repo_dir() self.io.tool_output(f"Git repo: {rel_repo_dir}") else: self.io.tool_output("Git repo: none") @@ -376,7 +380,7 @@ class Coder: if self.should_dirty_commit(inp): self.io.tool_output("Git repo has uncommitted changes, preparing commit...") - self.commit(ask=True, which="repo_files") + self.commit(ask=True, which="repo_files", pretty=self.pretty) # files changed, move cur messages back behind the files messages self.move_back_cur_messages(self.gpt_prompts.files_content_local_edits) @@ -495,8 +499,16 @@ class Coder: ) ] + def get_context_from_history(self, history): + context = "" + if history: + for msg in history: + context += "\n" + msg["role"].upper() + ": " + msg["content"] + "\n" + return context + def auto_commit(self): - res = self.commit(history=self.cur_messages, prefix="aider: ") + context = self.get_context_from_history(self.cur_messages) + res = self.commit(context=context, prefix="aider: ", pretty=self.pretty) if res: commit_hash, commit_message = res self.last_aider_commit_hash = commit_hash @@ -553,7 +565,7 @@ class Coder: return prompts.added_files.format(fnames=", ".join(mentioned_rel_fnames)) - def send(self, messages, model=None, silent=False, functions=None): + def send(self, messages, model=None, functions=None): if not model: model = self.main_model.name @@ -566,25 +578,24 @@ class Coder: self.chat_completion_call_hashes.append(hash_object.hexdigest()) if self.stream: - self.show_send_output_stream(completion, silent) + self.show_send_output_stream(completion) else: - self.show_send_output(completion, silent) + self.show_send_output(completion) except KeyboardInterrupt: self.keyboard_interrupt() interrupted = True - if not silent: - if self.partial_response_content: - self.io.ai_output(self.partial_response_content) - elif self.partial_response_function_call: - # TODO: push this into subclasses - args = self.parse_partial_args() - if args: - self.io.ai_output(json.dumps(args, indent=4)) + if self.partial_response_content: + self.io.ai_output(self.partial_response_content) + elif self.partial_response_function_call: + # TODO: push this into subclasses + args = self.parse_partial_args() + if args: + self.io.ai_output(json.dumps(args, indent=4)) return interrupted - def show_send_output(self, completion, silent): + def show_send_output(self, completion): if self.verbose: print(completion) @@ -633,9 +644,9 @@ class Coder: self.io.console.print(show_resp) self.io.console.print(tokens) - def show_send_output_stream(self, completion, silent): + def show_send_output_stream(self, completion): live = None - if self.pretty and not silent: + if self.pretty: live = Live(vertical_overflow="scroll") try: @@ -664,9 +675,6 @@ class Coder: except AttributeError: pass - if silent: - continue - if self.pretty: self.live_incremental_response(live, False) else: @@ -697,7 +705,7 @@ class Coder: def get_all_relative_files(self): if self.repo: - files = self.get_tracked_files() + files = self.repo.get_tracked_files() else: files = self.get_inchat_relative_files() @@ -752,25 +760,6 @@ class Coder: return full_path - def get_tracked_files(self): - if not self.repo: - return [] - - try: - commit = self.repo.head.commit - except ValueError: - return set() - - files = [] - for blob in commit.tree.traverse(): - if blob.type == "blob": # blob is a file - files.append(blob.path) - - # convert to appropriate os.sep, since git always normalizes to / - res = set(str(Path(PurePosixPath(path))) for path in files) - - return res - apply_update_errors = 0 def apply_updates(self): diff --git a/aider/coders/single_wholefile_func_coder.py b/aider/coders/single_wholefile_func_coder.py index 2835c5f7c..2d99c3645 100644 --- a/aider/coders/single_wholefile_func_coder.py +++ b/aider/coders/single_wholefile_func_coder.py @@ -42,15 +42,6 @@ class SingleWholeFileFunctionCoder(Coder): else: self.cur_messages += [dict(role="assistant", content=self.partial_response_content)] - def get_context_from_history(self, history): - context = "" - if history: - context += "# Context:\n" - for msg in history: - if msg["role"] == "user": - context += msg["role"].upper() + ": " + msg["content"] + "\n" - return context - def render_incremental_response(self, final=False): if self.partial_response_content: return self.partial_response_content diff --git a/aider/coders/wholefile_coder.py b/aider/coders/wholefile_coder.py index 3ccd0dac5..efec434af 100644 --- a/aider/coders/wholefile_coder.py +++ b/aider/coders/wholefile_coder.py @@ -20,15 +20,6 @@ class WholeFileCoder(Coder): else: self.cur_messages += [dict(role="assistant", content=self.partial_response_content)] - def get_context_from_history(self, history): - context = "" - if history: - context += "# Context:\n" - for msg in history: - if msg["role"] == "user": - context += msg["role"].upper() + ": " + msg["content"] + "\n" - return context - def render_incremental_response(self, final): try: return self.update_files(mode="diff") diff --git a/aider/coders/wholefile_func_coder.py b/aider/coders/wholefile_func_coder.py index e51150482..ebf10764d 100644 --- a/aider/coders/wholefile_func_coder.py +++ b/aider/coders/wholefile_func_coder.py @@ -55,15 +55,6 @@ class WholeFileFunctionCoder(Coder): else: self.cur_messages += [dict(role="assistant", content=self.partial_response_content)] - def get_context_from_history(self, history): - context = "" - if history: - context += "# Context:\n" - for msg in history: - if msg["role"] == "user": - context += msg["role"].upper() + ": " + msg["content"] + "\n" - return context - def render_incremental_response(self, final=False): if self.partial_response_content: return self.partial_response_content diff --git a/aider/commands.py b/aider/commands.py index 10e9cd740..11af5ec7f 100644 --- a/aider/commands.py +++ b/aider/commands.py @@ -215,7 +215,9 @@ class Commands: return commits = f"{self.coder.last_aider_commit_hash}~1" - diff = self.coder.get_diffs(commits, self.coder.last_aider_commit_hash) + diff = self.coder.get_diffs( + commits, self.coder.last_aider_commit_hash, pretty=self.coder.pretty + ) # don't use io.tool_output() because we don't want to log or further colorize print(diff) @@ -247,7 +249,7 @@ class Commands: added_fnames = [] git_added = [] - git_files = self.coder.get_tracked_files() + git_files = self.coder.get_tracked_files() if self.coder.repo else [] all_matched_files = set() for word in args.split(): diff --git a/aider/repo.py b/aider/repo.py index ca05a721c..d570748bc 100644 --- a/aider/repo.py +++ b/aider/repo.py @@ -1,4 +1,11 @@ +import os +from pathlib import Path, PurePosixPath + import git +import openai + +from aider import models, prompts, utils +from aider.sendchat import send_with_retries class AiderRepo: @@ -30,23 +37,26 @@ class AiderRepo: if fname.is_dir(): continue - self.abs_fnames.add(str(fname)) - num_repos = len(set(repo_paths)) if num_repos == 0: - return + raise FileNotFoundError if num_repos > 1: self.io.tool_error("Files are in different git repos.") - return + raise FileNotFoundError # https://github.com/gitpython-developers/GitPython/issues/427 self.repo = git.Repo(repo_paths.pop(), odbt=git.GitDB) - self.root = utils.safe_abs_path(self.repo.working_tree_dir) + def ___(self, fnames): + + # TODO! + + self.abs_fnames.add(str(fname)) + new_files = [] - for fname in self.abs_fnames: + for fname in fnames: relative_fname = self.get_rel_fname(fname) tracked_files = set(self.get_tracked_files()) @@ -71,7 +81,12 @@ class AiderRepo: else: self.io.tool_error("Skipped adding new files to the git repo.") - def commit(self, history=None, prefix=None, ask=False, message=None, which="chat_files"): + def commit( + self, context=None, prefix=None, ask=False, message=None, which="chat_files", pretty=False + ): + + ## TODO! + repo = self.repo if not repo: return @@ -96,7 +111,7 @@ class AiderRepo: if not current_branch_commit_count: continue - these_diffs = self.get_diffs("HEAD", "--", relative_fname) + these_diffs = self.get_diffs(pretty, "HEAD", "--", relative_fname) if these_diffs: diffs += these_diffs + "\n" @@ -115,7 +130,6 @@ class AiderRepo: # don't use io.tool_output() because we don't want to log or further colorize print(diffs) - context = self.get_context_from_history(history) if message: commit_message = message else: @@ -162,13 +176,6 @@ class AiderRepo: except ValueError: return self.repo.git_dir - def get_context_from_history(self, history): - context = "" - if history: - for msg in history: - context += "\n" + msg["role"].upper() + ": " + msg["content"] + "\n" - return context - def get_commit_message(self, diffs, context): if len(diffs) >= 4 * 1024 * 4: self.io.tool_error( @@ -184,34 +191,45 @@ class AiderRepo: ] try: - interrupted = self.send( - messages, + _hash, response = send_with_retries( model=models.GPT35.name, - silent=True, - ) - except openai.error.InvalidRequestError: - self.io.tool_error( - f"Failed to generate commit message using {models.GPT35.name} due to an invalid" - " request." + messages=messages, + functions=None, + stream=False, ) + commit_message = completion.choices[0].message.content + except (AttributeError, openai.error.InvalidRequestError): + self.io.tool_error(f"Failed to generate commit message using {models.GPT35.name}") return - commit_message = self.partial_response_content commit_message = commit_message.strip() if commit_message and commit_message[0] == '"' and commit_message[-1] == '"': commit_message = commit_message[1:-1].strip() - if interrupted: - self.io.tool_error( - f"Unable to get commit message from {models.GPT35.name}. Use /commit to try again." - ) - return - return commit_message - def get_diffs(self, *args): - if self.pretty: + def get_diffs(self, pretty, *args): + if pretty: args = ["--color"] + list(args) diffs = self.repo.git.diff(*args) return diffs + + def get_tracked_files(self): + if not self.repo: + return [] + + try: + commit = self.repo.head.commit + except ValueError: + return set() + + files = [] + for blob in commit.tree.traverse(): + if blob.type == "blob": # blob is a file + files.append(blob.path) + + # convert to appropriate os.sep, since git always normalizes to / + res = set(str(Path(PurePosixPath(path))) for path in files) + + return res