mirror of
https://github.com/Aider-AI/aider.git
synced 2025-05-28 08:14:59 +00:00
Merge pull request #137 from paul-gauthier/refactor-repo
Refactor git repo code into a new file
This commit is contained in:
commit
86309f336c
14 changed files with 642 additions and 570 deletions
|
@ -7,21 +7,19 @@ import sys
|
|||
import time
|
||||
import traceback
|
||||
from json.decoder import JSONDecodeError
|
||||
from pathlib import Path, PurePosixPath
|
||||
from pathlib import Path
|
||||
|
||||
import backoff
|
||||
import git
|
||||
import openai
|
||||
import requests
|
||||
from jsonschema import Draft7Validator
|
||||
from openai.error import APIError, RateLimitError, ServiceUnavailableError, Timeout
|
||||
from rich.console import Console, Text
|
||||
from rich.live import Live
|
||||
from rich.markdown import Markdown
|
||||
|
||||
from aider import models, prompts, utils
|
||||
from aider.commands import Commands
|
||||
from aider.repo import GitRepo
|
||||
from aider.repomap import RepoMap
|
||||
from aider.sendchat import send_with_retries
|
||||
|
||||
from ..dump import dump # noqa: F401
|
||||
|
||||
|
@ -100,6 +98,7 @@ class Coder:
|
|||
main_model,
|
||||
io,
|
||||
fnames=None,
|
||||
git_dname=None,
|
||||
pretty=True,
|
||||
show_diffs=False,
|
||||
auto_commits=True,
|
||||
|
@ -150,13 +149,27 @@ class Coder:
|
|||
|
||||
self.commands = Commands(self.io, self)
|
||||
|
||||
for fname in fnames:
|
||||
fname = Path(fname)
|
||||
if not fname.exists():
|
||||
self.io.tool_output(f"Creating empty file {fname}")
|
||||
fname.parent.mkdir(parents=True, exist_ok=True)
|
||||
fname.touch()
|
||||
|
||||
if not fname.is_file():
|
||||
raise ValueError(f"{fname} is not a file")
|
||||
|
||||
self.abs_fnames.add(str(fname.resolve()))
|
||||
|
||||
if use_git:
|
||||
self.set_repo(fnames)
|
||||
else:
|
||||
self.abs_fnames = set([str(Path(fname).resolve()) for fname in fnames])
|
||||
try:
|
||||
self.repo = GitRepo(self.io, fnames, git_dname)
|
||||
self.root = self.repo.root
|
||||
except FileNotFoundError:
|
||||
self.repo = None
|
||||
|
||||
if self.repo:
|
||||
rel_repo_dir = self.get_rel_repo_dir()
|
||||
rel_repo_dir = self.repo.get_rel_repo_dir()
|
||||
self.io.tool_output(f"Git repo: {rel_repo_dir}")
|
||||
else:
|
||||
self.io.tool_output("Git repo: none")
|
||||
|
@ -187,6 +200,9 @@ class Coder:
|
|||
for fname in self.get_inchat_relative_files():
|
||||
self.io.tool_output(f"Added {fname} to the chat.")
|
||||
|
||||
if self.repo:
|
||||
self.repo.add_new_files(fname for fname in fnames if not Path(fname).is_dir())
|
||||
|
||||
# validate the functions jsonschema
|
||||
if self.functions:
|
||||
for function in self.functions:
|
||||
|
@ -206,12 +222,6 @@ class Coder:
|
|||
|
||||
self.root = utils.safe_abs_path(self.root)
|
||||
|
||||
def get_rel_repo_dir(self):
|
||||
try:
|
||||
return os.path.relpath(self.repo.git_dir, os.getcwd())
|
||||
except ValueError:
|
||||
return self.repo.git_dir
|
||||
|
||||
def add_rel_fname(self, rel_fname):
|
||||
self.abs_fnames.add(self.abs_root_path(rel_fname))
|
||||
|
||||
|
@ -219,73 +229,6 @@ class Coder:
|
|||
res = Path(self.root) / path
|
||||
return utils.safe_abs_path(res)
|
||||
|
||||
def set_repo(self, cmd_line_fnames):
|
||||
if not cmd_line_fnames:
|
||||
cmd_line_fnames = ["."]
|
||||
|
||||
repo_paths = []
|
||||
for fname in cmd_line_fnames:
|
||||
fname = Path(fname)
|
||||
if not fname.exists():
|
||||
self.io.tool_output(f"Creating empty file {fname}")
|
||||
fname.parent.mkdir(parents=True, exist_ok=True)
|
||||
fname.touch()
|
||||
|
||||
fname = fname.resolve()
|
||||
|
||||
try:
|
||||
repo_path = git.Repo(fname, search_parent_directories=True).working_dir
|
||||
repo_path = utils.safe_abs_path(repo_path)
|
||||
repo_paths.append(repo_path)
|
||||
except git.exc.InvalidGitRepositoryError:
|
||||
pass
|
||||
|
||||
if fname.is_dir():
|
||||
continue
|
||||
|
||||
self.abs_fnames.add(str(fname))
|
||||
|
||||
num_repos = len(set(repo_paths))
|
||||
|
||||
if num_repos == 0:
|
||||
return
|
||||
if num_repos > 1:
|
||||
self.io.tool_error("Files are in different git repos.")
|
||||
return
|
||||
|
||||
# https://github.com/gitpython-developers/GitPython/issues/427
|
||||
self.repo = git.Repo(repo_paths.pop(), odbt=git.GitDB)
|
||||
|
||||
self.root = utils.safe_abs_path(self.repo.working_tree_dir)
|
||||
|
||||
new_files = []
|
||||
for fname in self.abs_fnames:
|
||||
relative_fname = self.get_rel_fname(fname)
|
||||
|
||||
tracked_files = set(self.get_tracked_files())
|
||||
if relative_fname not in tracked_files:
|
||||
new_files.append(relative_fname)
|
||||
|
||||
if new_files:
|
||||
rel_repo_dir = self.get_rel_repo_dir()
|
||||
|
||||
self.io.tool_output(f"Files not tracked in {rel_repo_dir}:")
|
||||
for fn in new_files:
|
||||
self.io.tool_output(f" - {fn}")
|
||||
if self.io.confirm_ask("Add them?"):
|
||||
for relative_fname in new_files:
|
||||
self.repo.git.add(relative_fname)
|
||||
self.io.tool_output(f"Added {relative_fname} to the git repo")
|
||||
show_files = ", ".join(new_files)
|
||||
commit_message = f"Added new files to the git repo: {show_files}"
|
||||
self.repo.git.commit("-m", commit_message, "--no-verify")
|
||||
commit_hash = self.repo.head.commit.hexsha[:7]
|
||||
self.io.tool_output(f"Commit {commit_hash} {commit_message}")
|
||||
else:
|
||||
self.io.tool_error("Skipped adding new files to the git repo.")
|
||||
return
|
||||
|
||||
# fences are obfuscated so aider can modify this file!
|
||||
fences = [
|
||||
("``" + "`", "``" + "`"),
|
||||
wrap_fence("source"),
|
||||
|
@ -412,25 +355,6 @@ class Coder:
|
|||
|
||||
self.last_keyboard_interrupt = now
|
||||
|
||||
def should_dirty_commit(self, inp):
|
||||
cmds = self.commands.matching_commands(inp)
|
||||
if cmds:
|
||||
matching_commands, _, _ = cmds
|
||||
if len(matching_commands) == 1:
|
||||
cmd = matching_commands[0][1:]
|
||||
if cmd in "add clear commit diff drop exit help ls tokens".split():
|
||||
return
|
||||
|
||||
if not self.dirty_commits:
|
||||
return
|
||||
if not self.repo:
|
||||
return
|
||||
if not self.repo.is_dirty():
|
||||
return
|
||||
if self.last_asked_for_commit_time >= self.get_last_modified():
|
||||
return
|
||||
return True
|
||||
|
||||
def move_back_cur_messages(self, message):
|
||||
self.done_messages += self.cur_messages
|
||||
if message:
|
||||
|
@ -448,13 +372,7 @@ class Coder:
|
|||
self.commands,
|
||||
)
|
||||
|
||||
if self.should_dirty_commit(inp):
|
||||
self.io.tool_output("Git repo has uncommitted changes, preparing commit...")
|
||||
self.commit(ask=True, which="repo_files")
|
||||
|
||||
# files changed, move cur messages back behind the files messages
|
||||
self.move_back_cur_messages(self.gpt_prompts.files_content_local_edits)
|
||||
|
||||
if self.should_dirty_commit(inp) and self.dirty_commit():
|
||||
if inp.strip():
|
||||
self.io.tool_output("Use up-arrow to retry previous command:", inp)
|
||||
return
|
||||
|
@ -569,23 +487,6 @@ class Coder:
|
|||
)
|
||||
]
|
||||
|
||||
def auto_commit(self):
|
||||
res = self.commit(history=self.cur_messages, prefix="aider: ")
|
||||
if res:
|
||||
commit_hash, commit_message = res
|
||||
self.last_aider_commit_hash = commit_hash
|
||||
|
||||
saved_message = self.gpt_prompts.files_content_gpt_edits.format(
|
||||
hash=commit_hash,
|
||||
message=commit_message,
|
||||
)
|
||||
else:
|
||||
if self.repo:
|
||||
self.io.tool_output("No changes made to git tracked files.")
|
||||
saved_message = self.gpt_prompts.files_content_gpt_no_edits
|
||||
|
||||
return saved_message
|
||||
|
||||
def check_for_file_mentions(self, content):
|
||||
words = set(word for word in content.split())
|
||||
|
||||
|
@ -627,44 +528,7 @@ class Coder:
|
|||
|
||||
return prompts.added_files.format(fnames=", ".join(mentioned_rel_fnames))
|
||||
|
||||
@backoff.on_exception(
|
||||
backoff.expo,
|
||||
(
|
||||
Timeout,
|
||||
APIError,
|
||||
ServiceUnavailableError,
|
||||
RateLimitError,
|
||||
requests.exceptions.ConnectionError,
|
||||
),
|
||||
max_tries=10,
|
||||
on_backoff=lambda details: print(
|
||||
f"{details.get('exception','Exception')}\nRetry in {details['wait']:.1f} seconds."
|
||||
),
|
||||
)
|
||||
def send_with_retries(self, model, messages, functions):
|
||||
kwargs = dict(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=0,
|
||||
stream=self.stream,
|
||||
)
|
||||
if functions is not None:
|
||||
kwargs["functions"] = self.functions
|
||||
|
||||
# we are abusing the openai object to stash these values
|
||||
if hasattr(openai, "api_deployment_id"):
|
||||
kwargs["deployment_id"] = openai.api_deployment_id
|
||||
if hasattr(openai, "api_engine"):
|
||||
kwargs["engine"] = openai.api_engine
|
||||
|
||||
# Generate SHA1 hash of kwargs and append it to chat_completion_call_hashes
|
||||
hash_object = hashlib.sha1(json.dumps(kwargs, sort_keys=True).encode())
|
||||
self.chat_completion_call_hashes.append(hash_object.hexdigest())
|
||||
|
||||
res = openai.ChatCompletion.create(**kwargs)
|
||||
return res
|
||||
|
||||
def send(self, messages, model=None, silent=False, functions=None):
|
||||
def send(self, messages, model=None, functions=None):
|
||||
if not model:
|
||||
model = self.main_model.name
|
||||
|
||||
|
@ -673,27 +537,28 @@ class Coder:
|
|||
|
||||
interrupted = False
|
||||
try:
|
||||
completion = self.send_with_retries(model, messages, functions)
|
||||
hash_object, completion = send_with_retries(model, messages, functions, self.stream)
|
||||
self.chat_completion_call_hashes.append(hash_object.hexdigest())
|
||||
|
||||
if self.stream:
|
||||
self.show_send_output_stream(completion, silent)
|
||||
self.show_send_output_stream(completion)
|
||||
else:
|
||||
self.show_send_output(completion, silent)
|
||||
self.show_send_output(completion)
|
||||
except KeyboardInterrupt:
|
||||
self.keyboard_interrupt()
|
||||
interrupted = True
|
||||
|
||||
if not silent:
|
||||
if self.partial_response_content:
|
||||
self.io.ai_output(self.partial_response_content)
|
||||
elif self.partial_response_function_call:
|
||||
# TODO: push this into subclasses
|
||||
args = self.parse_partial_args()
|
||||
if args:
|
||||
self.io.ai_output(json.dumps(args, indent=4))
|
||||
if self.partial_response_content:
|
||||
self.io.ai_output(self.partial_response_content)
|
||||
elif self.partial_response_function_call:
|
||||
# TODO: push this into subclasses
|
||||
args = self.parse_partial_args()
|
||||
if args:
|
||||
self.io.ai_output(json.dumps(args, indent=4))
|
||||
|
||||
return interrupted
|
||||
|
||||
def show_send_output(self, completion, silent):
|
||||
def show_send_output(self, completion):
|
||||
if self.verbose:
|
||||
print(completion)
|
||||
|
||||
|
@ -742,9 +607,9 @@ class Coder:
|
|||
self.io.console.print(show_resp)
|
||||
self.io.console.print(tokens)
|
||||
|
||||
def show_send_output_stream(self, completion, silent):
|
||||
def show_send_output_stream(self, completion):
|
||||
live = None
|
||||
if self.pretty and not silent:
|
||||
if self.pretty:
|
||||
live = Live(vertical_overflow="scroll")
|
||||
|
||||
try:
|
||||
|
@ -773,9 +638,6 @@ class Coder:
|
|||
except AttributeError:
|
||||
pass
|
||||
|
||||
if silent:
|
||||
continue
|
||||
|
||||
if self.pretty:
|
||||
self.live_incremental_response(live, False)
|
||||
else:
|
||||
|
@ -797,145 +659,6 @@ class Coder:
|
|||
def render_incremental_response(self, final):
|
||||
return self.partial_response_content
|
||||
|
||||
def get_context_from_history(self, history):
|
||||
context = ""
|
||||
if history:
|
||||
for msg in history:
|
||||
context += "\n" + msg["role"].upper() + ": " + msg["content"] + "\n"
|
||||
return context
|
||||
|
||||
def get_commit_message(self, diffs, context):
|
||||
if len(diffs) >= 4 * 1024 * 4:
|
||||
self.io.tool_error(
|
||||
f"Diff is too large for {models.GPT35.name} to generate a commit message."
|
||||
)
|
||||
return
|
||||
|
||||
diffs = "# Diffs:\n" + diffs
|
||||
|
||||
messages = [
|
||||
dict(role="system", content=prompts.commit_system),
|
||||
dict(role="user", content=context + diffs),
|
||||
]
|
||||
|
||||
try:
|
||||
interrupted = self.send(
|
||||
messages,
|
||||
model=models.GPT35.name,
|
||||
silent=True,
|
||||
)
|
||||
except openai.error.InvalidRequestError:
|
||||
self.io.tool_error(
|
||||
f"Failed to generate commit message using {models.GPT35.name} due to an invalid"
|
||||
" request."
|
||||
)
|
||||
return
|
||||
|
||||
commit_message = self.partial_response_content
|
||||
commit_message = commit_message.strip()
|
||||
if commit_message and commit_message[0] == '"' and commit_message[-1] == '"':
|
||||
commit_message = commit_message[1:-1].strip()
|
||||
|
||||
if interrupted:
|
||||
self.io.tool_error(
|
||||
f"Unable to get commit message from {models.GPT35.name}. Use /commit to try again."
|
||||
)
|
||||
return
|
||||
|
||||
return commit_message
|
||||
|
||||
def get_diffs(self, *args):
|
||||
if self.pretty:
|
||||
args = ["--color"] + list(args)
|
||||
|
||||
diffs = self.repo.git.diff(*args)
|
||||
return diffs
|
||||
|
||||
def commit(self, history=None, prefix=None, ask=False, message=None, which="chat_files"):
|
||||
repo = self.repo
|
||||
if not repo:
|
||||
return
|
||||
|
||||
if not repo.is_dirty():
|
||||
return
|
||||
|
||||
def get_dirty_files_and_diffs(file_list):
|
||||
diffs = ""
|
||||
relative_dirty_files = []
|
||||
for fname in file_list:
|
||||
relative_fname = self.get_rel_fname(fname)
|
||||
relative_dirty_files.append(relative_fname)
|
||||
|
||||
try:
|
||||
current_branch_commit_count = len(
|
||||
list(self.repo.iter_commits(self.repo.active_branch))
|
||||
)
|
||||
except git.exc.GitCommandError:
|
||||
current_branch_commit_count = None
|
||||
|
||||
if not current_branch_commit_count:
|
||||
continue
|
||||
|
||||
these_diffs = self.get_diffs("HEAD", "--", relative_fname)
|
||||
|
||||
if these_diffs:
|
||||
diffs += these_diffs + "\n"
|
||||
|
||||
return relative_dirty_files, diffs
|
||||
|
||||
if which == "repo_files":
|
||||
all_files = [os.path.join(self.root, f) for f in self.get_all_relative_files()]
|
||||
relative_dirty_fnames, diffs = get_dirty_files_and_diffs(all_files)
|
||||
elif which == "chat_files":
|
||||
relative_dirty_fnames, diffs = get_dirty_files_and_diffs(self.abs_fnames)
|
||||
else:
|
||||
raise ValueError(f"Invalid value for 'which': {which}")
|
||||
|
||||
if self.show_diffs or ask:
|
||||
# don't use io.tool_output() because we don't want to log or further colorize
|
||||
print(diffs)
|
||||
|
||||
context = self.get_context_from_history(history)
|
||||
if message:
|
||||
commit_message = message
|
||||
else:
|
||||
commit_message = self.get_commit_message(diffs, context)
|
||||
|
||||
if not commit_message:
|
||||
commit_message = "work in progress"
|
||||
|
||||
if prefix:
|
||||
commit_message = prefix + commit_message
|
||||
|
||||
if ask:
|
||||
if which == "repo_files":
|
||||
self.io.tool_output("Git repo has uncommitted changes.")
|
||||
else:
|
||||
self.io.tool_output("Files have uncommitted changes.")
|
||||
|
||||
res = self.io.prompt_ask(
|
||||
"Commit before the chat proceeds [y/n/commit message]?",
|
||||
default=commit_message,
|
||||
).strip()
|
||||
self.last_asked_for_commit_time = self.get_last_modified()
|
||||
|
||||
self.io.tool_output()
|
||||
|
||||
if res.lower() in ["n", "no"]:
|
||||
self.io.tool_error("Skipped commmit.")
|
||||
return
|
||||
if res.lower() not in ["y", "yes"] and res:
|
||||
commit_message = res
|
||||
|
||||
repo.git.add(*relative_dirty_fnames)
|
||||
|
||||
full_commit_message = commit_message + "\n\n# Aider chat conversation:\n\n" + context
|
||||
repo.git.commit("-m", full_commit_message, "--no-verify")
|
||||
commit_hash = repo.head.commit.hexsha[:7]
|
||||
self.io.tool_output(f"Commit {commit_hash} {commit_message}")
|
||||
|
||||
return commit_hash, commit_message
|
||||
|
||||
def get_rel_fname(self, fname):
|
||||
return os.path.relpath(fname, self.root)
|
||||
|
||||
|
@ -945,7 +668,7 @@ class Coder:
|
|||
|
||||
def get_all_relative_files(self):
|
||||
if self.repo:
|
||||
files = self.get_tracked_files()
|
||||
files = self.repo.get_tracked_files()
|
||||
else:
|
||||
files = self.get_inchat_relative_files()
|
||||
|
||||
|
@ -1000,32 +723,6 @@ class Coder:
|
|||
|
||||
return full_path
|
||||
|
||||
def get_tracked_files(self):
|
||||
if not self.repo:
|
||||
return []
|
||||
|
||||
try:
|
||||
commit = self.repo.head.commit
|
||||
except ValueError:
|
||||
commit = None
|
||||
|
||||
files = []
|
||||
if commit:
|
||||
for blob in commit.tree.traverse():
|
||||
if blob.type == "blob": # blob is a file
|
||||
files.append(blob.path)
|
||||
|
||||
# Add staged files
|
||||
index = self.repo.index
|
||||
staged_files = [path for path, _ in index.entries.keys()]
|
||||
|
||||
files.extend(staged_files)
|
||||
|
||||
# convert to appropriate os.sep, since git always normalizes to /
|
||||
res = set(str(Path(PurePosixPath(path))) for path in files)
|
||||
|
||||
return res
|
||||
|
||||
apply_update_errors = 0
|
||||
|
||||
def apply_updates(self):
|
||||
|
@ -1094,6 +791,72 @@ class Coder:
|
|||
except JSONDecodeError:
|
||||
pass
|
||||
|
||||
# commits...
|
||||
|
||||
def get_context_from_history(self, history):
|
||||
context = ""
|
||||
if history:
|
||||
for msg in history:
|
||||
context += "\n" + msg["role"].upper() + ": " + msg["content"] + "\n"
|
||||
return context
|
||||
|
||||
def auto_commit(self):
|
||||
context = self.get_context_from_history(self.cur_messages)
|
||||
res = self.repo.commit(context=context, prefix="aider: ")
|
||||
if res:
|
||||
commit_hash, commit_message = res
|
||||
self.last_aider_commit_hash = commit_hash
|
||||
|
||||
return self.gpt_prompts.files_content_gpt_edits.format(
|
||||
hash=commit_hash,
|
||||
message=commit_message,
|
||||
)
|
||||
|
||||
self.io.tool_output("No changes made to git tracked files.")
|
||||
return self.gpt_prompts.files_content_gpt_no_edits
|
||||
|
||||
def should_dirty_commit(self, inp):
|
||||
cmds = self.commands.matching_commands(inp)
|
||||
if cmds:
|
||||
matching_commands, _, _ = cmds
|
||||
if len(matching_commands) == 1:
|
||||
cmd = matching_commands[0][1:]
|
||||
if cmd in "add clear commit diff drop exit help ls tokens".split():
|
||||
return
|
||||
|
||||
if self.last_asked_for_commit_time >= self.get_last_modified():
|
||||
return
|
||||
return True
|
||||
|
||||
def dirty_commit(self):
|
||||
if not self.dirty_commits:
|
||||
return
|
||||
if not self.repo:
|
||||
return
|
||||
if not self.repo.is_dirty():
|
||||
return
|
||||
|
||||
self.io.tool_output("Git repo has uncommitted changes.")
|
||||
self.repo.show_diffs(self.pretty)
|
||||
self.last_asked_for_commit_time = self.get_last_modified()
|
||||
res = self.io.prompt_ask(
|
||||
"Commit before the chat proceeds [y/n/commit message]?",
|
||||
default="y",
|
||||
).strip()
|
||||
if res.lower() in ["n", "no"]:
|
||||
self.io.tool_error("Skipped commmit.")
|
||||
return
|
||||
if res.lower() in ["y", "yes"]:
|
||||
message = None
|
||||
else:
|
||||
message = res.strip()
|
||||
|
||||
self.repo.commit(message=message)
|
||||
|
||||
# files changed, move cur messages back behind the files messages
|
||||
self.move_back_cur_messages(self.gpt_prompts.files_content_local_edits)
|
||||
return True
|
||||
|
||||
|
||||
def check_model_availability(main_model):
|
||||
available_models = openai.Model.list()
|
||||
|
|
|
@ -42,15 +42,6 @@ class SingleWholeFileFunctionCoder(Coder):
|
|||
else:
|
||||
self.cur_messages += [dict(role="assistant", content=self.partial_response_content)]
|
||||
|
||||
def get_context_from_history(self, history):
|
||||
context = ""
|
||||
if history:
|
||||
context += "# Context:\n"
|
||||
for msg in history:
|
||||
if msg["role"] == "user":
|
||||
context += msg["role"].upper() + ": " + msg["content"] + "\n"
|
||||
return context
|
||||
|
||||
def render_incremental_response(self, final=False):
|
||||
if self.partial_response_content:
|
||||
return self.partial_response_content
|
||||
|
|
|
@ -20,15 +20,6 @@ class WholeFileCoder(Coder):
|
|||
else:
|
||||
self.cur_messages += [dict(role="assistant", content=self.partial_response_content)]
|
||||
|
||||
def get_context_from_history(self, history):
|
||||
context = ""
|
||||
if history:
|
||||
context += "# Context:\n"
|
||||
for msg in history:
|
||||
if msg["role"] == "user":
|
||||
context += msg["role"].upper() + ": " + msg["content"] + "\n"
|
||||
return context
|
||||
|
||||
def render_incremental_response(self, final):
|
||||
try:
|
||||
return self.update_files(mode="diff")
|
||||
|
|
|
@ -55,15 +55,6 @@ class WholeFileFunctionCoder(Coder):
|
|||
else:
|
||||
self.cur_messages += [dict(role="assistant", content=self.partial_response_content)]
|
||||
|
||||
def get_context_from_history(self, history):
|
||||
context = ""
|
||||
if history:
|
||||
context += "# Context:\n"
|
||||
for msg in history:
|
||||
if msg["role"] == "user":
|
||||
context += msg["role"].upper() + ": " + msg["content"] + "\n"
|
||||
return context
|
||||
|
||||
def render_incremental_response(self, final=False):
|
||||
if self.partial_response_content:
|
||||
return self.partial_response_content
|
||||
|
|
|
@ -176,10 +176,10 @@ class Commands:
|
|||
)
|
||||
return
|
||||
|
||||
local_head = self.coder.repo.git.rev_parse("HEAD")
|
||||
current_branch = self.coder.repo.active_branch.name
|
||||
local_head = self.coder.repo.repo.git.rev_parse("HEAD")
|
||||
current_branch = self.coder.repo.repo.active_branch.name
|
||||
try:
|
||||
remote_head = self.coder.repo.git.rev_parse(f"origin/{current_branch}")
|
||||
remote_head = self.coder.repo.repo.git.rev_parse(f"origin/{current_branch}")
|
||||
has_origin = True
|
||||
except git.exc.GitCommandError:
|
||||
has_origin = False
|
||||
|
@ -192,14 +192,14 @@ class Commands:
|
|||
)
|
||||
return
|
||||
|
||||
last_commit = self.coder.repo.head.commit
|
||||
last_commit = self.coder.repo.repo.head.commit
|
||||
if (
|
||||
not last_commit.message.startswith("aider:")
|
||||
or last_commit.hexsha[:7] != self.coder.last_aider_commit_hash
|
||||
):
|
||||
self.io.tool_error("The last commit was not made by aider in this chat session.")
|
||||
return
|
||||
self.coder.repo.git.reset("--hard", "HEAD~1")
|
||||
self.coder.repo.repo.git.reset("--hard", "HEAD~1")
|
||||
self.io.tool_output(
|
||||
f"{last_commit.message.strip()}\n"
|
||||
f"The above commit {self.coder.last_aider_commit_hash} "
|
||||
|
@ -220,7 +220,11 @@ class Commands:
|
|||
return
|
||||
|
||||
commits = f"{self.coder.last_aider_commit_hash}~1"
|
||||
diff = self.coder.get_diffs(commits, self.coder.last_aider_commit_hash)
|
||||
diff = self.coder.repo.get_diffs(
|
||||
self.coder.pretty,
|
||||
commits,
|
||||
self.coder.last_aider_commit_hash,
|
||||
)
|
||||
|
||||
# don't use io.tool_output() because we don't want to log or further colorize
|
||||
print(diff)
|
||||
|
@ -243,7 +247,7 @@ class Commands:
|
|||
|
||||
# if repo, filter against it
|
||||
if self.coder.repo:
|
||||
git_files = self.coder.get_tracked_files()
|
||||
git_files = self.coder.repo.get_tracked_files()
|
||||
matched_files = [fn for fn in matched_files if str(fn) in git_files]
|
||||
|
||||
res = list(map(str, matched_files))
|
||||
|
@ -254,7 +258,7 @@ class Commands:
|
|||
|
||||
added_fnames = []
|
||||
git_added = []
|
||||
git_files = self.coder.get_tracked_files()
|
||||
git_files = self.coder.repo.get_tracked_files() if self.coder.repo else []
|
||||
|
||||
all_matched_files = set()
|
||||
for word in args.split():
|
||||
|
@ -281,7 +285,7 @@ class Commands:
|
|||
abs_file_path = self.coder.abs_root_path(matched_file)
|
||||
|
||||
if self.coder.repo and matched_file not in git_files:
|
||||
self.coder.repo.git.add(abs_file_path)
|
||||
self.coder.repo.repo.git.add(abs_file_path)
|
||||
git_added.append(matched_file)
|
||||
|
||||
if abs_file_path in self.coder.abs_fnames:
|
||||
|
@ -298,8 +302,8 @@ class Commands:
|
|||
if self.coder.repo and git_added:
|
||||
git_added = " ".join(git_added)
|
||||
commit_message = f"aider: Added {git_added}"
|
||||
self.coder.repo.git.commit("-m", commit_message, "--no-verify")
|
||||
commit_hash = self.coder.repo.head.commit.hexsha[:7]
|
||||
self.coder.repo.repo.git.commit("-m", commit_message, "--no-verify")
|
||||
commit_hash = self.coder.repo.repo.head.commit.hexsha[:7]
|
||||
self.io.tool_output(f"Commit {commit_hash} {commit_message}")
|
||||
|
||||
if not added_fnames:
|
||||
|
|
116
aider/main.py
116
aider/main.py
|
@ -9,10 +9,14 @@ import openai
|
|||
from aider import __version__, models
|
||||
from aider.coders import Coder
|
||||
from aider.io import InputOutput
|
||||
from aider.repo import GitRepo
|
||||
from aider.versioncheck import check_version
|
||||
|
||||
from .dump import dump # noqa: F401
|
||||
|
||||
|
||||
def get_git_root():
|
||||
"""Try and guess the git repo, since the conf.yml can be at the repo root"""
|
||||
try:
|
||||
repo = git.Repo(search_parent_directories=True)
|
||||
return repo.working_tree_dir
|
||||
|
@ -20,6 +24,25 @@ def get_git_root():
|
|||
return None
|
||||
|
||||
|
||||
def guessed_wrong_repo(io, git_root, fnames, git_dname):
|
||||
"""After we parse the args, we can determine the real repo. Did we guess wrong?"""
|
||||
|
||||
try:
|
||||
check_repo = Path(GitRepo(io, fnames, git_dname).root).resolve()
|
||||
except FileNotFoundError:
|
||||
return
|
||||
|
||||
# we had no guess, rely on the "true" repo result
|
||||
if not git_root:
|
||||
return str(check_repo)
|
||||
|
||||
git_root = Path(git_root).resolve()
|
||||
if check_repo == git_root:
|
||||
return
|
||||
|
||||
return str(check_repo)
|
||||
|
||||
|
||||
def setup_git(git_root, io):
|
||||
if git_root:
|
||||
return git_root
|
||||
|
@ -71,11 +94,14 @@ def check_gitignore(git_root, io, ask=True):
|
|||
io.tool_output(f"Added {pat} to .gitignore")
|
||||
|
||||
|
||||
def main(args=None, input=None, output=None):
|
||||
if args is None:
|
||||
args = sys.argv[1:]
|
||||
def main(argv=None, input=None, output=None, force_git_root=None):
|
||||
if argv is None:
|
||||
argv = sys.argv[1:]
|
||||
|
||||
git_root = get_git_root()
|
||||
if force_git_root:
|
||||
git_root = force_git_root
|
||||
else:
|
||||
git_root = get_git_root()
|
||||
|
||||
conf_fname = Path(".aider.conf.yml")
|
||||
|
||||
|
@ -101,7 +127,7 @@ def main(args=None, input=None, output=None):
|
|||
"files",
|
||||
metavar="FILE",
|
||||
nargs="*",
|
||||
help="a list of source code files to edit with GPT (optional)",
|
||||
help="the directory of a git repo, or a list of files to edit with GPT (optional)",
|
||||
)
|
||||
core_group.add_argument(
|
||||
"--openai-api-key",
|
||||
|
@ -344,7 +370,7 @@ def main(args=None, input=None, output=None):
|
|||
),
|
||||
)
|
||||
|
||||
args = parser.parse_args(args)
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
if args.dark_mode:
|
||||
args.user_input_color = "#32FF32"
|
||||
|
@ -371,6 +397,37 @@ def main(args=None, input=None, output=None):
|
|||
dry_run=args.dry_run,
|
||||
)
|
||||
|
||||
fnames = [str(Path(fn).resolve()) for fn in args.files]
|
||||
if len(args.files) > 1:
|
||||
good = True
|
||||
for fname in args.files:
|
||||
if Path(fname).is_dir():
|
||||
io.tool_error(f"{fname} is a directory, not provided alone.")
|
||||
good = False
|
||||
if not good:
|
||||
io.tool_error(
|
||||
"Provide either a single directory of a git repo, or a list of one or more files."
|
||||
)
|
||||
return 1
|
||||
|
||||
git_dname = None
|
||||
if len(args.files) == 1:
|
||||
if Path(args.files[0]).is_dir():
|
||||
if args.git:
|
||||
git_dname = str(Path(args.files[0]).resolve())
|
||||
fnames = []
|
||||
else:
|
||||
io.tool_error(f"{args.files[0]} is a directory, but --no-git selected.")
|
||||
return 1
|
||||
|
||||
# We can't know the git repo for sure until after parsing the args.
|
||||
# If we guessed wrong, reparse because that changes things like
|
||||
# the location of the config.yml and history files.
|
||||
if args.git and not force_git_root:
|
||||
right_repo_root = guessed_wrong_repo(io, git_root, fnames, git_dname)
|
||||
if right_repo_root:
|
||||
return main(argv, input, output, right_repo_root)
|
||||
|
||||
io.tool_output(f"Aider v{__version__}")
|
||||
|
||||
check_version(io.tool_error)
|
||||
|
@ -418,24 +475,29 @@ def main(args=None, input=None, output=None):
|
|||
setattr(openai, mod_key, val)
|
||||
io.tool_output(f"Setting openai.{mod_key}={val}")
|
||||
|
||||
coder = Coder.create(
|
||||
main_model,
|
||||
args.edit_format,
|
||||
io,
|
||||
##
|
||||
fnames=args.files,
|
||||
pretty=args.pretty,
|
||||
show_diffs=args.show_diffs,
|
||||
auto_commits=args.auto_commits,
|
||||
dirty_commits=args.dirty_commits,
|
||||
dry_run=args.dry_run,
|
||||
map_tokens=args.map_tokens,
|
||||
verbose=args.verbose,
|
||||
assistant_output_color=args.assistant_output_color,
|
||||
code_theme=args.code_theme,
|
||||
stream=args.stream,
|
||||
use_git=args.git,
|
||||
)
|
||||
try:
|
||||
coder = Coder.create(
|
||||
main_model,
|
||||
args.edit_format,
|
||||
io,
|
||||
##
|
||||
fnames=fnames,
|
||||
git_dname=git_dname,
|
||||
pretty=args.pretty,
|
||||
show_diffs=args.show_diffs,
|
||||
auto_commits=args.auto_commits,
|
||||
dirty_commits=args.dirty_commits,
|
||||
dry_run=args.dry_run,
|
||||
map_tokens=args.map_tokens,
|
||||
verbose=args.verbose,
|
||||
assistant_output_color=args.assistant_output_color,
|
||||
code_theme=args.code_theme,
|
||||
stream=args.stream,
|
||||
use_git=args.git,
|
||||
)
|
||||
except ValueError as err:
|
||||
io.tool_error(str(err))
|
||||
return 1
|
||||
|
||||
if args.show_repo_map:
|
||||
repo_map = coder.get_repo_map()
|
||||
|
@ -443,9 +505,6 @@ def main(args=None, input=None, output=None):
|
|||
io.tool_output(repo_map)
|
||||
return
|
||||
|
||||
if args.dirty_commits:
|
||||
coder.commit(ask=True, which="repo_files")
|
||||
|
||||
if args.apply:
|
||||
content = io.read_text(args.apply)
|
||||
if content is None:
|
||||
|
@ -454,6 +513,9 @@ def main(args=None, input=None, output=None):
|
|||
return
|
||||
|
||||
io.tool_output("Use /help to see in-chat commands, run with --help to see cmd line args")
|
||||
|
||||
coder.dirty_commit()
|
||||
|
||||
if args.message:
|
||||
io.tool_output()
|
||||
coder.run(with_message=args.message)
|
||||
|
|
177
aider/repo.py
Normal file
177
aider/repo.py
Normal file
|
@ -0,0 +1,177 @@
|
|||
import os
|
||||
from pathlib import Path, PurePosixPath
|
||||
|
||||
import git
|
||||
|
||||
from aider import models, prompts, utils
|
||||
from aider.sendchat import simple_send_with_retries
|
||||
|
||||
from .dump import dump # noqa: F401
|
||||
|
||||
|
||||
class GitRepo:
|
||||
repo = None
|
||||
|
||||
def __init__(self, io, fnames, git_dname):
|
||||
self.io = io
|
||||
|
||||
if git_dname:
|
||||
check_fnames = [git_dname]
|
||||
elif fnames:
|
||||
check_fnames = fnames
|
||||
else:
|
||||
check_fnames = ["."]
|
||||
|
||||
repo_paths = []
|
||||
for fname in check_fnames:
|
||||
fname = Path(fname)
|
||||
fname = fname.resolve()
|
||||
|
||||
if not fname.exists() and fname.parent.exists():
|
||||
fname = fname.parent
|
||||
|
||||
try:
|
||||
repo_path = git.Repo(fname, search_parent_directories=True).working_dir
|
||||
repo_path = utils.safe_abs_path(repo_path)
|
||||
repo_paths.append(repo_path)
|
||||
except git.exc.InvalidGitRepositoryError:
|
||||
pass
|
||||
|
||||
num_repos = len(set(repo_paths))
|
||||
|
||||
if num_repos == 0:
|
||||
raise FileNotFoundError
|
||||
if num_repos > 1:
|
||||
self.io.tool_error("Files are in different git repos.")
|
||||
raise FileNotFoundError
|
||||
|
||||
# https://github.com/gitpython-developers/GitPython/issues/427
|
||||
self.repo = git.Repo(repo_paths.pop(), odbt=git.GitDB)
|
||||
self.root = utils.safe_abs_path(self.repo.working_tree_dir)
|
||||
|
||||
def add_new_files(self, fnames):
|
||||
cur_files = [Path(fn).resolve() for fn in self.get_tracked_files()]
|
||||
for fname in fnames:
|
||||
if Path(fname).resolve() in cur_files:
|
||||
continue
|
||||
if not Path(fname).exists():
|
||||
continue
|
||||
self.io.tool_output(f"Adding {fname} to git")
|
||||
self.repo.git.add(fname)
|
||||
|
||||
def commit(self, context=None, prefix=None, message=None):
|
||||
if not self.repo.is_dirty():
|
||||
return
|
||||
|
||||
if message:
|
||||
commit_message = message
|
||||
else:
|
||||
diffs = self.get_diffs(False)
|
||||
commit_message = self.get_commit_message(diffs, context)
|
||||
|
||||
if not commit_message:
|
||||
commit_message = "(no commit message provided)"
|
||||
|
||||
if prefix:
|
||||
commit_message = prefix + commit_message
|
||||
|
||||
full_commit_message = commit_message
|
||||
if context:
|
||||
full_commit_message += "\n\n# Aider chat conversation:\n\n" + context
|
||||
|
||||
self.repo.git.commit("-a", "-m", full_commit_message, "--no-verify")
|
||||
commit_hash = self.repo.head.commit.hexsha[:7]
|
||||
self.io.tool_output(f"Commit {commit_hash} {commit_message}")
|
||||
|
||||
return commit_hash, commit_message
|
||||
|
||||
def get_rel_repo_dir(self):
|
||||
try:
|
||||
return os.path.relpath(self.repo.git_dir, os.getcwd())
|
||||
except ValueError:
|
||||
return self.repo.git_dir
|
||||
|
||||
def get_commit_message(self, diffs, context):
|
||||
if len(diffs) >= 4 * 1024 * 4:
|
||||
self.io.tool_error(
|
||||
f"Diff is too large for {models.GPT35.name} to generate a commit message."
|
||||
)
|
||||
return
|
||||
|
||||
diffs = "# Diffs:\n" + diffs
|
||||
|
||||
content = ""
|
||||
if context:
|
||||
content += context + "\n"
|
||||
content += diffs
|
||||
|
||||
messages = [
|
||||
dict(role="system", content=prompts.commit_system),
|
||||
dict(role="user", content=content),
|
||||
]
|
||||
|
||||
for model in [models.GPT35.name, models.GPT35_16k.name]:
|
||||
commit_message = simple_send_with_retries(model, messages)
|
||||
if commit_message:
|
||||
break
|
||||
|
||||
if not commit_message:
|
||||
self.io.tool_error("Failed to generate commit message!")
|
||||
return
|
||||
|
||||
commit_message = commit_message.strip()
|
||||
if commit_message and commit_message[0] == '"' and commit_message[-1] == '"':
|
||||
commit_message = commit_message[1:-1].strip()
|
||||
|
||||
return commit_message
|
||||
|
||||
def get_diffs(self, pretty, *args):
|
||||
try:
|
||||
commits = self.repo.iter_commits(self.repo.active_branch)
|
||||
current_branch_has_commits = any(commits)
|
||||
except git.exc.GitCommandError:
|
||||
current_branch_has_commits = False
|
||||
|
||||
if not current_branch_has_commits:
|
||||
return ""
|
||||
|
||||
if pretty:
|
||||
args = ["--color"] + list(args)
|
||||
if not args:
|
||||
args = ["HEAD"]
|
||||
|
||||
diffs = self.repo.git.diff(*args)
|
||||
return diffs
|
||||
|
||||
def show_diffs(self, pretty):
|
||||
diffs = self.get_diffs(pretty)
|
||||
print(diffs)
|
||||
|
||||
def get_tracked_files(self):
|
||||
if not self.repo:
|
||||
return []
|
||||
|
||||
try:
|
||||
commit = self.repo.head.commit
|
||||
except ValueError:
|
||||
commit = None
|
||||
|
||||
files = []
|
||||
if commit:
|
||||
for blob in commit.tree.traverse():
|
||||
if blob.type == "blob": # blob is a file
|
||||
files.append(blob.path)
|
||||
|
||||
# Add staged files
|
||||
index = self.repo.index
|
||||
staged_files = [path for path, _ in index.entries.keys()]
|
||||
|
||||
files.extend(staged_files)
|
||||
|
||||
# convert to appropriate os.sep, since git always normalizes to /
|
||||
res = set(str(Path(PurePosixPath(path))) for path in files)
|
||||
|
||||
return res
|
||||
|
||||
def is_dirty(self):
|
||||
return self.repo.is_dirty()
|
57
aider/sendchat.py
Normal file
57
aider/sendchat.py
Normal file
|
@ -0,0 +1,57 @@
|
|||
import hashlib
|
||||
import json
|
||||
|
||||
import backoff
|
||||
import openai
|
||||
import requests
|
||||
from openai.error import APIError, RateLimitError, ServiceUnavailableError, Timeout
|
||||
|
||||
|
||||
@backoff.on_exception(
|
||||
backoff.expo,
|
||||
(
|
||||
Timeout,
|
||||
APIError,
|
||||
ServiceUnavailableError,
|
||||
RateLimitError,
|
||||
requests.exceptions.ConnectionError,
|
||||
),
|
||||
max_tries=10,
|
||||
on_backoff=lambda details: print(
|
||||
f"{details.get('exception','Exception')}\nRetry in {details['wait']:.1f} seconds."
|
||||
),
|
||||
)
|
||||
def send_with_retries(model, messages, functions, stream):
|
||||
kwargs = dict(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=0,
|
||||
stream=stream,
|
||||
)
|
||||
if functions is not None:
|
||||
kwargs["functions"] = functions
|
||||
|
||||
# we are abusing the openai object to stash these values
|
||||
if hasattr(openai, "api_deployment_id"):
|
||||
kwargs["deployment_id"] = openai.api_deployment_id
|
||||
if hasattr(openai, "api_engine"):
|
||||
kwargs["engine"] = openai.api_engine
|
||||
|
||||
# Generate SHA1 hash of kwargs and append it to chat_completion_call_hashes
|
||||
hash_object = hashlib.sha1(json.dumps(kwargs, sort_keys=True).encode())
|
||||
|
||||
res = openai.ChatCompletion.create(**kwargs)
|
||||
return hash_object, res
|
||||
|
||||
|
||||
def simple_send_with_retries(model, messages):
|
||||
try:
|
||||
_hash, response = send_with_retries(
|
||||
model=model,
|
||||
messages=messages,
|
||||
functions=None,
|
||||
stream=False,
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
except (AttributeError, openai.error.InvalidRequestError):
|
||||
return
|
|
@ -1,4 +1,3 @@
|
|||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
@ -6,7 +5,6 @@ from unittest.mock import MagicMock, patch
|
|||
|
||||
import git
|
||||
import openai
|
||||
import requests
|
||||
|
||||
from aider import models
|
||||
from aider.coders import Coder
|
||||
|
@ -77,7 +75,7 @@ class TestCoder(unittest.TestCase):
|
|||
# Mock the git repo
|
||||
mock = MagicMock()
|
||||
mock.return_value = set(["file1.txt", "file2.py"])
|
||||
coder.get_tracked_files = mock
|
||||
coder.repo.get_tracked_files = mock
|
||||
|
||||
# Call the check_for_file_mentions method
|
||||
coder.check_for_file_mentions("Please check file1.txt and file2.py")
|
||||
|
@ -121,7 +119,7 @@ class TestCoder(unittest.TestCase):
|
|||
|
||||
mock = MagicMock()
|
||||
mock.return_value = set(["file1.txt", "file2.py"])
|
||||
coder.get_tracked_files = mock
|
||||
coder.repo.get_tracked_files = mock
|
||||
|
||||
# Call the check_for_file_mentions method
|
||||
coder.check_for_file_mentions("Please check file1.txt and file2.py")
|
||||
|
@ -152,7 +150,7 @@ class TestCoder(unittest.TestCase):
|
|||
|
||||
mock = MagicMock()
|
||||
mock.return_value = set([str(fname), str(other_fname)])
|
||||
coder.get_tracked_files = mock
|
||||
coder.repo.get_tracked_files = mock
|
||||
|
||||
# Call the check_for_file_mentions method
|
||||
coder.check_for_file_mentions(f"Please check {fname}!")
|
||||
|
@ -170,7 +168,7 @@ class TestCoder(unittest.TestCase):
|
|||
|
||||
mock = MagicMock()
|
||||
mock.return_value = set([str(fname)])
|
||||
coder.get_tracked_files = mock
|
||||
coder.repo.get_tracked_files = mock
|
||||
|
||||
dump(fname)
|
||||
# Call the check_for_file_mentions method
|
||||
|
@ -178,110 +176,6 @@ class TestCoder(unittest.TestCase):
|
|||
|
||||
self.assertEqual(coder.abs_fnames, set([str(fname.resolve())]))
|
||||
|
||||
def test_get_commit_message(self):
|
||||
# Mock the IO object
|
||||
mock_io = MagicMock()
|
||||
|
||||
# Initialize the Coder object with the mocked IO and mocked repo
|
||||
coder = Coder.create(models.GPT4, None, mock_io)
|
||||
|
||||
# Mock the send method to set partial_response_content and return False
|
||||
def mock_send(*args, **kwargs):
|
||||
coder.partial_response_content = "a good commit message"
|
||||
return False
|
||||
|
||||
coder.send = MagicMock(side_effect=mock_send)
|
||||
|
||||
# Call the get_commit_message method with dummy diff and context
|
||||
result = coder.get_commit_message("dummy diff", "dummy context")
|
||||
|
||||
# Assert that the returned message is the expected one
|
||||
self.assertEqual(result, "a good commit message")
|
||||
|
||||
def test_get_commit_message_strip_quotes(self):
|
||||
# Mock the IO object
|
||||
mock_io = MagicMock()
|
||||
|
||||
# Initialize the Coder object with the mocked IO and mocked repo
|
||||
coder = Coder.create(models.GPT4, None, mock_io)
|
||||
|
||||
# Mock the send method to set partial_response_content and return False
|
||||
def mock_send(*args, **kwargs):
|
||||
coder.partial_response_content = "a good commit message"
|
||||
return False
|
||||
|
||||
coder.send = MagicMock(side_effect=mock_send)
|
||||
|
||||
# Call the get_commit_message method with dummy diff and context
|
||||
result = coder.get_commit_message("dummy diff", "dummy context")
|
||||
|
||||
# Assert that the returned message is the expected one
|
||||
self.assertEqual(result, "a good commit message")
|
||||
|
||||
def test_get_commit_message_no_strip_unmatched_quotes(self):
|
||||
# Mock the IO object
|
||||
mock_io = MagicMock()
|
||||
|
||||
# Initialize the Coder object with the mocked IO and mocked repo
|
||||
coder = Coder.create(models.GPT4, None, mock_io)
|
||||
|
||||
# Mock the send method to set partial_response_content and return False
|
||||
def mock_send(*args, **kwargs):
|
||||
coder.partial_response_content = 'a good "commit message"'
|
||||
return False
|
||||
|
||||
coder.send = MagicMock(side_effect=mock_send)
|
||||
|
||||
# Call the get_commit_message method with dummy diff and context
|
||||
result = coder.get_commit_message("dummy diff", "dummy context")
|
||||
|
||||
# Assert that the returned message is the expected one
|
||||
self.assertEqual(result, 'a good "commit message"')
|
||||
|
||||
@patch("aider.coders.base_coder.openai.ChatCompletion.create")
|
||||
@patch("builtins.print")
|
||||
def test_send_with_retries_rate_limit_error(self, mock_print, mock_chat_completion_create):
|
||||
# Mock the IO object
|
||||
mock_io = MagicMock()
|
||||
|
||||
# Initialize the Coder object with the mocked IO and mocked repo
|
||||
coder = Coder.create(models.GPT4, None, mock_io)
|
||||
|
||||
# Set up the mock to raise RateLimitError on
|
||||
# the first call and return None on the second call
|
||||
mock_chat_completion_create.side_effect = [
|
||||
openai.error.RateLimitError("Rate limit exceeded"),
|
||||
None,
|
||||
]
|
||||
|
||||
# Call the send_with_retries method
|
||||
coder.send_with_retries("model", ["message"], None)
|
||||
|
||||
# Assert that print was called once
|
||||
mock_print.assert_called_once()
|
||||
|
||||
@patch("aider.coders.base_coder.openai.ChatCompletion.create")
|
||||
@patch("builtins.print")
|
||||
def test_send_with_retries_connection_error(self, mock_print, mock_chat_completion_create):
|
||||
# Mock the IO object
|
||||
mock_io = MagicMock()
|
||||
|
||||
# Initialize the Coder object with the mocked IO and mocked repo
|
||||
coder = Coder.create(models.GPT4, None, mock_io)
|
||||
|
||||
# Set up the mock to raise ConnectionError on the first call
|
||||
# and return None on the second call
|
||||
mock_chat_completion_create.side_effect = [
|
||||
requests.exceptions.ConnectionError("Connection error"),
|
||||
None,
|
||||
]
|
||||
|
||||
# Call the send_with_retries method
|
||||
coder.send_with_retries("model", ["message"], None)
|
||||
|
||||
# Assert that print was called once
|
||||
mock_print.assert_called_once()
|
||||
|
||||
def test_run_with_file_deletion(self):
|
||||
# Create a few temporary files
|
||||
|
||||
|
@ -419,49 +313,5 @@ class TestCoder(unittest.TestCase):
|
|||
with self.assertRaises(openai.error.InvalidRequestError):
|
||||
coder.run(with_message="hi")
|
||||
|
||||
def test_get_tracked_files(self):
|
||||
# Create a temporary directory
|
||||
tempdir = Path(tempfile.mkdtemp())
|
||||
|
||||
# Initialize a git repository in the temporary directory and set user name and email
|
||||
repo = git.Repo.init(tempdir)
|
||||
repo.config_writer().set_value("user", "name", "Test User").release()
|
||||
repo.config_writer().set_value("user", "email", "testuser@example.com").release()
|
||||
|
||||
# Create three empty files and add them to the git repository
|
||||
filenames = ["README.md", "subdir/fänny.md", "systemüber/blick.md", 'file"with"quotes.txt']
|
||||
created_files = []
|
||||
for filename in filenames:
|
||||
file_path = tempdir / filename
|
||||
try:
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
file_path.touch()
|
||||
repo.git.add(str(file_path))
|
||||
created_files.append(Path(filename))
|
||||
except OSError:
|
||||
# windows won't allow files with quotes, that's ok
|
||||
self.assertIn('"', filename)
|
||||
self.assertEqual(os.name, "nt")
|
||||
|
||||
self.assertTrue(len(created_files) >= 3)
|
||||
|
||||
repo.git.commit("-m", "added")
|
||||
|
||||
# Create a Coder object on the temporary directory
|
||||
coder = Coder.create(
|
||||
models.GPT4,
|
||||
None,
|
||||
io=InputOutput(),
|
||||
fnames=[str(tempdir / filenames[0])],
|
||||
)
|
||||
|
||||
tracked_files = coder.get_tracked_files()
|
||||
|
||||
# On windows, paths will come back \like\this, so normalize them back to Paths
|
||||
tracked_files = [Path(fn) for fn in tracked_files]
|
||||
|
||||
# Assert that coder.get_tracked_files() returns the three filenames
|
||||
self.assertEqual(set(tracked_files), set(created_files))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -35,14 +35,42 @@ class TestMain(TestCase):
|
|||
main(["--no-git"], input=DummyInput(), output=DummyOutput())
|
||||
|
||||
def test_main_with_empty_dir_new_file(self):
|
||||
main(["foo.txt", "--yes"], input=DummyInput(), output=DummyOutput())
|
||||
main(["foo.txt", "--yes", "--no-git"], input=DummyInput(), output=DummyOutput())
|
||||
self.assertTrue(os.path.exists("foo.txt"))
|
||||
|
||||
def test_main_with_empty_git_dir_new_file(self):
|
||||
@patch("aider.repo.GitRepo.get_commit_message", return_value="mock commit message")
|
||||
def test_main_with_empty_git_dir_new_file(self, _):
|
||||
make_repo()
|
||||
main(["--yes", "foo.txt"], input=DummyInput(), output=DummyOutput())
|
||||
self.assertTrue(os.path.exists("foo.txt"))
|
||||
|
||||
@patch("aider.repo.GitRepo.get_commit_message", return_value="mock commit message")
|
||||
def test_main_with_empty_git_dir_new_files(self, _):
|
||||
make_repo()
|
||||
main(["--yes", "foo.txt", "bar.txt"], input=DummyInput(), output=DummyOutput())
|
||||
self.assertTrue(os.path.exists("foo.txt"))
|
||||
self.assertTrue(os.path.exists("bar.txt"))
|
||||
|
||||
def test_main_with_dname_and_fname(self):
|
||||
subdir = Path("subdir")
|
||||
subdir.mkdir()
|
||||
make_repo(str(subdir))
|
||||
res = main(["subdir", "foo.txt"], input=DummyInput(), output=DummyOutput())
|
||||
self.assertNotEqual(res, None)
|
||||
|
||||
@patch("aider.repo.GitRepo.get_commit_message", return_value="mock commit message")
|
||||
def test_main_with_subdir_repo_fnames(self, _):
|
||||
subdir = Path("subdir")
|
||||
subdir.mkdir()
|
||||
make_repo(str(subdir))
|
||||
main(
|
||||
["--yes", str(subdir / "foo.txt"), str(subdir / "bar.txt")],
|
||||
input=DummyInput(),
|
||||
output=DummyOutput(),
|
||||
)
|
||||
self.assertTrue((subdir / "foo.txt").exists())
|
||||
self.assertTrue((subdir / "bar.txt").exists())
|
||||
|
||||
def test_main_with_git_config_yml(self):
|
||||
make_repo()
|
||||
|
||||
|
|
114
tests/test_repo.py
Normal file
114
tests/test_repo.py
Normal file
|
@ -0,0 +1,114 @@
|
|||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import git
|
||||
|
||||
from aider.dump import dump # noqa: F401
|
||||
from aider.io import InputOutput
|
||||
from aider.repo import GitRepo
|
||||
from tests.utils import GitTemporaryDirectory
|
||||
|
||||
|
||||
class TestRepo(unittest.TestCase):
|
||||
@patch("aider.repo.simple_send_with_retries")
|
||||
def test_get_commit_message(self, mock_send):
|
||||
mock_send.return_value = "a good commit message"
|
||||
|
||||
repo = GitRepo(InputOutput(), None, None)
|
||||
# Call the get_commit_message method with dummy diff and context
|
||||
result = repo.get_commit_message("dummy diff", "dummy context")
|
||||
|
||||
# Assert that the returned message is the expected one
|
||||
self.assertEqual(result, "a good commit message")
|
||||
|
||||
@patch("aider.repo.simple_send_with_retries")
|
||||
def test_get_commit_message_strip_quotes(self, mock_send):
|
||||
mock_send.return_value = '"a good commit message"'
|
||||
|
||||
repo = GitRepo(InputOutput(), None, None)
|
||||
# Call the get_commit_message method with dummy diff and context
|
||||
result = repo.get_commit_message("dummy diff", "dummy context")
|
||||
|
||||
# Assert that the returned message is the expected one
|
||||
self.assertEqual(result, "a good commit message")
|
||||
|
||||
@patch("aider.repo.simple_send_with_retries")
|
||||
def test_get_commit_message_no_strip_unmatched_quotes(self, mock_send):
|
||||
mock_send.return_value = 'a good "commit message"'
|
||||
|
||||
repo = GitRepo(InputOutput(), None, None)
|
||||
# Call the get_commit_message method with dummy diff and context
|
||||
result = repo.get_commit_message("dummy diff", "dummy context")
|
||||
|
||||
# Assert that the returned message is the expected one
|
||||
self.assertEqual(result, 'a good "commit message"')
|
||||
|
||||
def test_get_tracked_files(self):
|
||||
# Create a temporary directory
|
||||
tempdir = Path(tempfile.mkdtemp())
|
||||
|
||||
# Initialize a git repository in the temporary directory and set user name and email
|
||||
repo = git.Repo.init(tempdir)
|
||||
repo.config_writer().set_value("user", "name", "Test User").release()
|
||||
repo.config_writer().set_value("user", "email", "testuser@example.com").release()
|
||||
|
||||
# Create three empty files and add them to the git repository
|
||||
filenames = ["README.md", "subdir/fänny.md", "systemüber/blick.md", 'file"with"quotes.txt']
|
||||
created_files = []
|
||||
for filename in filenames:
|
||||
file_path = tempdir / filename
|
||||
try:
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
file_path.touch()
|
||||
repo.git.add(str(file_path))
|
||||
created_files.append(Path(filename))
|
||||
except OSError:
|
||||
# windows won't allow files with quotes, that's ok
|
||||
self.assertIn('"', filename)
|
||||
self.assertEqual(os.name, "nt")
|
||||
|
||||
self.assertTrue(len(created_files) >= 3)
|
||||
|
||||
repo.git.commit("-m", "added")
|
||||
|
||||
tracked_files = GitRepo(InputOutput(), [tempdir], None).get_tracked_files()
|
||||
|
||||
# On windows, paths will come back \like\this, so normalize them back to Paths
|
||||
tracked_files = [Path(fn) for fn in tracked_files]
|
||||
|
||||
# Assert that coder.get_tracked_files() returns the three filenames
|
||||
self.assertEqual(set(tracked_files), set(created_files))
|
||||
|
||||
def test_get_tracked_files_with_new_staged_file(self):
|
||||
with GitTemporaryDirectory():
|
||||
# new repo
|
||||
raw_repo = git.Repo()
|
||||
|
||||
# add it, but no commits at all in the raw_repo yet
|
||||
fname = Path("new.txt")
|
||||
fname.touch()
|
||||
raw_repo.git.add(str(fname))
|
||||
|
||||
git_repo = GitRepo(InputOutput(), None, None)
|
||||
|
||||
# better be there
|
||||
fnames = git_repo.get_tracked_files()
|
||||
self.assertIn(str(fname), fnames)
|
||||
|
||||
# commit it, better still be there
|
||||
raw_repo.git.commit("-m", "new")
|
||||
fnames = git_repo.get_tracked_files()
|
||||
self.assertIn(str(fname), fnames)
|
||||
|
||||
# new file, added but not committed
|
||||
fname2 = Path("new2.txt")
|
||||
fname2.touch()
|
||||
raw_repo.git.add(str(fname2))
|
||||
|
||||
# both should be there
|
||||
fnames = git_repo.get_tracked_files()
|
||||
self.assertIn(str(fname), fnames)
|
||||
self.assertIn(str(fname2), fnames)
|
|
@ -4,7 +4,6 @@ from unittest.mock import patch
|
|||
|
||||
from aider.io import InputOutput
|
||||
from aider.repomap import RepoMap
|
||||
|
||||
from tests.utils import IgnorantTemporaryDirectory
|
||||
|
||||
|
||||
|
|
41
tests/test_sendchat.py
Normal file
41
tests/test_sendchat.py
Normal file
|
@ -0,0 +1,41 @@
|
|||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import openai
|
||||
import requests
|
||||
|
||||
from aider.sendchat import send_with_retries
|
||||
|
||||
|
||||
class TestSendChat(unittest.TestCase):
|
||||
@patch("aider.sendchat.openai.ChatCompletion.create")
|
||||
@patch("builtins.print")
|
||||
def test_send_with_retries_rate_limit_error(self, mock_print, mock_chat_completion_create):
|
||||
# Set up the mock to raise RateLimitError on
|
||||
# the first call and return None on the second call
|
||||
mock_chat_completion_create.side_effect = [
|
||||
openai.error.RateLimitError("Rate limit exceeded"),
|
||||
None,
|
||||
]
|
||||
|
||||
# Call the send_with_retries method
|
||||
send_with_retries("model", ["message"], None, False)
|
||||
|
||||
# Assert that print was called once
|
||||
mock_print.assert_called_once()
|
||||
|
||||
@patch("aider.sendchat.openai.ChatCompletion.create")
|
||||
@patch("builtins.print")
|
||||
def test_send_with_retries_connection_error(self, mock_print, mock_chat_completion_create):
|
||||
# Set up the mock to raise ConnectionError on the first call
|
||||
# and return None on the second call
|
||||
mock_chat_completion_create.side_effect = [
|
||||
requests.exceptions.ConnectionError("Connection error"),
|
||||
None,
|
||||
]
|
||||
|
||||
# Call the send_with_retries method
|
||||
send_with_retries("model", ["message"], None, False)
|
||||
|
||||
# Assert that print was called once
|
||||
mock_print.assert_called_once()
|
|
@ -42,7 +42,11 @@ class GitTemporaryDirectory(ChdirTemporaryDirectory):
|
|||
return res
|
||||
|
||||
|
||||
def make_repo():
|
||||
repo = git.Repo.init()
|
||||
def make_repo(path=None):
|
||||
if not path:
|
||||
path = "."
|
||||
repo = git.Repo.init(path)
|
||||
repo.config_writer().set_value("user", "name", "Test User").release()
|
||||
repo.config_writer().set_value("user", "email", "testuser@example.com").release()
|
||||
|
||||
return repo
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue