diff --git a/aider/coders/base_coder.py b/aider/coders/base_coder.py index c5ace1287..6fc81c8f2 100755 --- a/aider/coders/base_coder.py +++ b/aider/coders/base_coder.py @@ -7,21 +7,19 @@ import sys import time import traceback from json.decoder import JSONDecodeError -from pathlib import Path, PurePosixPath +from pathlib import Path -import backoff -import git import openai -import requests from jsonschema import Draft7Validator -from openai.error import APIError, RateLimitError, ServiceUnavailableError, Timeout from rich.console import Console, Text from rich.live import Live from rich.markdown import Markdown from aider import models, prompts, utils from aider.commands import Commands +from aider.repo import GitRepo from aider.repomap import RepoMap +from aider.sendchat import send_with_retries from ..dump import dump # noqa: F401 @@ -100,6 +98,7 @@ class Coder: main_model, io, fnames=None, + git_dname=None, pretty=True, show_diffs=False, auto_commits=True, @@ -150,13 +149,27 @@ class Coder: self.commands = Commands(self.io, self) + for fname in fnames: + fname = Path(fname) + if not fname.exists(): + self.io.tool_output(f"Creating empty file {fname}") + fname.parent.mkdir(parents=True, exist_ok=True) + fname.touch() + + if not fname.is_file(): + raise ValueError(f"{fname} is not a file") + + self.abs_fnames.add(str(fname.resolve())) + if use_git: - self.set_repo(fnames) - else: - self.abs_fnames = set([str(Path(fname).resolve()) for fname in fnames]) + try: + self.repo = GitRepo(self.io, fnames, git_dname) + self.root = self.repo.root + except FileNotFoundError: + self.repo = None if self.repo: - rel_repo_dir = self.get_rel_repo_dir() + rel_repo_dir = self.repo.get_rel_repo_dir() self.io.tool_output(f"Git repo: {rel_repo_dir}") else: self.io.tool_output("Git repo: none") @@ -187,6 +200,9 @@ class Coder: for fname in self.get_inchat_relative_files(): self.io.tool_output(f"Added {fname} to the chat.") + if self.repo: + self.repo.add_new_files(fname for fname in fnames if not Path(fname).is_dir()) + # validate the functions jsonschema if self.functions: for function in self.functions: @@ -206,12 +222,6 @@ class Coder: self.root = utils.safe_abs_path(self.root) - def get_rel_repo_dir(self): - try: - return os.path.relpath(self.repo.git_dir, os.getcwd()) - except ValueError: - return self.repo.git_dir - def add_rel_fname(self, rel_fname): self.abs_fnames.add(self.abs_root_path(rel_fname)) @@ -219,73 +229,6 @@ class Coder: res = Path(self.root) / path return utils.safe_abs_path(res) - def set_repo(self, cmd_line_fnames): - if not cmd_line_fnames: - cmd_line_fnames = ["."] - - repo_paths = [] - for fname in cmd_line_fnames: - fname = Path(fname) - if not fname.exists(): - self.io.tool_output(f"Creating empty file {fname}") - fname.parent.mkdir(parents=True, exist_ok=True) - fname.touch() - - fname = fname.resolve() - - try: - repo_path = git.Repo(fname, search_parent_directories=True).working_dir - repo_path = utils.safe_abs_path(repo_path) - repo_paths.append(repo_path) - except git.exc.InvalidGitRepositoryError: - pass - - if fname.is_dir(): - continue - - self.abs_fnames.add(str(fname)) - - num_repos = len(set(repo_paths)) - - if num_repos == 0: - return - if num_repos > 1: - self.io.tool_error("Files are in different git repos.") - return - - # https://github.com/gitpython-developers/GitPython/issues/427 - self.repo = git.Repo(repo_paths.pop(), odbt=git.GitDB) - - self.root = utils.safe_abs_path(self.repo.working_tree_dir) - - new_files = [] - for fname in self.abs_fnames: - relative_fname = self.get_rel_fname(fname) - - tracked_files = set(self.get_tracked_files()) - if relative_fname not in tracked_files: - new_files.append(relative_fname) - - if new_files: - rel_repo_dir = self.get_rel_repo_dir() - - self.io.tool_output(f"Files not tracked in {rel_repo_dir}:") - for fn in new_files: - self.io.tool_output(f" - {fn}") - if self.io.confirm_ask("Add them?"): - for relative_fname in new_files: - self.repo.git.add(relative_fname) - self.io.tool_output(f"Added {relative_fname} to the git repo") - show_files = ", ".join(new_files) - commit_message = f"Added new files to the git repo: {show_files}" - self.repo.git.commit("-m", commit_message, "--no-verify") - commit_hash = self.repo.head.commit.hexsha[:7] - self.io.tool_output(f"Commit {commit_hash} {commit_message}") - else: - self.io.tool_error("Skipped adding new files to the git repo.") - return - - # fences are obfuscated so aider can modify this file! fences = [ ("``" + "`", "``" + "`"), wrap_fence("source"), @@ -412,25 +355,6 @@ class Coder: self.last_keyboard_interrupt = now - def should_dirty_commit(self, inp): - cmds = self.commands.matching_commands(inp) - if cmds: - matching_commands, _, _ = cmds - if len(matching_commands) == 1: - cmd = matching_commands[0][1:] - if cmd in "add clear commit diff drop exit help ls tokens".split(): - return - - if not self.dirty_commits: - return - if not self.repo: - return - if not self.repo.is_dirty(): - return - if self.last_asked_for_commit_time >= self.get_last_modified(): - return - return True - def move_back_cur_messages(self, message): self.done_messages += self.cur_messages if message: @@ -448,13 +372,7 @@ class Coder: self.commands, ) - if self.should_dirty_commit(inp): - self.io.tool_output("Git repo has uncommitted changes, preparing commit...") - self.commit(ask=True, which="repo_files") - - # files changed, move cur messages back behind the files messages - self.move_back_cur_messages(self.gpt_prompts.files_content_local_edits) - + if self.should_dirty_commit(inp) and self.dirty_commit(): if inp.strip(): self.io.tool_output("Use up-arrow to retry previous command:", inp) return @@ -569,23 +487,6 @@ class Coder: ) ] - def auto_commit(self): - res = self.commit(history=self.cur_messages, prefix="aider: ") - if res: - commit_hash, commit_message = res - self.last_aider_commit_hash = commit_hash - - saved_message = self.gpt_prompts.files_content_gpt_edits.format( - hash=commit_hash, - message=commit_message, - ) - else: - if self.repo: - self.io.tool_output("No changes made to git tracked files.") - saved_message = self.gpt_prompts.files_content_gpt_no_edits - - return saved_message - def check_for_file_mentions(self, content): words = set(word for word in content.split()) @@ -627,44 +528,7 @@ class Coder: return prompts.added_files.format(fnames=", ".join(mentioned_rel_fnames)) - @backoff.on_exception( - backoff.expo, - ( - Timeout, - APIError, - ServiceUnavailableError, - RateLimitError, - requests.exceptions.ConnectionError, - ), - max_tries=10, - on_backoff=lambda details: print( - f"{details.get('exception','Exception')}\nRetry in {details['wait']:.1f} seconds." - ), - ) - def send_with_retries(self, model, messages, functions): - kwargs = dict( - model=model, - messages=messages, - temperature=0, - stream=self.stream, - ) - if functions is not None: - kwargs["functions"] = self.functions - - # we are abusing the openai object to stash these values - if hasattr(openai, "api_deployment_id"): - kwargs["deployment_id"] = openai.api_deployment_id - if hasattr(openai, "api_engine"): - kwargs["engine"] = openai.api_engine - - # Generate SHA1 hash of kwargs and append it to chat_completion_call_hashes - hash_object = hashlib.sha1(json.dumps(kwargs, sort_keys=True).encode()) - self.chat_completion_call_hashes.append(hash_object.hexdigest()) - - res = openai.ChatCompletion.create(**kwargs) - return res - - def send(self, messages, model=None, silent=False, functions=None): + def send(self, messages, model=None, functions=None): if not model: model = self.main_model.name @@ -673,27 +537,28 @@ class Coder: interrupted = False try: - completion = self.send_with_retries(model, messages, functions) + hash_object, completion = send_with_retries(model, messages, functions, self.stream) + self.chat_completion_call_hashes.append(hash_object.hexdigest()) + if self.stream: - self.show_send_output_stream(completion, silent) + self.show_send_output_stream(completion) else: - self.show_send_output(completion, silent) + self.show_send_output(completion) except KeyboardInterrupt: self.keyboard_interrupt() interrupted = True - if not silent: - if self.partial_response_content: - self.io.ai_output(self.partial_response_content) - elif self.partial_response_function_call: - # TODO: push this into subclasses - args = self.parse_partial_args() - if args: - self.io.ai_output(json.dumps(args, indent=4)) + if self.partial_response_content: + self.io.ai_output(self.partial_response_content) + elif self.partial_response_function_call: + # TODO: push this into subclasses + args = self.parse_partial_args() + if args: + self.io.ai_output(json.dumps(args, indent=4)) return interrupted - def show_send_output(self, completion, silent): + def show_send_output(self, completion): if self.verbose: print(completion) @@ -742,9 +607,9 @@ class Coder: self.io.console.print(show_resp) self.io.console.print(tokens) - def show_send_output_stream(self, completion, silent): + def show_send_output_stream(self, completion): live = None - if self.pretty and not silent: + if self.pretty: live = Live(vertical_overflow="scroll") try: @@ -773,9 +638,6 @@ class Coder: except AttributeError: pass - if silent: - continue - if self.pretty: self.live_incremental_response(live, False) else: @@ -797,145 +659,6 @@ class Coder: def render_incremental_response(self, final): return self.partial_response_content - def get_context_from_history(self, history): - context = "" - if history: - for msg in history: - context += "\n" + msg["role"].upper() + ": " + msg["content"] + "\n" - return context - - def get_commit_message(self, diffs, context): - if len(diffs) >= 4 * 1024 * 4: - self.io.tool_error( - f"Diff is too large for {models.GPT35.name} to generate a commit message." - ) - return - - diffs = "# Diffs:\n" + diffs - - messages = [ - dict(role="system", content=prompts.commit_system), - dict(role="user", content=context + diffs), - ] - - try: - interrupted = self.send( - messages, - model=models.GPT35.name, - silent=True, - ) - except openai.error.InvalidRequestError: - self.io.tool_error( - f"Failed to generate commit message using {models.GPT35.name} due to an invalid" - " request." - ) - return - - commit_message = self.partial_response_content - commit_message = commit_message.strip() - if commit_message and commit_message[0] == '"' and commit_message[-1] == '"': - commit_message = commit_message[1:-1].strip() - - if interrupted: - self.io.tool_error( - f"Unable to get commit message from {models.GPT35.name}. Use /commit to try again." - ) - return - - return commit_message - - def get_diffs(self, *args): - if self.pretty: - args = ["--color"] + list(args) - - diffs = self.repo.git.diff(*args) - return diffs - - def commit(self, history=None, prefix=None, ask=False, message=None, which="chat_files"): - repo = self.repo - if not repo: - return - - if not repo.is_dirty(): - return - - def get_dirty_files_and_diffs(file_list): - diffs = "" - relative_dirty_files = [] - for fname in file_list: - relative_fname = self.get_rel_fname(fname) - relative_dirty_files.append(relative_fname) - - try: - current_branch_commit_count = len( - list(self.repo.iter_commits(self.repo.active_branch)) - ) - except git.exc.GitCommandError: - current_branch_commit_count = None - - if not current_branch_commit_count: - continue - - these_diffs = self.get_diffs("HEAD", "--", relative_fname) - - if these_diffs: - diffs += these_diffs + "\n" - - return relative_dirty_files, diffs - - if which == "repo_files": - all_files = [os.path.join(self.root, f) for f in self.get_all_relative_files()] - relative_dirty_fnames, diffs = get_dirty_files_and_diffs(all_files) - elif which == "chat_files": - relative_dirty_fnames, diffs = get_dirty_files_and_diffs(self.abs_fnames) - else: - raise ValueError(f"Invalid value for 'which': {which}") - - if self.show_diffs or ask: - # don't use io.tool_output() because we don't want to log or further colorize - print(diffs) - - context = self.get_context_from_history(history) - if message: - commit_message = message - else: - commit_message = self.get_commit_message(diffs, context) - - if not commit_message: - commit_message = "work in progress" - - if prefix: - commit_message = prefix + commit_message - - if ask: - if which == "repo_files": - self.io.tool_output("Git repo has uncommitted changes.") - else: - self.io.tool_output("Files have uncommitted changes.") - - res = self.io.prompt_ask( - "Commit before the chat proceeds [y/n/commit message]?", - default=commit_message, - ).strip() - self.last_asked_for_commit_time = self.get_last_modified() - - self.io.tool_output() - - if res.lower() in ["n", "no"]: - self.io.tool_error("Skipped commmit.") - return - if res.lower() not in ["y", "yes"] and res: - commit_message = res - - repo.git.add(*relative_dirty_fnames) - - full_commit_message = commit_message + "\n\n# Aider chat conversation:\n\n" + context - repo.git.commit("-m", full_commit_message, "--no-verify") - commit_hash = repo.head.commit.hexsha[:7] - self.io.tool_output(f"Commit {commit_hash} {commit_message}") - - return commit_hash, commit_message - def get_rel_fname(self, fname): return os.path.relpath(fname, self.root) @@ -945,7 +668,7 @@ class Coder: def get_all_relative_files(self): if self.repo: - files = self.get_tracked_files() + files = self.repo.get_tracked_files() else: files = self.get_inchat_relative_files() @@ -1000,32 +723,6 @@ class Coder: return full_path - def get_tracked_files(self): - if not self.repo: - return [] - - try: - commit = self.repo.head.commit - except ValueError: - commit = None - - files = [] - if commit: - for blob in commit.tree.traverse(): - if blob.type == "blob": # blob is a file - files.append(blob.path) - - # Add staged files - index = self.repo.index - staged_files = [path for path, _ in index.entries.keys()] - - files.extend(staged_files) - - # convert to appropriate os.sep, since git always normalizes to / - res = set(str(Path(PurePosixPath(path))) for path in files) - - return res - apply_update_errors = 0 def apply_updates(self): @@ -1094,6 +791,72 @@ class Coder: except JSONDecodeError: pass + # commits... + + def get_context_from_history(self, history): + context = "" + if history: + for msg in history: + context += "\n" + msg["role"].upper() + ": " + msg["content"] + "\n" + return context + + def auto_commit(self): + context = self.get_context_from_history(self.cur_messages) + res = self.repo.commit(context=context, prefix="aider: ") + if res: + commit_hash, commit_message = res + self.last_aider_commit_hash = commit_hash + + return self.gpt_prompts.files_content_gpt_edits.format( + hash=commit_hash, + message=commit_message, + ) + + self.io.tool_output("No changes made to git tracked files.") + return self.gpt_prompts.files_content_gpt_no_edits + + def should_dirty_commit(self, inp): + cmds = self.commands.matching_commands(inp) + if cmds: + matching_commands, _, _ = cmds + if len(matching_commands) == 1: + cmd = matching_commands[0][1:] + if cmd in "add clear commit diff drop exit help ls tokens".split(): + return + + if self.last_asked_for_commit_time >= self.get_last_modified(): + return + return True + + def dirty_commit(self): + if not self.dirty_commits: + return + if not self.repo: + return + if not self.repo.is_dirty(): + return + + self.io.tool_output("Git repo has uncommitted changes.") + self.repo.show_diffs(self.pretty) + self.last_asked_for_commit_time = self.get_last_modified() + res = self.io.prompt_ask( + "Commit before the chat proceeds [y/n/commit message]?", + default="y", + ).strip() + if res.lower() in ["n", "no"]: + self.io.tool_error("Skipped commmit.") + return + if res.lower() in ["y", "yes"]: + message = None + else: + message = res.strip() + + self.repo.commit(message=message) + + # files changed, move cur messages back behind the files messages + self.move_back_cur_messages(self.gpt_prompts.files_content_local_edits) + return True + def check_model_availability(main_model): available_models = openai.Model.list() diff --git a/aider/coders/single_wholefile_func_coder.py b/aider/coders/single_wholefile_func_coder.py index 2835c5f7c..2d99c3645 100644 --- a/aider/coders/single_wholefile_func_coder.py +++ b/aider/coders/single_wholefile_func_coder.py @@ -42,15 +42,6 @@ class SingleWholeFileFunctionCoder(Coder): else: self.cur_messages += [dict(role="assistant", content=self.partial_response_content)] - def get_context_from_history(self, history): - context = "" - if history: - context += "# Context:\n" - for msg in history: - if msg["role"] == "user": - context += msg["role"].upper() + ": " + msg["content"] + "\n" - return context - def render_incremental_response(self, final=False): if self.partial_response_content: return self.partial_response_content diff --git a/aider/coders/wholefile_coder.py b/aider/coders/wholefile_coder.py index 3ccd0dac5..efec434af 100644 --- a/aider/coders/wholefile_coder.py +++ b/aider/coders/wholefile_coder.py @@ -20,15 +20,6 @@ class WholeFileCoder(Coder): else: self.cur_messages += [dict(role="assistant", content=self.partial_response_content)] - def get_context_from_history(self, history): - context = "" - if history: - context += "# Context:\n" - for msg in history: - if msg["role"] == "user": - context += msg["role"].upper() + ": " + msg["content"] + "\n" - return context - def render_incremental_response(self, final): try: return self.update_files(mode="diff") diff --git a/aider/coders/wholefile_func_coder.py b/aider/coders/wholefile_func_coder.py index e51150482..ebf10764d 100644 --- a/aider/coders/wholefile_func_coder.py +++ b/aider/coders/wholefile_func_coder.py @@ -55,15 +55,6 @@ class WholeFileFunctionCoder(Coder): else: self.cur_messages += [dict(role="assistant", content=self.partial_response_content)] - def get_context_from_history(self, history): - context = "" - if history: - context += "# Context:\n" - for msg in history: - if msg["role"] == "user": - context += msg["role"].upper() + ": " + msg["content"] + "\n" - return context - def render_incremental_response(self, final=False): if self.partial_response_content: return self.partial_response_content diff --git a/aider/commands.py b/aider/commands.py index 66e76b809..6fd5fa594 100644 --- a/aider/commands.py +++ b/aider/commands.py @@ -176,10 +176,10 @@ class Commands: ) return - local_head = self.coder.repo.git.rev_parse("HEAD") - current_branch = self.coder.repo.active_branch.name + local_head = self.coder.repo.repo.git.rev_parse("HEAD") + current_branch = self.coder.repo.repo.active_branch.name try: - remote_head = self.coder.repo.git.rev_parse(f"origin/{current_branch}") + remote_head = self.coder.repo.repo.git.rev_parse(f"origin/{current_branch}") has_origin = True except git.exc.GitCommandError: has_origin = False @@ -192,14 +192,14 @@ class Commands: ) return - last_commit = self.coder.repo.head.commit + last_commit = self.coder.repo.repo.head.commit if ( not last_commit.message.startswith("aider:") or last_commit.hexsha[:7] != self.coder.last_aider_commit_hash ): self.io.tool_error("The last commit was not made by aider in this chat session.") return - self.coder.repo.git.reset("--hard", "HEAD~1") + self.coder.repo.repo.git.reset("--hard", "HEAD~1") self.io.tool_output( f"{last_commit.message.strip()}\n" f"The above commit {self.coder.last_aider_commit_hash} " @@ -220,7 +220,11 @@ class Commands: return commits = f"{self.coder.last_aider_commit_hash}~1" - diff = self.coder.get_diffs(commits, self.coder.last_aider_commit_hash) + diff = self.coder.repo.get_diffs( + self.coder.pretty, + commits, + self.coder.last_aider_commit_hash, + ) # don't use io.tool_output() because we don't want to log or further colorize print(diff) @@ -243,7 +247,7 @@ class Commands: # if repo, filter against it if self.coder.repo: - git_files = self.coder.get_tracked_files() + git_files = self.coder.repo.get_tracked_files() matched_files = [fn for fn in matched_files if str(fn) in git_files] res = list(map(str, matched_files)) @@ -254,7 +258,7 @@ class Commands: added_fnames = [] git_added = [] - git_files = self.coder.get_tracked_files() + git_files = self.coder.repo.get_tracked_files() if self.coder.repo else [] all_matched_files = set() for word in args.split(): @@ -281,7 +285,7 @@ class Commands: abs_file_path = self.coder.abs_root_path(matched_file) if self.coder.repo and matched_file not in git_files: - self.coder.repo.git.add(abs_file_path) + self.coder.repo.repo.git.add(abs_file_path) git_added.append(matched_file) if abs_file_path in self.coder.abs_fnames: @@ -298,8 +302,8 @@ class Commands: if self.coder.repo and git_added: git_added = " ".join(git_added) commit_message = f"aider: Added {git_added}" - self.coder.repo.git.commit("-m", commit_message, "--no-verify") - commit_hash = self.coder.repo.head.commit.hexsha[:7] + self.coder.repo.repo.git.commit("-m", commit_message, "--no-verify") + commit_hash = self.coder.repo.repo.head.commit.hexsha[:7] self.io.tool_output(f"Commit {commit_hash} {commit_message}") if not added_fnames: diff --git a/aider/main.py b/aider/main.py index 8841aeb83..cceb8fae0 100644 --- a/aider/main.py +++ b/aider/main.py @@ -9,10 +9,14 @@ import openai from aider import __version__, models from aider.coders import Coder from aider.io import InputOutput +from aider.repo import GitRepo from aider.versioncheck import check_version +from .dump import dump # noqa: F401 + def get_git_root(): + """Try and guess the git repo, since the conf.yml can be at the repo root""" try: repo = git.Repo(search_parent_directories=True) return repo.working_tree_dir @@ -20,6 +24,25 @@ def get_git_root(): return None +def guessed_wrong_repo(io, git_root, fnames, git_dname): + """After we parse the args, we can determine the real repo. Did we guess wrong?""" + + try: + check_repo = Path(GitRepo(io, fnames, git_dname).root).resolve() + except FileNotFoundError: + return + + # we had no guess, rely on the "true" repo result + if not git_root: + return str(check_repo) + + git_root = Path(git_root).resolve() + if check_repo == git_root: + return + + return str(check_repo) + + def setup_git(git_root, io): if git_root: return git_root @@ -71,11 +94,14 @@ def check_gitignore(git_root, io, ask=True): io.tool_output(f"Added {pat} to .gitignore") -def main(args=None, input=None, output=None): - if args is None: - args = sys.argv[1:] +def main(argv=None, input=None, output=None, force_git_root=None): + if argv is None: + argv = sys.argv[1:] - git_root = get_git_root() + if force_git_root: + git_root = force_git_root + else: + git_root = get_git_root() conf_fname = Path(".aider.conf.yml") @@ -101,7 +127,7 @@ def main(args=None, input=None, output=None): "files", metavar="FILE", nargs="*", - help="a list of source code files to edit with GPT (optional)", + help="the directory of a git repo, or a list of files to edit with GPT (optional)", ) core_group.add_argument( "--openai-api-key", @@ -344,7 +370,7 @@ def main(args=None, input=None, output=None): ), ) - args = parser.parse_args(args) + args = parser.parse_args(argv) if args.dark_mode: args.user_input_color = "#32FF32" @@ -371,6 +397,37 @@ def main(args=None, input=None, output=None): dry_run=args.dry_run, ) + fnames = [str(Path(fn).resolve()) for fn in args.files] + if len(args.files) > 1: + good = True + for fname in args.files: + if Path(fname).is_dir(): + io.tool_error(f"{fname} is a directory, not provided alone.") + good = False + if not good: + io.tool_error( + "Provide either a single directory of a git repo, or a list of one or more files." + ) + return 1 + + git_dname = None + if len(args.files) == 1: + if Path(args.files[0]).is_dir(): + if args.git: + git_dname = str(Path(args.files[0]).resolve()) + fnames = [] + else: + io.tool_error(f"{args.files[0]} is a directory, but --no-git selected.") + return 1 + + # We can't know the git repo for sure until after parsing the args. + # If we guessed wrong, reparse because that changes things like + # the location of the config.yml and history files. + if args.git and not force_git_root: + right_repo_root = guessed_wrong_repo(io, git_root, fnames, git_dname) + if right_repo_root: + return main(argv, input, output, right_repo_root) + io.tool_output(f"Aider v{__version__}") check_version(io.tool_error) @@ -418,24 +475,29 @@ def main(args=None, input=None, output=None): setattr(openai, mod_key, val) io.tool_output(f"Setting openai.{mod_key}={val}") - coder = Coder.create( - main_model, - args.edit_format, - io, - ## - fnames=args.files, - pretty=args.pretty, - show_diffs=args.show_diffs, - auto_commits=args.auto_commits, - dirty_commits=args.dirty_commits, - dry_run=args.dry_run, - map_tokens=args.map_tokens, - verbose=args.verbose, - assistant_output_color=args.assistant_output_color, - code_theme=args.code_theme, - stream=args.stream, - use_git=args.git, - ) + try: + coder = Coder.create( + main_model, + args.edit_format, + io, + ## + fnames=fnames, + git_dname=git_dname, + pretty=args.pretty, + show_diffs=args.show_diffs, + auto_commits=args.auto_commits, + dirty_commits=args.dirty_commits, + dry_run=args.dry_run, + map_tokens=args.map_tokens, + verbose=args.verbose, + assistant_output_color=args.assistant_output_color, + code_theme=args.code_theme, + stream=args.stream, + use_git=args.git, + ) + except ValueError as err: + io.tool_error(str(err)) + return 1 if args.show_repo_map: repo_map = coder.get_repo_map() @@ -443,9 +505,6 @@ def main(args=None, input=None, output=None): io.tool_output(repo_map) return - if args.dirty_commits: - coder.commit(ask=True, which="repo_files") - if args.apply: content = io.read_text(args.apply) if content is None: @@ -454,6 +513,9 @@ def main(args=None, input=None, output=None): return io.tool_output("Use /help to see in-chat commands, run with --help to see cmd line args") + + coder.dirty_commit() + if args.message: io.tool_output() coder.run(with_message=args.message) diff --git a/aider/repo.py b/aider/repo.py new file mode 100644 index 000000000..16bcc45e4 --- /dev/null +++ b/aider/repo.py @@ -0,0 +1,177 @@ +import os +from pathlib import Path, PurePosixPath + +import git + +from aider import models, prompts, utils +from aider.sendchat import simple_send_with_retries + +from .dump import dump # noqa: F401 + + +class GitRepo: + repo = None + + def __init__(self, io, fnames, git_dname): + self.io = io + + if git_dname: + check_fnames = [git_dname] + elif fnames: + check_fnames = fnames + else: + check_fnames = ["."] + + repo_paths = [] + for fname in check_fnames: + fname = Path(fname) + fname = fname.resolve() + + if not fname.exists() and fname.parent.exists(): + fname = fname.parent + + try: + repo_path = git.Repo(fname, search_parent_directories=True).working_dir + repo_path = utils.safe_abs_path(repo_path) + repo_paths.append(repo_path) + except git.exc.InvalidGitRepositoryError: + pass + + num_repos = len(set(repo_paths)) + + if num_repos == 0: + raise FileNotFoundError + if num_repos > 1: + self.io.tool_error("Files are in different git repos.") + raise FileNotFoundError + + # https://github.com/gitpython-developers/GitPython/issues/427 + self.repo = git.Repo(repo_paths.pop(), odbt=git.GitDB) + self.root = utils.safe_abs_path(self.repo.working_tree_dir) + + def add_new_files(self, fnames): + cur_files = [Path(fn).resolve() for fn in self.get_tracked_files()] + for fname in fnames: + if Path(fname).resolve() in cur_files: + continue + if not Path(fname).exists(): + continue + self.io.tool_output(f"Adding {fname} to git") + self.repo.git.add(fname) + + def commit(self, context=None, prefix=None, message=None): + if not self.repo.is_dirty(): + return + + if message: + commit_message = message + else: + diffs = self.get_diffs(False) + commit_message = self.get_commit_message(diffs, context) + + if not commit_message: + commit_message = "(no commit message provided)" + + if prefix: + commit_message = prefix + commit_message + + full_commit_message = commit_message + if context: + full_commit_message += "\n\n# Aider chat conversation:\n\n" + context + + self.repo.git.commit("-a", "-m", full_commit_message, "--no-verify") + commit_hash = self.repo.head.commit.hexsha[:7] + self.io.tool_output(f"Commit {commit_hash} {commit_message}") + + return commit_hash, commit_message + + def get_rel_repo_dir(self): + try: + return os.path.relpath(self.repo.git_dir, os.getcwd()) + except ValueError: + return self.repo.git_dir + + def get_commit_message(self, diffs, context): + if len(diffs) >= 4 * 1024 * 4: + self.io.tool_error( + f"Diff is too large for {models.GPT35.name} to generate a commit message." + ) + return + + diffs = "# Diffs:\n" + diffs + + content = "" + if context: + content += context + "\n" + content += diffs + + messages = [ + dict(role="system", content=prompts.commit_system), + dict(role="user", content=content), + ] + + for model in [models.GPT35.name, models.GPT35_16k.name]: + commit_message = simple_send_with_retries(model, messages) + if commit_message: + break + + if not commit_message: + self.io.tool_error("Failed to generate commit message!") + return + + commit_message = commit_message.strip() + if commit_message and commit_message[0] == '"' and commit_message[-1] == '"': + commit_message = commit_message[1:-1].strip() + + return commit_message + + def get_diffs(self, pretty, *args): + try: + commits = self.repo.iter_commits(self.repo.active_branch) + current_branch_has_commits = any(commits) + except git.exc.GitCommandError: + current_branch_has_commits = False + + if not current_branch_has_commits: + return "" + + if pretty: + args = ["--color"] + list(args) + if not args: + args = ["HEAD"] + + diffs = self.repo.git.diff(*args) + return diffs + + def show_diffs(self, pretty): + diffs = self.get_diffs(pretty) + print(diffs) + + def get_tracked_files(self): + if not self.repo: + return [] + + try: + commit = self.repo.head.commit + except ValueError: + commit = None + + files = [] + if commit: + for blob in commit.tree.traverse(): + if blob.type == "blob": # blob is a file + files.append(blob.path) + + # Add staged files + index = self.repo.index + staged_files = [path for path, _ in index.entries.keys()] + + files.extend(staged_files) + + # convert to appropriate os.sep, since git always normalizes to / + res = set(str(Path(PurePosixPath(path))) for path in files) + + return res + + def is_dirty(self): + return self.repo.is_dirty() diff --git a/aider/sendchat.py b/aider/sendchat.py new file mode 100644 index 000000000..441447863 --- /dev/null +++ b/aider/sendchat.py @@ -0,0 +1,57 @@ +import hashlib +import json + +import backoff +import openai +import requests +from openai.error import APIError, RateLimitError, ServiceUnavailableError, Timeout + + +@backoff.on_exception( + backoff.expo, + ( + Timeout, + APIError, + ServiceUnavailableError, + RateLimitError, + requests.exceptions.ConnectionError, + ), + max_tries=10, + on_backoff=lambda details: print( + f"{details.get('exception','Exception')}\nRetry in {details['wait']:.1f} seconds." + ), +) +def send_with_retries(model, messages, functions, stream): + kwargs = dict( + model=model, + messages=messages, + temperature=0, + stream=stream, + ) + if functions is not None: + kwargs["functions"] = functions + + # we are abusing the openai object to stash these values + if hasattr(openai, "api_deployment_id"): + kwargs["deployment_id"] = openai.api_deployment_id + if hasattr(openai, "api_engine"): + kwargs["engine"] = openai.api_engine + + # Generate SHA1 hash of kwargs and append it to chat_completion_call_hashes + hash_object = hashlib.sha1(json.dumps(kwargs, sort_keys=True).encode()) + + res = openai.ChatCompletion.create(**kwargs) + return hash_object, res + + +def simple_send_with_retries(model, messages): + try: + _hash, response = send_with_retries( + model=model, + messages=messages, + functions=None, + stream=False, + ) + return response.choices[0].message.content + except (AttributeError, openai.error.InvalidRequestError): + return diff --git a/tests/test_coder.py b/tests/test_coder.py index 818f6fef9..a9ffaeca5 100644 --- a/tests/test_coder.py +++ b/tests/test_coder.py @@ -1,4 +1,3 @@ -import os import tempfile import unittest from pathlib import Path @@ -6,7 +5,6 @@ from unittest.mock import MagicMock, patch import git import openai -import requests from aider import models from aider.coders import Coder @@ -77,7 +75,7 @@ class TestCoder(unittest.TestCase): # Mock the git repo mock = MagicMock() mock.return_value = set(["file1.txt", "file2.py"]) - coder.get_tracked_files = mock + coder.repo.get_tracked_files = mock # Call the check_for_file_mentions method coder.check_for_file_mentions("Please check file1.txt and file2.py") @@ -121,7 +119,7 @@ class TestCoder(unittest.TestCase): mock = MagicMock() mock.return_value = set(["file1.txt", "file2.py"]) - coder.get_tracked_files = mock + coder.repo.get_tracked_files = mock # Call the check_for_file_mentions method coder.check_for_file_mentions("Please check file1.txt and file2.py") @@ -152,7 +150,7 @@ class TestCoder(unittest.TestCase): mock = MagicMock() mock.return_value = set([str(fname), str(other_fname)]) - coder.get_tracked_files = mock + coder.repo.get_tracked_files = mock # Call the check_for_file_mentions method coder.check_for_file_mentions(f"Please check {fname}!") @@ -170,7 +168,7 @@ class TestCoder(unittest.TestCase): mock = MagicMock() mock.return_value = set([str(fname)]) - coder.get_tracked_files = mock + coder.repo.get_tracked_files = mock dump(fname) # Call the check_for_file_mentions method @@ -178,110 +176,6 @@ class TestCoder(unittest.TestCase): self.assertEqual(coder.abs_fnames, set([str(fname.resolve())])) - def test_get_commit_message(self): - # Mock the IO object - mock_io = MagicMock() - - # Initialize the Coder object with the mocked IO and mocked repo - coder = Coder.create(models.GPT4, None, mock_io) - - # Mock the send method to set partial_response_content and return False - def mock_send(*args, **kwargs): - coder.partial_response_content = "a good commit message" - return False - - coder.send = MagicMock(side_effect=mock_send) - - # Call the get_commit_message method with dummy diff and context - result = coder.get_commit_message("dummy diff", "dummy context") - - # Assert that the returned message is the expected one - self.assertEqual(result, "a good commit message") - - def test_get_commit_message_strip_quotes(self): - # Mock the IO object - mock_io = MagicMock() - - # Initialize the Coder object with the mocked IO and mocked repo - coder = Coder.create(models.GPT4, None, mock_io) - - # Mock the send method to set partial_response_content and return False - def mock_send(*args, **kwargs): - coder.partial_response_content = "a good commit message" - return False - - coder.send = MagicMock(side_effect=mock_send) - - # Call the get_commit_message method with dummy diff and context - result = coder.get_commit_message("dummy diff", "dummy context") - - # Assert that the returned message is the expected one - self.assertEqual(result, "a good commit message") - - def test_get_commit_message_no_strip_unmatched_quotes(self): - # Mock the IO object - mock_io = MagicMock() - - # Initialize the Coder object with the mocked IO and mocked repo - coder = Coder.create(models.GPT4, None, mock_io) - - # Mock the send method to set partial_response_content and return False - def mock_send(*args, **kwargs): - coder.partial_response_content = 'a good "commit message"' - return False - - coder.send = MagicMock(side_effect=mock_send) - - # Call the get_commit_message method with dummy diff and context - result = coder.get_commit_message("dummy diff", "dummy context") - - # Assert that the returned message is the expected one - self.assertEqual(result, 'a good "commit message"') - - @patch("aider.coders.base_coder.openai.ChatCompletion.create") - @patch("builtins.print") - def test_send_with_retries_rate_limit_error(self, mock_print, mock_chat_completion_create): - # Mock the IO object - mock_io = MagicMock() - - # Initialize the Coder object with the mocked IO and mocked repo - coder = Coder.create(models.GPT4, None, mock_io) - - # Set up the mock to raise RateLimitError on - # the first call and return None on the second call - mock_chat_completion_create.side_effect = [ - openai.error.RateLimitError("Rate limit exceeded"), - None, - ] - - # Call the send_with_retries method - coder.send_with_retries("model", ["message"], None) - - # Assert that print was called once - mock_print.assert_called_once() - - @patch("aider.coders.base_coder.openai.ChatCompletion.create") - @patch("builtins.print") - def test_send_with_retries_connection_error(self, mock_print, mock_chat_completion_create): - # Mock the IO object - mock_io = MagicMock() - - # Initialize the Coder object with the mocked IO and mocked repo - coder = Coder.create(models.GPT4, None, mock_io) - - # Set up the mock to raise ConnectionError on the first call - # and return None on the second call - mock_chat_completion_create.side_effect = [ - requests.exceptions.ConnectionError("Connection error"), - None, - ] - - # Call the send_with_retries method - coder.send_with_retries("model", ["message"], None) - - # Assert that print was called once - mock_print.assert_called_once() - def test_run_with_file_deletion(self): # Create a few temporary files @@ -419,49 +313,5 @@ class TestCoder(unittest.TestCase): with self.assertRaises(openai.error.InvalidRequestError): coder.run(with_message="hi") - def test_get_tracked_files(self): - # Create a temporary directory - tempdir = Path(tempfile.mkdtemp()) - - # Initialize a git repository in the temporary directory and set user name and email - repo = git.Repo.init(tempdir) - repo.config_writer().set_value("user", "name", "Test User").release() - repo.config_writer().set_value("user", "email", "testuser@example.com").release() - - # Create three empty files and add them to the git repository - filenames = ["README.md", "subdir/fänny.md", "systemüber/blick.md", 'file"with"quotes.txt'] - created_files = [] - for filename in filenames: - file_path = tempdir / filename - try: - file_path.parent.mkdir(parents=True, exist_ok=True) - file_path.touch() - repo.git.add(str(file_path)) - created_files.append(Path(filename)) - except OSError: - # windows won't allow files with quotes, that's ok - self.assertIn('"', filename) - self.assertEqual(os.name, "nt") - - self.assertTrue(len(created_files) >= 3) - - repo.git.commit("-m", "added") - - # Create a Coder object on the temporary directory - coder = Coder.create( - models.GPT4, - None, - io=InputOutput(), - fnames=[str(tempdir / filenames[0])], - ) - - tracked_files = coder.get_tracked_files() - - # On windows, paths will come back \like\this, so normalize them back to Paths - tracked_files = [Path(fn) for fn in tracked_files] - - # Assert that coder.get_tracked_files() returns the three filenames - self.assertEqual(set(tracked_files), set(created_files)) - if __name__ == "__main__": unittest.main() diff --git a/tests/test_main.py b/tests/test_main.py index e85c8439d..e23268c4c 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -35,14 +35,42 @@ class TestMain(TestCase): main(["--no-git"], input=DummyInput(), output=DummyOutput()) def test_main_with_empty_dir_new_file(self): - main(["foo.txt", "--yes"], input=DummyInput(), output=DummyOutput()) + main(["foo.txt", "--yes", "--no-git"], input=DummyInput(), output=DummyOutput()) self.assertTrue(os.path.exists("foo.txt")) - def test_main_with_empty_git_dir_new_file(self): + @patch("aider.repo.GitRepo.get_commit_message", return_value="mock commit message") + def test_main_with_empty_git_dir_new_file(self, _): make_repo() main(["--yes", "foo.txt"], input=DummyInput(), output=DummyOutput()) self.assertTrue(os.path.exists("foo.txt")) + @patch("aider.repo.GitRepo.get_commit_message", return_value="mock commit message") + def test_main_with_empty_git_dir_new_files(self, _): + make_repo() + main(["--yes", "foo.txt", "bar.txt"], input=DummyInput(), output=DummyOutput()) + self.assertTrue(os.path.exists("foo.txt")) + self.assertTrue(os.path.exists("bar.txt")) + + def test_main_with_dname_and_fname(self): + subdir = Path("subdir") + subdir.mkdir() + make_repo(str(subdir)) + res = main(["subdir", "foo.txt"], input=DummyInput(), output=DummyOutput()) + self.assertNotEqual(res, None) + + @patch("aider.repo.GitRepo.get_commit_message", return_value="mock commit message") + def test_main_with_subdir_repo_fnames(self, _): + subdir = Path("subdir") + subdir.mkdir() + make_repo(str(subdir)) + main( + ["--yes", str(subdir / "foo.txt"), str(subdir / "bar.txt")], + input=DummyInput(), + output=DummyOutput(), + ) + self.assertTrue((subdir / "foo.txt").exists()) + self.assertTrue((subdir / "bar.txt").exists()) + def test_main_with_git_config_yml(self): make_repo() diff --git a/tests/test_repo.py b/tests/test_repo.py new file mode 100644 index 000000000..488959a09 --- /dev/null +++ b/tests/test_repo.py @@ -0,0 +1,114 @@ +import os +import tempfile +import unittest +from pathlib import Path +from unittest.mock import patch + +import git + +from aider.dump import dump # noqa: F401 +from aider.io import InputOutput +from aider.repo import GitRepo +from tests.utils import GitTemporaryDirectory + + +class TestRepo(unittest.TestCase): + @patch("aider.repo.simple_send_with_retries") + def test_get_commit_message(self, mock_send): + mock_send.return_value = "a good commit message" + + repo = GitRepo(InputOutput(), None, None) + # Call the get_commit_message method with dummy diff and context + result = repo.get_commit_message("dummy diff", "dummy context") + + # Assert that the returned message is the expected one + self.assertEqual(result, "a good commit message") + + @patch("aider.repo.simple_send_with_retries") + def test_get_commit_message_strip_quotes(self, mock_send): + mock_send.return_value = '"a good commit message"' + + repo = GitRepo(InputOutput(), None, None) + # Call the get_commit_message method with dummy diff and context + result = repo.get_commit_message("dummy diff", "dummy context") + + # Assert that the returned message is the expected one + self.assertEqual(result, "a good commit message") + + @patch("aider.repo.simple_send_with_retries") + def test_get_commit_message_no_strip_unmatched_quotes(self, mock_send): + mock_send.return_value = 'a good "commit message"' + + repo = GitRepo(InputOutput(), None, None) + # Call the get_commit_message method with dummy diff and context + result = repo.get_commit_message("dummy diff", "dummy context") + + # Assert that the returned message is the expected one + self.assertEqual(result, 'a good "commit message"') + + def test_get_tracked_files(self): + # Create a temporary directory + tempdir = Path(tempfile.mkdtemp()) + + # Initialize a git repository in the temporary directory and set user name and email + repo = git.Repo.init(tempdir) + repo.config_writer().set_value("user", "name", "Test User").release() + repo.config_writer().set_value("user", "email", "testuser@example.com").release() + + # Create three empty files and add them to the git repository + filenames = ["README.md", "subdir/fänny.md", "systemüber/blick.md", 'file"with"quotes.txt'] + created_files = [] + for filename in filenames: + file_path = tempdir / filename + try: + file_path.parent.mkdir(parents=True, exist_ok=True) + file_path.touch() + repo.git.add(str(file_path)) + created_files.append(Path(filename)) + except OSError: + # windows won't allow files with quotes, that's ok + self.assertIn('"', filename) + self.assertEqual(os.name, "nt") + + self.assertTrue(len(created_files) >= 3) + + repo.git.commit("-m", "added") + + tracked_files = GitRepo(InputOutput(), [tempdir], None).get_tracked_files() + + # On windows, paths will come back \like\this, so normalize them back to Paths + tracked_files = [Path(fn) for fn in tracked_files] + + # Assert that coder.get_tracked_files() returns the three filenames + self.assertEqual(set(tracked_files), set(created_files)) + + def test_get_tracked_files_with_new_staged_file(self): + with GitTemporaryDirectory(): + # new repo + raw_repo = git.Repo() + + # add it, but no commits at all in the raw_repo yet + fname = Path("new.txt") + fname.touch() + raw_repo.git.add(str(fname)) + + git_repo = GitRepo(InputOutput(), None, None) + + # better be there + fnames = git_repo.get_tracked_files() + self.assertIn(str(fname), fnames) + + # commit it, better still be there + raw_repo.git.commit("-m", "new") + fnames = git_repo.get_tracked_files() + self.assertIn(str(fname), fnames) + + # new file, added but not committed + fname2 = Path("new2.txt") + fname2.touch() + raw_repo.git.add(str(fname2)) + + # both should be there + fnames = git_repo.get_tracked_files() + self.assertIn(str(fname), fnames) + self.assertIn(str(fname2), fnames) diff --git a/tests/test_repomap.py b/tests/test_repomap.py index 1ecc951b2..82c100468 100644 --- a/tests/test_repomap.py +++ b/tests/test_repomap.py @@ -4,7 +4,6 @@ from unittest.mock import patch from aider.io import InputOutput from aider.repomap import RepoMap - from tests.utils import IgnorantTemporaryDirectory diff --git a/tests/test_sendchat.py b/tests/test_sendchat.py new file mode 100644 index 000000000..59c6f8c80 --- /dev/null +++ b/tests/test_sendchat.py @@ -0,0 +1,41 @@ +import unittest +from unittest.mock import patch + +import openai +import requests + +from aider.sendchat import send_with_retries + + +class TestSendChat(unittest.TestCase): + @patch("aider.sendchat.openai.ChatCompletion.create") + @patch("builtins.print") + def test_send_with_retries_rate_limit_error(self, mock_print, mock_chat_completion_create): + # Set up the mock to raise RateLimitError on + # the first call and return None on the second call + mock_chat_completion_create.side_effect = [ + openai.error.RateLimitError("Rate limit exceeded"), + None, + ] + + # Call the send_with_retries method + send_with_retries("model", ["message"], None, False) + + # Assert that print was called once + mock_print.assert_called_once() + + @patch("aider.sendchat.openai.ChatCompletion.create") + @patch("builtins.print") + def test_send_with_retries_connection_error(self, mock_print, mock_chat_completion_create): + # Set up the mock to raise ConnectionError on the first call + # and return None on the second call + mock_chat_completion_create.side_effect = [ + requests.exceptions.ConnectionError("Connection error"), + None, + ] + + # Call the send_with_retries method + send_with_retries("model", ["message"], None, False) + + # Assert that print was called once + mock_print.assert_called_once() diff --git a/tests/utils.py b/tests/utils.py index 8ece2b4a4..32a2deaf5 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -42,7 +42,11 @@ class GitTemporaryDirectory(ChdirTemporaryDirectory): return res -def make_repo(): - repo = git.Repo.init() +def make_repo(path=None): + if not path: + path = "." + repo = git.Repo.init(path) repo.config_writer().set_value("user", "name", "Test User").release() repo.config_writer().set_value("user", "email", "testuser@example.com").release() + + return repo