mirror of
https://github.com/Aider-AI/aider.git
synced 2025-05-29 00:35:00 +00:00
Merge pull request #137 from paul-gauthier/refactor-repo
Refactor git repo code into a new file
This commit is contained in:
commit
86309f336c
14 changed files with 642 additions and 570 deletions
|
@ -7,21 +7,19 @@ import sys
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from json.decoder import JSONDecodeError
|
from json.decoder import JSONDecodeError
|
||||||
from pathlib import Path, PurePosixPath
|
from pathlib import Path
|
||||||
|
|
||||||
import backoff
|
|
||||||
import git
|
|
||||||
import openai
|
import openai
|
||||||
import requests
|
|
||||||
from jsonschema import Draft7Validator
|
from jsonschema import Draft7Validator
|
||||||
from openai.error import APIError, RateLimitError, ServiceUnavailableError, Timeout
|
|
||||||
from rich.console import Console, Text
|
from rich.console import Console, Text
|
||||||
from rich.live import Live
|
from rich.live import Live
|
||||||
from rich.markdown import Markdown
|
from rich.markdown import Markdown
|
||||||
|
|
||||||
from aider import models, prompts, utils
|
from aider import models, prompts, utils
|
||||||
from aider.commands import Commands
|
from aider.commands import Commands
|
||||||
|
from aider.repo import GitRepo
|
||||||
from aider.repomap import RepoMap
|
from aider.repomap import RepoMap
|
||||||
|
from aider.sendchat import send_with_retries
|
||||||
|
|
||||||
from ..dump import dump # noqa: F401
|
from ..dump import dump # noqa: F401
|
||||||
|
|
||||||
|
@ -100,6 +98,7 @@ class Coder:
|
||||||
main_model,
|
main_model,
|
||||||
io,
|
io,
|
||||||
fnames=None,
|
fnames=None,
|
||||||
|
git_dname=None,
|
||||||
pretty=True,
|
pretty=True,
|
||||||
show_diffs=False,
|
show_diffs=False,
|
||||||
auto_commits=True,
|
auto_commits=True,
|
||||||
|
@ -150,13 +149,27 @@ class Coder:
|
||||||
|
|
||||||
self.commands = Commands(self.io, self)
|
self.commands = Commands(self.io, self)
|
||||||
|
|
||||||
|
for fname in fnames:
|
||||||
|
fname = Path(fname)
|
||||||
|
if not fname.exists():
|
||||||
|
self.io.tool_output(f"Creating empty file {fname}")
|
||||||
|
fname.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
fname.touch()
|
||||||
|
|
||||||
|
if not fname.is_file():
|
||||||
|
raise ValueError(f"{fname} is not a file")
|
||||||
|
|
||||||
|
self.abs_fnames.add(str(fname.resolve()))
|
||||||
|
|
||||||
if use_git:
|
if use_git:
|
||||||
self.set_repo(fnames)
|
try:
|
||||||
else:
|
self.repo = GitRepo(self.io, fnames, git_dname)
|
||||||
self.abs_fnames = set([str(Path(fname).resolve()) for fname in fnames])
|
self.root = self.repo.root
|
||||||
|
except FileNotFoundError:
|
||||||
|
self.repo = None
|
||||||
|
|
||||||
if self.repo:
|
if self.repo:
|
||||||
rel_repo_dir = self.get_rel_repo_dir()
|
rel_repo_dir = self.repo.get_rel_repo_dir()
|
||||||
self.io.tool_output(f"Git repo: {rel_repo_dir}")
|
self.io.tool_output(f"Git repo: {rel_repo_dir}")
|
||||||
else:
|
else:
|
||||||
self.io.tool_output("Git repo: none")
|
self.io.tool_output("Git repo: none")
|
||||||
|
@ -187,6 +200,9 @@ class Coder:
|
||||||
for fname in self.get_inchat_relative_files():
|
for fname in self.get_inchat_relative_files():
|
||||||
self.io.tool_output(f"Added {fname} to the chat.")
|
self.io.tool_output(f"Added {fname} to the chat.")
|
||||||
|
|
||||||
|
if self.repo:
|
||||||
|
self.repo.add_new_files(fname for fname in fnames if not Path(fname).is_dir())
|
||||||
|
|
||||||
# validate the functions jsonschema
|
# validate the functions jsonschema
|
||||||
if self.functions:
|
if self.functions:
|
||||||
for function in self.functions:
|
for function in self.functions:
|
||||||
|
@ -206,12 +222,6 @@ class Coder:
|
||||||
|
|
||||||
self.root = utils.safe_abs_path(self.root)
|
self.root = utils.safe_abs_path(self.root)
|
||||||
|
|
||||||
def get_rel_repo_dir(self):
|
|
||||||
try:
|
|
||||||
return os.path.relpath(self.repo.git_dir, os.getcwd())
|
|
||||||
except ValueError:
|
|
||||||
return self.repo.git_dir
|
|
||||||
|
|
||||||
def add_rel_fname(self, rel_fname):
|
def add_rel_fname(self, rel_fname):
|
||||||
self.abs_fnames.add(self.abs_root_path(rel_fname))
|
self.abs_fnames.add(self.abs_root_path(rel_fname))
|
||||||
|
|
||||||
|
@ -219,73 +229,6 @@ class Coder:
|
||||||
res = Path(self.root) / path
|
res = Path(self.root) / path
|
||||||
return utils.safe_abs_path(res)
|
return utils.safe_abs_path(res)
|
||||||
|
|
||||||
def set_repo(self, cmd_line_fnames):
|
|
||||||
if not cmd_line_fnames:
|
|
||||||
cmd_line_fnames = ["."]
|
|
||||||
|
|
||||||
repo_paths = []
|
|
||||||
for fname in cmd_line_fnames:
|
|
||||||
fname = Path(fname)
|
|
||||||
if not fname.exists():
|
|
||||||
self.io.tool_output(f"Creating empty file {fname}")
|
|
||||||
fname.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
fname.touch()
|
|
||||||
|
|
||||||
fname = fname.resolve()
|
|
||||||
|
|
||||||
try:
|
|
||||||
repo_path = git.Repo(fname, search_parent_directories=True).working_dir
|
|
||||||
repo_path = utils.safe_abs_path(repo_path)
|
|
||||||
repo_paths.append(repo_path)
|
|
||||||
except git.exc.InvalidGitRepositoryError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if fname.is_dir():
|
|
||||||
continue
|
|
||||||
|
|
||||||
self.abs_fnames.add(str(fname))
|
|
||||||
|
|
||||||
num_repos = len(set(repo_paths))
|
|
||||||
|
|
||||||
if num_repos == 0:
|
|
||||||
return
|
|
||||||
if num_repos > 1:
|
|
||||||
self.io.tool_error("Files are in different git repos.")
|
|
||||||
return
|
|
||||||
|
|
||||||
# https://github.com/gitpython-developers/GitPython/issues/427
|
|
||||||
self.repo = git.Repo(repo_paths.pop(), odbt=git.GitDB)
|
|
||||||
|
|
||||||
self.root = utils.safe_abs_path(self.repo.working_tree_dir)
|
|
||||||
|
|
||||||
new_files = []
|
|
||||||
for fname in self.abs_fnames:
|
|
||||||
relative_fname = self.get_rel_fname(fname)
|
|
||||||
|
|
||||||
tracked_files = set(self.get_tracked_files())
|
|
||||||
if relative_fname not in tracked_files:
|
|
||||||
new_files.append(relative_fname)
|
|
||||||
|
|
||||||
if new_files:
|
|
||||||
rel_repo_dir = self.get_rel_repo_dir()
|
|
||||||
|
|
||||||
self.io.tool_output(f"Files not tracked in {rel_repo_dir}:")
|
|
||||||
for fn in new_files:
|
|
||||||
self.io.tool_output(f" - {fn}")
|
|
||||||
if self.io.confirm_ask("Add them?"):
|
|
||||||
for relative_fname in new_files:
|
|
||||||
self.repo.git.add(relative_fname)
|
|
||||||
self.io.tool_output(f"Added {relative_fname} to the git repo")
|
|
||||||
show_files = ", ".join(new_files)
|
|
||||||
commit_message = f"Added new files to the git repo: {show_files}"
|
|
||||||
self.repo.git.commit("-m", commit_message, "--no-verify")
|
|
||||||
commit_hash = self.repo.head.commit.hexsha[:7]
|
|
||||||
self.io.tool_output(f"Commit {commit_hash} {commit_message}")
|
|
||||||
else:
|
|
||||||
self.io.tool_error("Skipped adding new files to the git repo.")
|
|
||||||
return
|
|
||||||
|
|
||||||
# fences are obfuscated so aider can modify this file!
|
|
||||||
fences = [
|
fences = [
|
||||||
("``" + "`", "``" + "`"),
|
("``" + "`", "``" + "`"),
|
||||||
wrap_fence("source"),
|
wrap_fence("source"),
|
||||||
|
@ -412,25 +355,6 @@ class Coder:
|
||||||
|
|
||||||
self.last_keyboard_interrupt = now
|
self.last_keyboard_interrupt = now
|
||||||
|
|
||||||
def should_dirty_commit(self, inp):
|
|
||||||
cmds = self.commands.matching_commands(inp)
|
|
||||||
if cmds:
|
|
||||||
matching_commands, _, _ = cmds
|
|
||||||
if len(matching_commands) == 1:
|
|
||||||
cmd = matching_commands[0][1:]
|
|
||||||
if cmd in "add clear commit diff drop exit help ls tokens".split():
|
|
||||||
return
|
|
||||||
|
|
||||||
if not self.dirty_commits:
|
|
||||||
return
|
|
||||||
if not self.repo:
|
|
||||||
return
|
|
||||||
if not self.repo.is_dirty():
|
|
||||||
return
|
|
||||||
if self.last_asked_for_commit_time >= self.get_last_modified():
|
|
||||||
return
|
|
||||||
return True
|
|
||||||
|
|
||||||
def move_back_cur_messages(self, message):
|
def move_back_cur_messages(self, message):
|
||||||
self.done_messages += self.cur_messages
|
self.done_messages += self.cur_messages
|
||||||
if message:
|
if message:
|
||||||
|
@ -448,13 +372,7 @@ class Coder:
|
||||||
self.commands,
|
self.commands,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.should_dirty_commit(inp):
|
if self.should_dirty_commit(inp) and self.dirty_commit():
|
||||||
self.io.tool_output("Git repo has uncommitted changes, preparing commit...")
|
|
||||||
self.commit(ask=True, which="repo_files")
|
|
||||||
|
|
||||||
# files changed, move cur messages back behind the files messages
|
|
||||||
self.move_back_cur_messages(self.gpt_prompts.files_content_local_edits)
|
|
||||||
|
|
||||||
if inp.strip():
|
if inp.strip():
|
||||||
self.io.tool_output("Use up-arrow to retry previous command:", inp)
|
self.io.tool_output("Use up-arrow to retry previous command:", inp)
|
||||||
return
|
return
|
||||||
|
@ -569,23 +487,6 @@ class Coder:
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
def auto_commit(self):
|
|
||||||
res = self.commit(history=self.cur_messages, prefix="aider: ")
|
|
||||||
if res:
|
|
||||||
commit_hash, commit_message = res
|
|
||||||
self.last_aider_commit_hash = commit_hash
|
|
||||||
|
|
||||||
saved_message = self.gpt_prompts.files_content_gpt_edits.format(
|
|
||||||
hash=commit_hash,
|
|
||||||
message=commit_message,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if self.repo:
|
|
||||||
self.io.tool_output("No changes made to git tracked files.")
|
|
||||||
saved_message = self.gpt_prompts.files_content_gpt_no_edits
|
|
||||||
|
|
||||||
return saved_message
|
|
||||||
|
|
||||||
def check_for_file_mentions(self, content):
|
def check_for_file_mentions(self, content):
|
||||||
words = set(word for word in content.split())
|
words = set(word for word in content.split())
|
||||||
|
|
||||||
|
@ -627,44 +528,7 @@ class Coder:
|
||||||
|
|
||||||
return prompts.added_files.format(fnames=", ".join(mentioned_rel_fnames))
|
return prompts.added_files.format(fnames=", ".join(mentioned_rel_fnames))
|
||||||
|
|
||||||
@backoff.on_exception(
|
def send(self, messages, model=None, functions=None):
|
||||||
backoff.expo,
|
|
||||||
(
|
|
||||||
Timeout,
|
|
||||||
APIError,
|
|
||||||
ServiceUnavailableError,
|
|
||||||
RateLimitError,
|
|
||||||
requests.exceptions.ConnectionError,
|
|
||||||
),
|
|
||||||
max_tries=10,
|
|
||||||
on_backoff=lambda details: print(
|
|
||||||
f"{details.get('exception','Exception')}\nRetry in {details['wait']:.1f} seconds."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
def send_with_retries(self, model, messages, functions):
|
|
||||||
kwargs = dict(
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
temperature=0,
|
|
||||||
stream=self.stream,
|
|
||||||
)
|
|
||||||
if functions is not None:
|
|
||||||
kwargs["functions"] = self.functions
|
|
||||||
|
|
||||||
# we are abusing the openai object to stash these values
|
|
||||||
if hasattr(openai, "api_deployment_id"):
|
|
||||||
kwargs["deployment_id"] = openai.api_deployment_id
|
|
||||||
if hasattr(openai, "api_engine"):
|
|
||||||
kwargs["engine"] = openai.api_engine
|
|
||||||
|
|
||||||
# Generate SHA1 hash of kwargs and append it to chat_completion_call_hashes
|
|
||||||
hash_object = hashlib.sha1(json.dumps(kwargs, sort_keys=True).encode())
|
|
||||||
self.chat_completion_call_hashes.append(hash_object.hexdigest())
|
|
||||||
|
|
||||||
res = openai.ChatCompletion.create(**kwargs)
|
|
||||||
return res
|
|
||||||
|
|
||||||
def send(self, messages, model=None, silent=False, functions=None):
|
|
||||||
if not model:
|
if not model:
|
||||||
model = self.main_model.name
|
model = self.main_model.name
|
||||||
|
|
||||||
|
@ -673,27 +537,28 @@ class Coder:
|
||||||
|
|
||||||
interrupted = False
|
interrupted = False
|
||||||
try:
|
try:
|
||||||
completion = self.send_with_retries(model, messages, functions)
|
hash_object, completion = send_with_retries(model, messages, functions, self.stream)
|
||||||
|
self.chat_completion_call_hashes.append(hash_object.hexdigest())
|
||||||
|
|
||||||
if self.stream:
|
if self.stream:
|
||||||
self.show_send_output_stream(completion, silent)
|
self.show_send_output_stream(completion)
|
||||||
else:
|
else:
|
||||||
self.show_send_output(completion, silent)
|
self.show_send_output(completion)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
self.keyboard_interrupt()
|
self.keyboard_interrupt()
|
||||||
interrupted = True
|
interrupted = True
|
||||||
|
|
||||||
if not silent:
|
if self.partial_response_content:
|
||||||
if self.partial_response_content:
|
self.io.ai_output(self.partial_response_content)
|
||||||
self.io.ai_output(self.partial_response_content)
|
elif self.partial_response_function_call:
|
||||||
elif self.partial_response_function_call:
|
# TODO: push this into subclasses
|
||||||
# TODO: push this into subclasses
|
args = self.parse_partial_args()
|
||||||
args = self.parse_partial_args()
|
if args:
|
||||||
if args:
|
self.io.ai_output(json.dumps(args, indent=4))
|
||||||
self.io.ai_output(json.dumps(args, indent=4))
|
|
||||||
|
|
||||||
return interrupted
|
return interrupted
|
||||||
|
|
||||||
def show_send_output(self, completion, silent):
|
def show_send_output(self, completion):
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(completion)
|
print(completion)
|
||||||
|
|
||||||
|
@ -742,9 +607,9 @@ class Coder:
|
||||||
self.io.console.print(show_resp)
|
self.io.console.print(show_resp)
|
||||||
self.io.console.print(tokens)
|
self.io.console.print(tokens)
|
||||||
|
|
||||||
def show_send_output_stream(self, completion, silent):
|
def show_send_output_stream(self, completion):
|
||||||
live = None
|
live = None
|
||||||
if self.pretty and not silent:
|
if self.pretty:
|
||||||
live = Live(vertical_overflow="scroll")
|
live = Live(vertical_overflow="scroll")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -773,9 +638,6 @@ class Coder:
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if silent:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if self.pretty:
|
if self.pretty:
|
||||||
self.live_incremental_response(live, False)
|
self.live_incremental_response(live, False)
|
||||||
else:
|
else:
|
||||||
|
@ -797,145 +659,6 @@ class Coder:
|
||||||
def render_incremental_response(self, final):
|
def render_incremental_response(self, final):
|
||||||
return self.partial_response_content
|
return self.partial_response_content
|
||||||
|
|
||||||
def get_context_from_history(self, history):
|
|
||||||
context = ""
|
|
||||||
if history:
|
|
||||||
for msg in history:
|
|
||||||
context += "\n" + msg["role"].upper() + ": " + msg["content"] + "\n"
|
|
||||||
return context
|
|
||||||
|
|
||||||
def get_commit_message(self, diffs, context):
|
|
||||||
if len(diffs) >= 4 * 1024 * 4:
|
|
||||||
self.io.tool_error(
|
|
||||||
f"Diff is too large for {models.GPT35.name} to generate a commit message."
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
diffs = "# Diffs:\n" + diffs
|
|
||||||
|
|
||||||
messages = [
|
|
||||||
dict(role="system", content=prompts.commit_system),
|
|
||||||
dict(role="user", content=context + diffs),
|
|
||||||
]
|
|
||||||
|
|
||||||
try:
|
|
||||||
interrupted = self.send(
|
|
||||||
messages,
|
|
||||||
model=models.GPT35.name,
|
|
||||||
silent=True,
|
|
||||||
)
|
|
||||||
except openai.error.InvalidRequestError:
|
|
||||||
self.io.tool_error(
|
|
||||||
f"Failed to generate commit message using {models.GPT35.name} due to an invalid"
|
|
||||||
" request."
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
commit_message = self.partial_response_content
|
|
||||||
commit_message = commit_message.strip()
|
|
||||||
if commit_message and commit_message[0] == '"' and commit_message[-1] == '"':
|
|
||||||
commit_message = commit_message[1:-1].strip()
|
|
||||||
|
|
||||||
if interrupted:
|
|
||||||
self.io.tool_error(
|
|
||||||
f"Unable to get commit message from {models.GPT35.name}. Use /commit to try again."
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
return commit_message
|
|
||||||
|
|
||||||
def get_diffs(self, *args):
|
|
||||||
if self.pretty:
|
|
||||||
args = ["--color"] + list(args)
|
|
||||||
|
|
||||||
diffs = self.repo.git.diff(*args)
|
|
||||||
return diffs
|
|
||||||
|
|
||||||
def commit(self, history=None, prefix=None, ask=False, message=None, which="chat_files"):
|
|
||||||
repo = self.repo
|
|
||||||
if not repo:
|
|
||||||
return
|
|
||||||
|
|
||||||
if not repo.is_dirty():
|
|
||||||
return
|
|
||||||
|
|
||||||
def get_dirty_files_and_diffs(file_list):
|
|
||||||
diffs = ""
|
|
||||||
relative_dirty_files = []
|
|
||||||
for fname in file_list:
|
|
||||||
relative_fname = self.get_rel_fname(fname)
|
|
||||||
relative_dirty_files.append(relative_fname)
|
|
||||||
|
|
||||||
try:
|
|
||||||
current_branch_commit_count = len(
|
|
||||||
list(self.repo.iter_commits(self.repo.active_branch))
|
|
||||||
)
|
|
||||||
except git.exc.GitCommandError:
|
|
||||||
current_branch_commit_count = None
|
|
||||||
|
|
||||||
if not current_branch_commit_count:
|
|
||||||
continue
|
|
||||||
|
|
||||||
these_diffs = self.get_diffs("HEAD", "--", relative_fname)
|
|
||||||
|
|
||||||
if these_diffs:
|
|
||||||
diffs += these_diffs + "\n"
|
|
||||||
|
|
||||||
return relative_dirty_files, diffs
|
|
||||||
|
|
||||||
if which == "repo_files":
|
|
||||||
all_files = [os.path.join(self.root, f) for f in self.get_all_relative_files()]
|
|
||||||
relative_dirty_fnames, diffs = get_dirty_files_and_diffs(all_files)
|
|
||||||
elif which == "chat_files":
|
|
||||||
relative_dirty_fnames, diffs = get_dirty_files_and_diffs(self.abs_fnames)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid value for 'which': {which}")
|
|
||||||
|
|
||||||
if self.show_diffs or ask:
|
|
||||||
# don't use io.tool_output() because we don't want to log or further colorize
|
|
||||||
print(diffs)
|
|
||||||
|
|
||||||
context = self.get_context_from_history(history)
|
|
||||||
if message:
|
|
||||||
commit_message = message
|
|
||||||
else:
|
|
||||||
commit_message = self.get_commit_message(diffs, context)
|
|
||||||
|
|
||||||
if not commit_message:
|
|
||||||
commit_message = "work in progress"
|
|
||||||
|
|
||||||
if prefix:
|
|
||||||
commit_message = prefix + commit_message
|
|
||||||
|
|
||||||
if ask:
|
|
||||||
if which == "repo_files":
|
|
||||||
self.io.tool_output("Git repo has uncommitted changes.")
|
|
||||||
else:
|
|
||||||
self.io.tool_output("Files have uncommitted changes.")
|
|
||||||
|
|
||||||
res = self.io.prompt_ask(
|
|
||||||
"Commit before the chat proceeds [y/n/commit message]?",
|
|
||||||
default=commit_message,
|
|
||||||
).strip()
|
|
||||||
self.last_asked_for_commit_time = self.get_last_modified()
|
|
||||||
|
|
||||||
self.io.tool_output()
|
|
||||||
|
|
||||||
if res.lower() in ["n", "no"]:
|
|
||||||
self.io.tool_error("Skipped commmit.")
|
|
||||||
return
|
|
||||||
if res.lower() not in ["y", "yes"] and res:
|
|
||||||
commit_message = res
|
|
||||||
|
|
||||||
repo.git.add(*relative_dirty_fnames)
|
|
||||||
|
|
||||||
full_commit_message = commit_message + "\n\n# Aider chat conversation:\n\n" + context
|
|
||||||
repo.git.commit("-m", full_commit_message, "--no-verify")
|
|
||||||
commit_hash = repo.head.commit.hexsha[:7]
|
|
||||||
self.io.tool_output(f"Commit {commit_hash} {commit_message}")
|
|
||||||
|
|
||||||
return commit_hash, commit_message
|
|
||||||
|
|
||||||
def get_rel_fname(self, fname):
|
def get_rel_fname(self, fname):
|
||||||
return os.path.relpath(fname, self.root)
|
return os.path.relpath(fname, self.root)
|
||||||
|
|
||||||
|
@ -945,7 +668,7 @@ class Coder:
|
||||||
|
|
||||||
def get_all_relative_files(self):
|
def get_all_relative_files(self):
|
||||||
if self.repo:
|
if self.repo:
|
||||||
files = self.get_tracked_files()
|
files = self.repo.get_tracked_files()
|
||||||
else:
|
else:
|
||||||
files = self.get_inchat_relative_files()
|
files = self.get_inchat_relative_files()
|
||||||
|
|
||||||
|
@ -1000,32 +723,6 @@ class Coder:
|
||||||
|
|
||||||
return full_path
|
return full_path
|
||||||
|
|
||||||
def get_tracked_files(self):
|
|
||||||
if not self.repo:
|
|
||||||
return []
|
|
||||||
|
|
||||||
try:
|
|
||||||
commit = self.repo.head.commit
|
|
||||||
except ValueError:
|
|
||||||
commit = None
|
|
||||||
|
|
||||||
files = []
|
|
||||||
if commit:
|
|
||||||
for blob in commit.tree.traverse():
|
|
||||||
if blob.type == "blob": # blob is a file
|
|
||||||
files.append(blob.path)
|
|
||||||
|
|
||||||
# Add staged files
|
|
||||||
index = self.repo.index
|
|
||||||
staged_files = [path for path, _ in index.entries.keys()]
|
|
||||||
|
|
||||||
files.extend(staged_files)
|
|
||||||
|
|
||||||
# convert to appropriate os.sep, since git always normalizes to /
|
|
||||||
res = set(str(Path(PurePosixPath(path))) for path in files)
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
apply_update_errors = 0
|
apply_update_errors = 0
|
||||||
|
|
||||||
def apply_updates(self):
|
def apply_updates(self):
|
||||||
|
@ -1094,6 +791,72 @@ class Coder:
|
||||||
except JSONDecodeError:
|
except JSONDecodeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# commits...
|
||||||
|
|
||||||
|
def get_context_from_history(self, history):
|
||||||
|
context = ""
|
||||||
|
if history:
|
||||||
|
for msg in history:
|
||||||
|
context += "\n" + msg["role"].upper() + ": " + msg["content"] + "\n"
|
||||||
|
return context
|
||||||
|
|
||||||
|
def auto_commit(self):
|
||||||
|
context = self.get_context_from_history(self.cur_messages)
|
||||||
|
res = self.repo.commit(context=context, prefix="aider: ")
|
||||||
|
if res:
|
||||||
|
commit_hash, commit_message = res
|
||||||
|
self.last_aider_commit_hash = commit_hash
|
||||||
|
|
||||||
|
return self.gpt_prompts.files_content_gpt_edits.format(
|
||||||
|
hash=commit_hash,
|
||||||
|
message=commit_message,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.io.tool_output("No changes made to git tracked files.")
|
||||||
|
return self.gpt_prompts.files_content_gpt_no_edits
|
||||||
|
|
||||||
|
def should_dirty_commit(self, inp):
|
||||||
|
cmds = self.commands.matching_commands(inp)
|
||||||
|
if cmds:
|
||||||
|
matching_commands, _, _ = cmds
|
||||||
|
if len(matching_commands) == 1:
|
||||||
|
cmd = matching_commands[0][1:]
|
||||||
|
if cmd in "add clear commit diff drop exit help ls tokens".split():
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.last_asked_for_commit_time >= self.get_last_modified():
|
||||||
|
return
|
||||||
|
return True
|
||||||
|
|
||||||
|
def dirty_commit(self):
|
||||||
|
if not self.dirty_commits:
|
||||||
|
return
|
||||||
|
if not self.repo:
|
||||||
|
return
|
||||||
|
if not self.repo.is_dirty():
|
||||||
|
return
|
||||||
|
|
||||||
|
self.io.tool_output("Git repo has uncommitted changes.")
|
||||||
|
self.repo.show_diffs(self.pretty)
|
||||||
|
self.last_asked_for_commit_time = self.get_last_modified()
|
||||||
|
res = self.io.prompt_ask(
|
||||||
|
"Commit before the chat proceeds [y/n/commit message]?",
|
||||||
|
default="y",
|
||||||
|
).strip()
|
||||||
|
if res.lower() in ["n", "no"]:
|
||||||
|
self.io.tool_error("Skipped commmit.")
|
||||||
|
return
|
||||||
|
if res.lower() in ["y", "yes"]:
|
||||||
|
message = None
|
||||||
|
else:
|
||||||
|
message = res.strip()
|
||||||
|
|
||||||
|
self.repo.commit(message=message)
|
||||||
|
|
||||||
|
# files changed, move cur messages back behind the files messages
|
||||||
|
self.move_back_cur_messages(self.gpt_prompts.files_content_local_edits)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def check_model_availability(main_model):
|
def check_model_availability(main_model):
|
||||||
available_models = openai.Model.list()
|
available_models = openai.Model.list()
|
||||||
|
|
|
@ -42,15 +42,6 @@ class SingleWholeFileFunctionCoder(Coder):
|
||||||
else:
|
else:
|
||||||
self.cur_messages += [dict(role="assistant", content=self.partial_response_content)]
|
self.cur_messages += [dict(role="assistant", content=self.partial_response_content)]
|
||||||
|
|
||||||
def get_context_from_history(self, history):
|
|
||||||
context = ""
|
|
||||||
if history:
|
|
||||||
context += "# Context:\n"
|
|
||||||
for msg in history:
|
|
||||||
if msg["role"] == "user":
|
|
||||||
context += msg["role"].upper() + ": " + msg["content"] + "\n"
|
|
||||||
return context
|
|
||||||
|
|
||||||
def render_incremental_response(self, final=False):
|
def render_incremental_response(self, final=False):
|
||||||
if self.partial_response_content:
|
if self.partial_response_content:
|
||||||
return self.partial_response_content
|
return self.partial_response_content
|
||||||
|
|
|
@ -20,15 +20,6 @@ class WholeFileCoder(Coder):
|
||||||
else:
|
else:
|
||||||
self.cur_messages += [dict(role="assistant", content=self.partial_response_content)]
|
self.cur_messages += [dict(role="assistant", content=self.partial_response_content)]
|
||||||
|
|
||||||
def get_context_from_history(self, history):
|
|
||||||
context = ""
|
|
||||||
if history:
|
|
||||||
context += "# Context:\n"
|
|
||||||
for msg in history:
|
|
||||||
if msg["role"] == "user":
|
|
||||||
context += msg["role"].upper() + ": " + msg["content"] + "\n"
|
|
||||||
return context
|
|
||||||
|
|
||||||
def render_incremental_response(self, final):
|
def render_incremental_response(self, final):
|
||||||
try:
|
try:
|
||||||
return self.update_files(mode="diff")
|
return self.update_files(mode="diff")
|
||||||
|
|
|
@ -55,15 +55,6 @@ class WholeFileFunctionCoder(Coder):
|
||||||
else:
|
else:
|
||||||
self.cur_messages += [dict(role="assistant", content=self.partial_response_content)]
|
self.cur_messages += [dict(role="assistant", content=self.partial_response_content)]
|
||||||
|
|
||||||
def get_context_from_history(self, history):
|
|
||||||
context = ""
|
|
||||||
if history:
|
|
||||||
context += "# Context:\n"
|
|
||||||
for msg in history:
|
|
||||||
if msg["role"] == "user":
|
|
||||||
context += msg["role"].upper() + ": " + msg["content"] + "\n"
|
|
||||||
return context
|
|
||||||
|
|
||||||
def render_incremental_response(self, final=False):
|
def render_incremental_response(self, final=False):
|
||||||
if self.partial_response_content:
|
if self.partial_response_content:
|
||||||
return self.partial_response_content
|
return self.partial_response_content
|
||||||
|
|
|
@ -176,10 +176,10 @@ class Commands:
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
local_head = self.coder.repo.git.rev_parse("HEAD")
|
local_head = self.coder.repo.repo.git.rev_parse("HEAD")
|
||||||
current_branch = self.coder.repo.active_branch.name
|
current_branch = self.coder.repo.repo.active_branch.name
|
||||||
try:
|
try:
|
||||||
remote_head = self.coder.repo.git.rev_parse(f"origin/{current_branch}")
|
remote_head = self.coder.repo.repo.git.rev_parse(f"origin/{current_branch}")
|
||||||
has_origin = True
|
has_origin = True
|
||||||
except git.exc.GitCommandError:
|
except git.exc.GitCommandError:
|
||||||
has_origin = False
|
has_origin = False
|
||||||
|
@ -192,14 +192,14 @@ class Commands:
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
last_commit = self.coder.repo.head.commit
|
last_commit = self.coder.repo.repo.head.commit
|
||||||
if (
|
if (
|
||||||
not last_commit.message.startswith("aider:")
|
not last_commit.message.startswith("aider:")
|
||||||
or last_commit.hexsha[:7] != self.coder.last_aider_commit_hash
|
or last_commit.hexsha[:7] != self.coder.last_aider_commit_hash
|
||||||
):
|
):
|
||||||
self.io.tool_error("The last commit was not made by aider in this chat session.")
|
self.io.tool_error("The last commit was not made by aider in this chat session.")
|
||||||
return
|
return
|
||||||
self.coder.repo.git.reset("--hard", "HEAD~1")
|
self.coder.repo.repo.git.reset("--hard", "HEAD~1")
|
||||||
self.io.tool_output(
|
self.io.tool_output(
|
||||||
f"{last_commit.message.strip()}\n"
|
f"{last_commit.message.strip()}\n"
|
||||||
f"The above commit {self.coder.last_aider_commit_hash} "
|
f"The above commit {self.coder.last_aider_commit_hash} "
|
||||||
|
@ -220,7 +220,11 @@ class Commands:
|
||||||
return
|
return
|
||||||
|
|
||||||
commits = f"{self.coder.last_aider_commit_hash}~1"
|
commits = f"{self.coder.last_aider_commit_hash}~1"
|
||||||
diff = self.coder.get_diffs(commits, self.coder.last_aider_commit_hash)
|
diff = self.coder.repo.get_diffs(
|
||||||
|
self.coder.pretty,
|
||||||
|
commits,
|
||||||
|
self.coder.last_aider_commit_hash,
|
||||||
|
)
|
||||||
|
|
||||||
# don't use io.tool_output() because we don't want to log or further colorize
|
# don't use io.tool_output() because we don't want to log or further colorize
|
||||||
print(diff)
|
print(diff)
|
||||||
|
@ -243,7 +247,7 @@ class Commands:
|
||||||
|
|
||||||
# if repo, filter against it
|
# if repo, filter against it
|
||||||
if self.coder.repo:
|
if self.coder.repo:
|
||||||
git_files = self.coder.get_tracked_files()
|
git_files = self.coder.repo.get_tracked_files()
|
||||||
matched_files = [fn for fn in matched_files if str(fn) in git_files]
|
matched_files = [fn for fn in matched_files if str(fn) in git_files]
|
||||||
|
|
||||||
res = list(map(str, matched_files))
|
res = list(map(str, matched_files))
|
||||||
|
@ -254,7 +258,7 @@ class Commands:
|
||||||
|
|
||||||
added_fnames = []
|
added_fnames = []
|
||||||
git_added = []
|
git_added = []
|
||||||
git_files = self.coder.get_tracked_files()
|
git_files = self.coder.repo.get_tracked_files() if self.coder.repo else []
|
||||||
|
|
||||||
all_matched_files = set()
|
all_matched_files = set()
|
||||||
for word in args.split():
|
for word in args.split():
|
||||||
|
@ -281,7 +285,7 @@ class Commands:
|
||||||
abs_file_path = self.coder.abs_root_path(matched_file)
|
abs_file_path = self.coder.abs_root_path(matched_file)
|
||||||
|
|
||||||
if self.coder.repo and matched_file not in git_files:
|
if self.coder.repo and matched_file not in git_files:
|
||||||
self.coder.repo.git.add(abs_file_path)
|
self.coder.repo.repo.git.add(abs_file_path)
|
||||||
git_added.append(matched_file)
|
git_added.append(matched_file)
|
||||||
|
|
||||||
if abs_file_path in self.coder.abs_fnames:
|
if abs_file_path in self.coder.abs_fnames:
|
||||||
|
@ -298,8 +302,8 @@ class Commands:
|
||||||
if self.coder.repo and git_added:
|
if self.coder.repo and git_added:
|
||||||
git_added = " ".join(git_added)
|
git_added = " ".join(git_added)
|
||||||
commit_message = f"aider: Added {git_added}"
|
commit_message = f"aider: Added {git_added}"
|
||||||
self.coder.repo.git.commit("-m", commit_message, "--no-verify")
|
self.coder.repo.repo.git.commit("-m", commit_message, "--no-verify")
|
||||||
commit_hash = self.coder.repo.head.commit.hexsha[:7]
|
commit_hash = self.coder.repo.repo.head.commit.hexsha[:7]
|
||||||
self.io.tool_output(f"Commit {commit_hash} {commit_message}")
|
self.io.tool_output(f"Commit {commit_hash} {commit_message}")
|
||||||
|
|
||||||
if not added_fnames:
|
if not added_fnames:
|
||||||
|
|
116
aider/main.py
116
aider/main.py
|
@ -9,10 +9,14 @@ import openai
|
||||||
from aider import __version__, models
|
from aider import __version__, models
|
||||||
from aider.coders import Coder
|
from aider.coders import Coder
|
||||||
from aider.io import InputOutput
|
from aider.io import InputOutput
|
||||||
|
from aider.repo import GitRepo
|
||||||
from aider.versioncheck import check_version
|
from aider.versioncheck import check_version
|
||||||
|
|
||||||
|
from .dump import dump # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
def get_git_root():
|
def get_git_root():
|
||||||
|
"""Try and guess the git repo, since the conf.yml can be at the repo root"""
|
||||||
try:
|
try:
|
||||||
repo = git.Repo(search_parent_directories=True)
|
repo = git.Repo(search_parent_directories=True)
|
||||||
return repo.working_tree_dir
|
return repo.working_tree_dir
|
||||||
|
@ -20,6 +24,25 @@ def get_git_root():
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def guessed_wrong_repo(io, git_root, fnames, git_dname):
|
||||||
|
"""After we parse the args, we can determine the real repo. Did we guess wrong?"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
check_repo = Path(GitRepo(io, fnames, git_dname).root).resolve()
|
||||||
|
except FileNotFoundError:
|
||||||
|
return
|
||||||
|
|
||||||
|
# we had no guess, rely on the "true" repo result
|
||||||
|
if not git_root:
|
||||||
|
return str(check_repo)
|
||||||
|
|
||||||
|
git_root = Path(git_root).resolve()
|
||||||
|
if check_repo == git_root:
|
||||||
|
return
|
||||||
|
|
||||||
|
return str(check_repo)
|
||||||
|
|
||||||
|
|
||||||
def setup_git(git_root, io):
|
def setup_git(git_root, io):
|
||||||
if git_root:
|
if git_root:
|
||||||
return git_root
|
return git_root
|
||||||
|
@ -71,11 +94,14 @@ def check_gitignore(git_root, io, ask=True):
|
||||||
io.tool_output(f"Added {pat} to .gitignore")
|
io.tool_output(f"Added {pat} to .gitignore")
|
||||||
|
|
||||||
|
|
||||||
def main(args=None, input=None, output=None):
|
def main(argv=None, input=None, output=None, force_git_root=None):
|
||||||
if args is None:
|
if argv is None:
|
||||||
args = sys.argv[1:]
|
argv = sys.argv[1:]
|
||||||
|
|
||||||
git_root = get_git_root()
|
if force_git_root:
|
||||||
|
git_root = force_git_root
|
||||||
|
else:
|
||||||
|
git_root = get_git_root()
|
||||||
|
|
||||||
conf_fname = Path(".aider.conf.yml")
|
conf_fname = Path(".aider.conf.yml")
|
||||||
|
|
||||||
|
@ -101,7 +127,7 @@ def main(args=None, input=None, output=None):
|
||||||
"files",
|
"files",
|
||||||
metavar="FILE",
|
metavar="FILE",
|
||||||
nargs="*",
|
nargs="*",
|
||||||
help="a list of source code files to edit with GPT (optional)",
|
help="the directory of a git repo, or a list of files to edit with GPT (optional)",
|
||||||
)
|
)
|
||||||
core_group.add_argument(
|
core_group.add_argument(
|
||||||
"--openai-api-key",
|
"--openai-api-key",
|
||||||
|
@ -344,7 +370,7 @@ def main(args=None, input=None, output=None):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args(args)
|
args = parser.parse_args(argv)
|
||||||
|
|
||||||
if args.dark_mode:
|
if args.dark_mode:
|
||||||
args.user_input_color = "#32FF32"
|
args.user_input_color = "#32FF32"
|
||||||
|
@ -371,6 +397,37 @@ def main(args=None, input=None, output=None):
|
||||||
dry_run=args.dry_run,
|
dry_run=args.dry_run,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
fnames = [str(Path(fn).resolve()) for fn in args.files]
|
||||||
|
if len(args.files) > 1:
|
||||||
|
good = True
|
||||||
|
for fname in args.files:
|
||||||
|
if Path(fname).is_dir():
|
||||||
|
io.tool_error(f"{fname} is a directory, not provided alone.")
|
||||||
|
good = False
|
||||||
|
if not good:
|
||||||
|
io.tool_error(
|
||||||
|
"Provide either a single directory of a git repo, or a list of one or more files."
|
||||||
|
)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
git_dname = None
|
||||||
|
if len(args.files) == 1:
|
||||||
|
if Path(args.files[0]).is_dir():
|
||||||
|
if args.git:
|
||||||
|
git_dname = str(Path(args.files[0]).resolve())
|
||||||
|
fnames = []
|
||||||
|
else:
|
||||||
|
io.tool_error(f"{args.files[0]} is a directory, but --no-git selected.")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
# We can't know the git repo for sure until after parsing the args.
|
||||||
|
# If we guessed wrong, reparse because that changes things like
|
||||||
|
# the location of the config.yml and history files.
|
||||||
|
if args.git and not force_git_root:
|
||||||
|
right_repo_root = guessed_wrong_repo(io, git_root, fnames, git_dname)
|
||||||
|
if right_repo_root:
|
||||||
|
return main(argv, input, output, right_repo_root)
|
||||||
|
|
||||||
io.tool_output(f"Aider v{__version__}")
|
io.tool_output(f"Aider v{__version__}")
|
||||||
|
|
||||||
check_version(io.tool_error)
|
check_version(io.tool_error)
|
||||||
|
@ -418,24 +475,29 @@ def main(args=None, input=None, output=None):
|
||||||
setattr(openai, mod_key, val)
|
setattr(openai, mod_key, val)
|
||||||
io.tool_output(f"Setting openai.{mod_key}={val}")
|
io.tool_output(f"Setting openai.{mod_key}={val}")
|
||||||
|
|
||||||
coder = Coder.create(
|
try:
|
||||||
main_model,
|
coder = Coder.create(
|
||||||
args.edit_format,
|
main_model,
|
||||||
io,
|
args.edit_format,
|
||||||
##
|
io,
|
||||||
fnames=args.files,
|
##
|
||||||
pretty=args.pretty,
|
fnames=fnames,
|
||||||
show_diffs=args.show_diffs,
|
git_dname=git_dname,
|
||||||
auto_commits=args.auto_commits,
|
pretty=args.pretty,
|
||||||
dirty_commits=args.dirty_commits,
|
show_diffs=args.show_diffs,
|
||||||
dry_run=args.dry_run,
|
auto_commits=args.auto_commits,
|
||||||
map_tokens=args.map_tokens,
|
dirty_commits=args.dirty_commits,
|
||||||
verbose=args.verbose,
|
dry_run=args.dry_run,
|
||||||
assistant_output_color=args.assistant_output_color,
|
map_tokens=args.map_tokens,
|
||||||
code_theme=args.code_theme,
|
verbose=args.verbose,
|
||||||
stream=args.stream,
|
assistant_output_color=args.assistant_output_color,
|
||||||
use_git=args.git,
|
code_theme=args.code_theme,
|
||||||
)
|
stream=args.stream,
|
||||||
|
use_git=args.git,
|
||||||
|
)
|
||||||
|
except ValueError as err:
|
||||||
|
io.tool_error(str(err))
|
||||||
|
return 1
|
||||||
|
|
||||||
if args.show_repo_map:
|
if args.show_repo_map:
|
||||||
repo_map = coder.get_repo_map()
|
repo_map = coder.get_repo_map()
|
||||||
|
@ -443,9 +505,6 @@ def main(args=None, input=None, output=None):
|
||||||
io.tool_output(repo_map)
|
io.tool_output(repo_map)
|
||||||
return
|
return
|
||||||
|
|
||||||
if args.dirty_commits:
|
|
||||||
coder.commit(ask=True, which="repo_files")
|
|
||||||
|
|
||||||
if args.apply:
|
if args.apply:
|
||||||
content = io.read_text(args.apply)
|
content = io.read_text(args.apply)
|
||||||
if content is None:
|
if content is None:
|
||||||
|
@ -454,6 +513,9 @@ def main(args=None, input=None, output=None):
|
||||||
return
|
return
|
||||||
|
|
||||||
io.tool_output("Use /help to see in-chat commands, run with --help to see cmd line args")
|
io.tool_output("Use /help to see in-chat commands, run with --help to see cmd line args")
|
||||||
|
|
||||||
|
coder.dirty_commit()
|
||||||
|
|
||||||
if args.message:
|
if args.message:
|
||||||
io.tool_output()
|
io.tool_output()
|
||||||
coder.run(with_message=args.message)
|
coder.run(with_message=args.message)
|
||||||
|
|
177
aider/repo.py
Normal file
177
aider/repo.py
Normal file
|
@ -0,0 +1,177 @@
|
||||||
|
import os
|
||||||
|
from pathlib import Path, PurePosixPath
|
||||||
|
|
||||||
|
import git
|
||||||
|
|
||||||
|
from aider import models, prompts, utils
|
||||||
|
from aider.sendchat import simple_send_with_retries
|
||||||
|
|
||||||
|
from .dump import dump # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
|
class GitRepo:
|
||||||
|
repo = None
|
||||||
|
|
||||||
|
def __init__(self, io, fnames, git_dname):
|
||||||
|
self.io = io
|
||||||
|
|
||||||
|
if git_dname:
|
||||||
|
check_fnames = [git_dname]
|
||||||
|
elif fnames:
|
||||||
|
check_fnames = fnames
|
||||||
|
else:
|
||||||
|
check_fnames = ["."]
|
||||||
|
|
||||||
|
repo_paths = []
|
||||||
|
for fname in check_fnames:
|
||||||
|
fname = Path(fname)
|
||||||
|
fname = fname.resolve()
|
||||||
|
|
||||||
|
if not fname.exists() and fname.parent.exists():
|
||||||
|
fname = fname.parent
|
||||||
|
|
||||||
|
try:
|
||||||
|
repo_path = git.Repo(fname, search_parent_directories=True).working_dir
|
||||||
|
repo_path = utils.safe_abs_path(repo_path)
|
||||||
|
repo_paths.append(repo_path)
|
||||||
|
except git.exc.InvalidGitRepositoryError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
num_repos = len(set(repo_paths))
|
||||||
|
|
||||||
|
if num_repos == 0:
|
||||||
|
raise FileNotFoundError
|
||||||
|
if num_repos > 1:
|
||||||
|
self.io.tool_error("Files are in different git repos.")
|
||||||
|
raise FileNotFoundError
|
||||||
|
|
||||||
|
# https://github.com/gitpython-developers/GitPython/issues/427
|
||||||
|
self.repo = git.Repo(repo_paths.pop(), odbt=git.GitDB)
|
||||||
|
self.root = utils.safe_abs_path(self.repo.working_tree_dir)
|
||||||
|
|
||||||
|
def add_new_files(self, fnames):
|
||||||
|
cur_files = [Path(fn).resolve() for fn in self.get_tracked_files()]
|
||||||
|
for fname in fnames:
|
||||||
|
if Path(fname).resolve() in cur_files:
|
||||||
|
continue
|
||||||
|
if not Path(fname).exists():
|
||||||
|
continue
|
||||||
|
self.io.tool_output(f"Adding {fname} to git")
|
||||||
|
self.repo.git.add(fname)
|
||||||
|
|
||||||
|
def commit(self, context=None, prefix=None, message=None):
|
||||||
|
if not self.repo.is_dirty():
|
||||||
|
return
|
||||||
|
|
||||||
|
if message:
|
||||||
|
commit_message = message
|
||||||
|
else:
|
||||||
|
diffs = self.get_diffs(False)
|
||||||
|
commit_message = self.get_commit_message(diffs, context)
|
||||||
|
|
||||||
|
if not commit_message:
|
||||||
|
commit_message = "(no commit message provided)"
|
||||||
|
|
||||||
|
if prefix:
|
||||||
|
commit_message = prefix + commit_message
|
||||||
|
|
||||||
|
full_commit_message = commit_message
|
||||||
|
if context:
|
||||||
|
full_commit_message += "\n\n# Aider chat conversation:\n\n" + context
|
||||||
|
|
||||||
|
self.repo.git.commit("-a", "-m", full_commit_message, "--no-verify")
|
||||||
|
commit_hash = self.repo.head.commit.hexsha[:7]
|
||||||
|
self.io.tool_output(f"Commit {commit_hash} {commit_message}")
|
||||||
|
|
||||||
|
return commit_hash, commit_message
|
||||||
|
|
||||||
|
def get_rel_repo_dir(self):
|
||||||
|
try:
|
||||||
|
return os.path.relpath(self.repo.git_dir, os.getcwd())
|
||||||
|
except ValueError:
|
||||||
|
return self.repo.git_dir
|
||||||
|
|
||||||
|
def get_commit_message(self, diffs, context):
|
||||||
|
if len(diffs) >= 4 * 1024 * 4:
|
||||||
|
self.io.tool_error(
|
||||||
|
f"Diff is too large for {models.GPT35.name} to generate a commit message."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
diffs = "# Diffs:\n" + diffs
|
||||||
|
|
||||||
|
content = ""
|
||||||
|
if context:
|
||||||
|
content += context + "\n"
|
||||||
|
content += diffs
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
dict(role="system", content=prompts.commit_system),
|
||||||
|
dict(role="user", content=content),
|
||||||
|
]
|
||||||
|
|
||||||
|
for model in [models.GPT35.name, models.GPT35_16k.name]:
|
||||||
|
commit_message = simple_send_with_retries(model, messages)
|
||||||
|
if commit_message:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not commit_message:
|
||||||
|
self.io.tool_error("Failed to generate commit message!")
|
||||||
|
return
|
||||||
|
|
||||||
|
commit_message = commit_message.strip()
|
||||||
|
if commit_message and commit_message[0] == '"' and commit_message[-1] == '"':
|
||||||
|
commit_message = commit_message[1:-1].strip()
|
||||||
|
|
||||||
|
return commit_message
|
||||||
|
|
||||||
|
def get_diffs(self, pretty, *args):
|
||||||
|
try:
|
||||||
|
commits = self.repo.iter_commits(self.repo.active_branch)
|
||||||
|
current_branch_has_commits = any(commits)
|
||||||
|
except git.exc.GitCommandError:
|
||||||
|
current_branch_has_commits = False
|
||||||
|
|
||||||
|
if not current_branch_has_commits:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
if pretty:
|
||||||
|
args = ["--color"] + list(args)
|
||||||
|
if not args:
|
||||||
|
args = ["HEAD"]
|
||||||
|
|
||||||
|
diffs = self.repo.git.diff(*args)
|
||||||
|
return diffs
|
||||||
|
|
||||||
|
def show_diffs(self, pretty):
|
||||||
|
diffs = self.get_diffs(pretty)
|
||||||
|
print(diffs)
|
||||||
|
|
||||||
|
def get_tracked_files(self):
|
||||||
|
if not self.repo:
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
commit = self.repo.head.commit
|
||||||
|
except ValueError:
|
||||||
|
commit = None
|
||||||
|
|
||||||
|
files = []
|
||||||
|
if commit:
|
||||||
|
for blob in commit.tree.traverse():
|
||||||
|
if blob.type == "blob": # blob is a file
|
||||||
|
files.append(blob.path)
|
||||||
|
|
||||||
|
# Add staged files
|
||||||
|
index = self.repo.index
|
||||||
|
staged_files = [path for path, _ in index.entries.keys()]
|
||||||
|
|
||||||
|
files.extend(staged_files)
|
||||||
|
|
||||||
|
# convert to appropriate os.sep, since git always normalizes to /
|
||||||
|
res = set(str(Path(PurePosixPath(path))) for path in files)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
def is_dirty(self):
|
||||||
|
return self.repo.is_dirty()
|
57
aider/sendchat.py
Normal file
57
aider/sendchat.py
Normal file
|
@ -0,0 +1,57 @@
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
|
||||||
|
import backoff
|
||||||
|
import openai
|
||||||
|
import requests
|
||||||
|
from openai.error import APIError, RateLimitError, ServiceUnavailableError, Timeout
|
||||||
|
|
||||||
|
|
||||||
|
@backoff.on_exception(
|
||||||
|
backoff.expo,
|
||||||
|
(
|
||||||
|
Timeout,
|
||||||
|
APIError,
|
||||||
|
ServiceUnavailableError,
|
||||||
|
RateLimitError,
|
||||||
|
requests.exceptions.ConnectionError,
|
||||||
|
),
|
||||||
|
max_tries=10,
|
||||||
|
on_backoff=lambda details: print(
|
||||||
|
f"{details.get('exception','Exception')}\nRetry in {details['wait']:.1f} seconds."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
def send_with_retries(model, messages, functions, stream):
|
||||||
|
kwargs = dict(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
temperature=0,
|
||||||
|
stream=stream,
|
||||||
|
)
|
||||||
|
if functions is not None:
|
||||||
|
kwargs["functions"] = functions
|
||||||
|
|
||||||
|
# we are abusing the openai object to stash these values
|
||||||
|
if hasattr(openai, "api_deployment_id"):
|
||||||
|
kwargs["deployment_id"] = openai.api_deployment_id
|
||||||
|
if hasattr(openai, "api_engine"):
|
||||||
|
kwargs["engine"] = openai.api_engine
|
||||||
|
|
||||||
|
# Generate SHA1 hash of kwargs and append it to chat_completion_call_hashes
|
||||||
|
hash_object = hashlib.sha1(json.dumps(kwargs, sort_keys=True).encode())
|
||||||
|
|
||||||
|
res = openai.ChatCompletion.create(**kwargs)
|
||||||
|
return hash_object, res
|
||||||
|
|
||||||
|
|
||||||
|
def simple_send_with_retries(model, messages):
|
||||||
|
try:
|
||||||
|
_hash, response = send_with_retries(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
functions=None,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
return response.choices[0].message.content
|
||||||
|
except (AttributeError, openai.error.InvalidRequestError):
|
||||||
|
return
|
|
@ -1,4 +1,3 @@
|
||||||
import os
|
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -6,7 +5,6 @@ from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import git
|
import git
|
||||||
import openai
|
import openai
|
||||||
import requests
|
|
||||||
|
|
||||||
from aider import models
|
from aider import models
|
||||||
from aider.coders import Coder
|
from aider.coders import Coder
|
||||||
|
@ -77,7 +75,7 @@ class TestCoder(unittest.TestCase):
|
||||||
# Mock the git repo
|
# Mock the git repo
|
||||||
mock = MagicMock()
|
mock = MagicMock()
|
||||||
mock.return_value = set(["file1.txt", "file2.py"])
|
mock.return_value = set(["file1.txt", "file2.py"])
|
||||||
coder.get_tracked_files = mock
|
coder.repo.get_tracked_files = mock
|
||||||
|
|
||||||
# Call the check_for_file_mentions method
|
# Call the check_for_file_mentions method
|
||||||
coder.check_for_file_mentions("Please check file1.txt and file2.py")
|
coder.check_for_file_mentions("Please check file1.txt and file2.py")
|
||||||
|
@ -121,7 +119,7 @@ class TestCoder(unittest.TestCase):
|
||||||
|
|
||||||
mock = MagicMock()
|
mock = MagicMock()
|
||||||
mock.return_value = set(["file1.txt", "file2.py"])
|
mock.return_value = set(["file1.txt", "file2.py"])
|
||||||
coder.get_tracked_files = mock
|
coder.repo.get_tracked_files = mock
|
||||||
|
|
||||||
# Call the check_for_file_mentions method
|
# Call the check_for_file_mentions method
|
||||||
coder.check_for_file_mentions("Please check file1.txt and file2.py")
|
coder.check_for_file_mentions("Please check file1.txt and file2.py")
|
||||||
|
@ -152,7 +150,7 @@ class TestCoder(unittest.TestCase):
|
||||||
|
|
||||||
mock = MagicMock()
|
mock = MagicMock()
|
||||||
mock.return_value = set([str(fname), str(other_fname)])
|
mock.return_value = set([str(fname), str(other_fname)])
|
||||||
coder.get_tracked_files = mock
|
coder.repo.get_tracked_files = mock
|
||||||
|
|
||||||
# Call the check_for_file_mentions method
|
# Call the check_for_file_mentions method
|
||||||
coder.check_for_file_mentions(f"Please check {fname}!")
|
coder.check_for_file_mentions(f"Please check {fname}!")
|
||||||
|
@ -170,7 +168,7 @@ class TestCoder(unittest.TestCase):
|
||||||
|
|
||||||
mock = MagicMock()
|
mock = MagicMock()
|
||||||
mock.return_value = set([str(fname)])
|
mock.return_value = set([str(fname)])
|
||||||
coder.get_tracked_files = mock
|
coder.repo.get_tracked_files = mock
|
||||||
|
|
||||||
dump(fname)
|
dump(fname)
|
||||||
# Call the check_for_file_mentions method
|
# Call the check_for_file_mentions method
|
||||||
|
@ -178,110 +176,6 @@ class TestCoder(unittest.TestCase):
|
||||||
|
|
||||||
self.assertEqual(coder.abs_fnames, set([str(fname.resolve())]))
|
self.assertEqual(coder.abs_fnames, set([str(fname.resolve())]))
|
||||||
|
|
||||||
def test_get_commit_message(self):
|
|
||||||
# Mock the IO object
|
|
||||||
mock_io = MagicMock()
|
|
||||||
|
|
||||||
# Initialize the Coder object with the mocked IO and mocked repo
|
|
||||||
coder = Coder.create(models.GPT4, None, mock_io)
|
|
||||||
|
|
||||||
# Mock the send method to set partial_response_content and return False
|
|
||||||
def mock_send(*args, **kwargs):
|
|
||||||
coder.partial_response_content = "a good commit message"
|
|
||||||
return False
|
|
||||||
|
|
||||||
coder.send = MagicMock(side_effect=mock_send)
|
|
||||||
|
|
||||||
# Call the get_commit_message method with dummy diff and context
|
|
||||||
result = coder.get_commit_message("dummy diff", "dummy context")
|
|
||||||
|
|
||||||
# Assert that the returned message is the expected one
|
|
||||||
self.assertEqual(result, "a good commit message")
|
|
||||||
|
|
||||||
def test_get_commit_message_strip_quotes(self):
|
|
||||||
# Mock the IO object
|
|
||||||
mock_io = MagicMock()
|
|
||||||
|
|
||||||
# Initialize the Coder object with the mocked IO and mocked repo
|
|
||||||
coder = Coder.create(models.GPT4, None, mock_io)
|
|
||||||
|
|
||||||
# Mock the send method to set partial_response_content and return False
|
|
||||||
def mock_send(*args, **kwargs):
|
|
||||||
coder.partial_response_content = "a good commit message"
|
|
||||||
return False
|
|
||||||
|
|
||||||
coder.send = MagicMock(side_effect=mock_send)
|
|
||||||
|
|
||||||
# Call the get_commit_message method with dummy diff and context
|
|
||||||
result = coder.get_commit_message("dummy diff", "dummy context")
|
|
||||||
|
|
||||||
# Assert that the returned message is the expected one
|
|
||||||
self.assertEqual(result, "a good commit message")
|
|
||||||
|
|
||||||
def test_get_commit_message_no_strip_unmatched_quotes(self):
|
|
||||||
# Mock the IO object
|
|
||||||
mock_io = MagicMock()
|
|
||||||
|
|
||||||
# Initialize the Coder object with the mocked IO and mocked repo
|
|
||||||
coder = Coder.create(models.GPT4, None, mock_io)
|
|
||||||
|
|
||||||
# Mock the send method to set partial_response_content and return False
|
|
||||||
def mock_send(*args, **kwargs):
|
|
||||||
coder.partial_response_content = 'a good "commit message"'
|
|
||||||
return False
|
|
||||||
|
|
||||||
coder.send = MagicMock(side_effect=mock_send)
|
|
||||||
|
|
||||||
# Call the get_commit_message method with dummy diff and context
|
|
||||||
result = coder.get_commit_message("dummy diff", "dummy context")
|
|
||||||
|
|
||||||
# Assert that the returned message is the expected one
|
|
||||||
self.assertEqual(result, 'a good "commit message"')
|
|
||||||
|
|
||||||
@patch("aider.coders.base_coder.openai.ChatCompletion.create")
|
|
||||||
@patch("builtins.print")
|
|
||||||
def test_send_with_retries_rate_limit_error(self, mock_print, mock_chat_completion_create):
|
|
||||||
# Mock the IO object
|
|
||||||
mock_io = MagicMock()
|
|
||||||
|
|
||||||
# Initialize the Coder object with the mocked IO and mocked repo
|
|
||||||
coder = Coder.create(models.GPT4, None, mock_io)
|
|
||||||
|
|
||||||
# Set up the mock to raise RateLimitError on
|
|
||||||
# the first call and return None on the second call
|
|
||||||
mock_chat_completion_create.side_effect = [
|
|
||||||
openai.error.RateLimitError("Rate limit exceeded"),
|
|
||||||
None,
|
|
||||||
]
|
|
||||||
|
|
||||||
# Call the send_with_retries method
|
|
||||||
coder.send_with_retries("model", ["message"], None)
|
|
||||||
|
|
||||||
# Assert that print was called once
|
|
||||||
mock_print.assert_called_once()
|
|
||||||
|
|
||||||
@patch("aider.coders.base_coder.openai.ChatCompletion.create")
|
|
||||||
@patch("builtins.print")
|
|
||||||
def test_send_with_retries_connection_error(self, mock_print, mock_chat_completion_create):
|
|
||||||
# Mock the IO object
|
|
||||||
mock_io = MagicMock()
|
|
||||||
|
|
||||||
# Initialize the Coder object with the mocked IO and mocked repo
|
|
||||||
coder = Coder.create(models.GPT4, None, mock_io)
|
|
||||||
|
|
||||||
# Set up the mock to raise ConnectionError on the first call
|
|
||||||
# and return None on the second call
|
|
||||||
mock_chat_completion_create.side_effect = [
|
|
||||||
requests.exceptions.ConnectionError("Connection error"),
|
|
||||||
None,
|
|
||||||
]
|
|
||||||
|
|
||||||
# Call the send_with_retries method
|
|
||||||
coder.send_with_retries("model", ["message"], None)
|
|
||||||
|
|
||||||
# Assert that print was called once
|
|
||||||
mock_print.assert_called_once()
|
|
||||||
|
|
||||||
def test_run_with_file_deletion(self):
|
def test_run_with_file_deletion(self):
|
||||||
# Create a few temporary files
|
# Create a few temporary files
|
||||||
|
|
||||||
|
@ -419,49 +313,5 @@ class TestCoder(unittest.TestCase):
|
||||||
with self.assertRaises(openai.error.InvalidRequestError):
|
with self.assertRaises(openai.error.InvalidRequestError):
|
||||||
coder.run(with_message="hi")
|
coder.run(with_message="hi")
|
||||||
|
|
||||||
def test_get_tracked_files(self):
|
|
||||||
# Create a temporary directory
|
|
||||||
tempdir = Path(tempfile.mkdtemp())
|
|
||||||
|
|
||||||
# Initialize a git repository in the temporary directory and set user name and email
|
|
||||||
repo = git.Repo.init(tempdir)
|
|
||||||
repo.config_writer().set_value("user", "name", "Test User").release()
|
|
||||||
repo.config_writer().set_value("user", "email", "testuser@example.com").release()
|
|
||||||
|
|
||||||
# Create three empty files and add them to the git repository
|
|
||||||
filenames = ["README.md", "subdir/fänny.md", "systemüber/blick.md", 'file"with"quotes.txt']
|
|
||||||
created_files = []
|
|
||||||
for filename in filenames:
|
|
||||||
file_path = tempdir / filename
|
|
||||||
try:
|
|
||||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
file_path.touch()
|
|
||||||
repo.git.add(str(file_path))
|
|
||||||
created_files.append(Path(filename))
|
|
||||||
except OSError:
|
|
||||||
# windows won't allow files with quotes, that's ok
|
|
||||||
self.assertIn('"', filename)
|
|
||||||
self.assertEqual(os.name, "nt")
|
|
||||||
|
|
||||||
self.assertTrue(len(created_files) >= 3)
|
|
||||||
|
|
||||||
repo.git.commit("-m", "added")
|
|
||||||
|
|
||||||
# Create a Coder object on the temporary directory
|
|
||||||
coder = Coder.create(
|
|
||||||
models.GPT4,
|
|
||||||
None,
|
|
||||||
io=InputOutput(),
|
|
||||||
fnames=[str(tempdir / filenames[0])],
|
|
||||||
)
|
|
||||||
|
|
||||||
tracked_files = coder.get_tracked_files()
|
|
||||||
|
|
||||||
# On windows, paths will come back \like\this, so normalize them back to Paths
|
|
||||||
tracked_files = [Path(fn) for fn in tracked_files]
|
|
||||||
|
|
||||||
# Assert that coder.get_tracked_files() returns the three filenames
|
|
||||||
self.assertEqual(set(tracked_files), set(created_files))
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
@ -35,14 +35,42 @@ class TestMain(TestCase):
|
||||||
main(["--no-git"], input=DummyInput(), output=DummyOutput())
|
main(["--no-git"], input=DummyInput(), output=DummyOutput())
|
||||||
|
|
||||||
def test_main_with_empty_dir_new_file(self):
|
def test_main_with_empty_dir_new_file(self):
|
||||||
main(["foo.txt", "--yes"], input=DummyInput(), output=DummyOutput())
|
main(["foo.txt", "--yes", "--no-git"], input=DummyInput(), output=DummyOutput())
|
||||||
self.assertTrue(os.path.exists("foo.txt"))
|
self.assertTrue(os.path.exists("foo.txt"))
|
||||||
|
|
||||||
def test_main_with_empty_git_dir_new_file(self):
|
@patch("aider.repo.GitRepo.get_commit_message", return_value="mock commit message")
|
||||||
|
def test_main_with_empty_git_dir_new_file(self, _):
|
||||||
make_repo()
|
make_repo()
|
||||||
main(["--yes", "foo.txt"], input=DummyInput(), output=DummyOutput())
|
main(["--yes", "foo.txt"], input=DummyInput(), output=DummyOutput())
|
||||||
self.assertTrue(os.path.exists("foo.txt"))
|
self.assertTrue(os.path.exists("foo.txt"))
|
||||||
|
|
||||||
|
@patch("aider.repo.GitRepo.get_commit_message", return_value="mock commit message")
|
||||||
|
def test_main_with_empty_git_dir_new_files(self, _):
|
||||||
|
make_repo()
|
||||||
|
main(["--yes", "foo.txt", "bar.txt"], input=DummyInput(), output=DummyOutput())
|
||||||
|
self.assertTrue(os.path.exists("foo.txt"))
|
||||||
|
self.assertTrue(os.path.exists("bar.txt"))
|
||||||
|
|
||||||
|
def test_main_with_dname_and_fname(self):
|
||||||
|
subdir = Path("subdir")
|
||||||
|
subdir.mkdir()
|
||||||
|
make_repo(str(subdir))
|
||||||
|
res = main(["subdir", "foo.txt"], input=DummyInput(), output=DummyOutput())
|
||||||
|
self.assertNotEqual(res, None)
|
||||||
|
|
||||||
|
@patch("aider.repo.GitRepo.get_commit_message", return_value="mock commit message")
|
||||||
|
def test_main_with_subdir_repo_fnames(self, _):
|
||||||
|
subdir = Path("subdir")
|
||||||
|
subdir.mkdir()
|
||||||
|
make_repo(str(subdir))
|
||||||
|
main(
|
||||||
|
["--yes", str(subdir / "foo.txt"), str(subdir / "bar.txt")],
|
||||||
|
input=DummyInput(),
|
||||||
|
output=DummyOutput(),
|
||||||
|
)
|
||||||
|
self.assertTrue((subdir / "foo.txt").exists())
|
||||||
|
self.assertTrue((subdir / "bar.txt").exists())
|
||||||
|
|
||||||
def test_main_with_git_config_yml(self):
|
def test_main_with_git_config_yml(self):
|
||||||
make_repo()
|
make_repo()
|
||||||
|
|
||||||
|
|
114
tests/test_repo.py
Normal file
114
tests/test_repo.py
Normal file
|
@ -0,0 +1,114 @@
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import git
|
||||||
|
|
||||||
|
from aider.dump import dump # noqa: F401
|
||||||
|
from aider.io import InputOutput
|
||||||
|
from aider.repo import GitRepo
|
||||||
|
from tests.utils import GitTemporaryDirectory
|
||||||
|
|
||||||
|
|
||||||
|
class TestRepo(unittest.TestCase):
|
||||||
|
@patch("aider.repo.simple_send_with_retries")
|
||||||
|
def test_get_commit_message(self, mock_send):
|
||||||
|
mock_send.return_value = "a good commit message"
|
||||||
|
|
||||||
|
repo = GitRepo(InputOutput(), None, None)
|
||||||
|
# Call the get_commit_message method with dummy diff and context
|
||||||
|
result = repo.get_commit_message("dummy diff", "dummy context")
|
||||||
|
|
||||||
|
# Assert that the returned message is the expected one
|
||||||
|
self.assertEqual(result, "a good commit message")
|
||||||
|
|
||||||
|
@patch("aider.repo.simple_send_with_retries")
|
||||||
|
def test_get_commit_message_strip_quotes(self, mock_send):
|
||||||
|
mock_send.return_value = '"a good commit message"'
|
||||||
|
|
||||||
|
repo = GitRepo(InputOutput(), None, None)
|
||||||
|
# Call the get_commit_message method with dummy diff and context
|
||||||
|
result = repo.get_commit_message("dummy diff", "dummy context")
|
||||||
|
|
||||||
|
# Assert that the returned message is the expected one
|
||||||
|
self.assertEqual(result, "a good commit message")
|
||||||
|
|
||||||
|
@patch("aider.repo.simple_send_with_retries")
|
||||||
|
def test_get_commit_message_no_strip_unmatched_quotes(self, mock_send):
|
||||||
|
mock_send.return_value = 'a good "commit message"'
|
||||||
|
|
||||||
|
repo = GitRepo(InputOutput(), None, None)
|
||||||
|
# Call the get_commit_message method with dummy diff and context
|
||||||
|
result = repo.get_commit_message("dummy diff", "dummy context")
|
||||||
|
|
||||||
|
# Assert that the returned message is the expected one
|
||||||
|
self.assertEqual(result, 'a good "commit message"')
|
||||||
|
|
||||||
|
def test_get_tracked_files(self):
|
||||||
|
# Create a temporary directory
|
||||||
|
tempdir = Path(tempfile.mkdtemp())
|
||||||
|
|
||||||
|
# Initialize a git repository in the temporary directory and set user name and email
|
||||||
|
repo = git.Repo.init(tempdir)
|
||||||
|
repo.config_writer().set_value("user", "name", "Test User").release()
|
||||||
|
repo.config_writer().set_value("user", "email", "testuser@example.com").release()
|
||||||
|
|
||||||
|
# Create three empty files and add them to the git repository
|
||||||
|
filenames = ["README.md", "subdir/fänny.md", "systemüber/blick.md", 'file"with"quotes.txt']
|
||||||
|
created_files = []
|
||||||
|
for filename in filenames:
|
||||||
|
file_path = tempdir / filename
|
||||||
|
try:
|
||||||
|
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
file_path.touch()
|
||||||
|
repo.git.add(str(file_path))
|
||||||
|
created_files.append(Path(filename))
|
||||||
|
except OSError:
|
||||||
|
# windows won't allow files with quotes, that's ok
|
||||||
|
self.assertIn('"', filename)
|
||||||
|
self.assertEqual(os.name, "nt")
|
||||||
|
|
||||||
|
self.assertTrue(len(created_files) >= 3)
|
||||||
|
|
||||||
|
repo.git.commit("-m", "added")
|
||||||
|
|
||||||
|
tracked_files = GitRepo(InputOutput(), [tempdir], None).get_tracked_files()
|
||||||
|
|
||||||
|
# On windows, paths will come back \like\this, so normalize them back to Paths
|
||||||
|
tracked_files = [Path(fn) for fn in tracked_files]
|
||||||
|
|
||||||
|
# Assert that coder.get_tracked_files() returns the three filenames
|
||||||
|
self.assertEqual(set(tracked_files), set(created_files))
|
||||||
|
|
||||||
|
def test_get_tracked_files_with_new_staged_file(self):
|
||||||
|
with GitTemporaryDirectory():
|
||||||
|
# new repo
|
||||||
|
raw_repo = git.Repo()
|
||||||
|
|
||||||
|
# add it, but no commits at all in the raw_repo yet
|
||||||
|
fname = Path("new.txt")
|
||||||
|
fname.touch()
|
||||||
|
raw_repo.git.add(str(fname))
|
||||||
|
|
||||||
|
git_repo = GitRepo(InputOutput(), None, None)
|
||||||
|
|
||||||
|
# better be there
|
||||||
|
fnames = git_repo.get_tracked_files()
|
||||||
|
self.assertIn(str(fname), fnames)
|
||||||
|
|
||||||
|
# commit it, better still be there
|
||||||
|
raw_repo.git.commit("-m", "new")
|
||||||
|
fnames = git_repo.get_tracked_files()
|
||||||
|
self.assertIn(str(fname), fnames)
|
||||||
|
|
||||||
|
# new file, added but not committed
|
||||||
|
fname2 = Path("new2.txt")
|
||||||
|
fname2.touch()
|
||||||
|
raw_repo.git.add(str(fname2))
|
||||||
|
|
||||||
|
# both should be there
|
||||||
|
fnames = git_repo.get_tracked_files()
|
||||||
|
self.assertIn(str(fname), fnames)
|
||||||
|
self.assertIn(str(fname2), fnames)
|
|
@ -4,7 +4,6 @@ from unittest.mock import patch
|
||||||
|
|
||||||
from aider.io import InputOutput
|
from aider.io import InputOutput
|
||||||
from aider.repomap import RepoMap
|
from aider.repomap import RepoMap
|
||||||
|
|
||||||
from tests.utils import IgnorantTemporaryDirectory
|
from tests.utils import IgnorantTemporaryDirectory
|
||||||
|
|
||||||
|
|
||||||
|
|
41
tests/test_sendchat.py
Normal file
41
tests/test_sendchat.py
Normal file
|
@ -0,0 +1,41 @@
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import openai
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from aider.sendchat import send_with_retries
|
||||||
|
|
||||||
|
|
||||||
|
class TestSendChat(unittest.TestCase):
|
||||||
|
@patch("aider.sendchat.openai.ChatCompletion.create")
|
||||||
|
@patch("builtins.print")
|
||||||
|
def test_send_with_retries_rate_limit_error(self, mock_print, mock_chat_completion_create):
|
||||||
|
# Set up the mock to raise RateLimitError on
|
||||||
|
# the first call and return None on the second call
|
||||||
|
mock_chat_completion_create.side_effect = [
|
||||||
|
openai.error.RateLimitError("Rate limit exceeded"),
|
||||||
|
None,
|
||||||
|
]
|
||||||
|
|
||||||
|
# Call the send_with_retries method
|
||||||
|
send_with_retries("model", ["message"], None, False)
|
||||||
|
|
||||||
|
# Assert that print was called once
|
||||||
|
mock_print.assert_called_once()
|
||||||
|
|
||||||
|
@patch("aider.sendchat.openai.ChatCompletion.create")
|
||||||
|
@patch("builtins.print")
|
||||||
|
def test_send_with_retries_connection_error(self, mock_print, mock_chat_completion_create):
|
||||||
|
# Set up the mock to raise ConnectionError on the first call
|
||||||
|
# and return None on the second call
|
||||||
|
mock_chat_completion_create.side_effect = [
|
||||||
|
requests.exceptions.ConnectionError("Connection error"),
|
||||||
|
None,
|
||||||
|
]
|
||||||
|
|
||||||
|
# Call the send_with_retries method
|
||||||
|
send_with_retries("model", ["message"], None, False)
|
||||||
|
|
||||||
|
# Assert that print was called once
|
||||||
|
mock_print.assert_called_once()
|
|
@ -42,7 +42,11 @@ class GitTemporaryDirectory(ChdirTemporaryDirectory):
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def make_repo():
|
def make_repo(path=None):
|
||||||
repo = git.Repo.init()
|
if not path:
|
||||||
|
path = "."
|
||||||
|
repo = git.Repo.init(path)
|
||||||
repo.config_writer().set_value("user", "name", "Test User").release()
|
repo.config_writer().set_value("user", "name", "Test User").release()
|
||||||
repo.config_writer().set_value("user", "email", "testuser@example.com").release()
|
repo.config_writer().set_value("user", "email", "testuser@example.com").release()
|
||||||
|
|
||||||
|
return repo
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue