Merge pull request #137 from paul-gauthier/refactor-repo

Refactor git repo code into a new file
This commit is contained in:
paul-gauthier 2023-07-26 07:24:10 -03:00 committed by GitHub
commit 86309f336c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 642 additions and 570 deletions

View file

@ -7,21 +7,19 @@ import sys
import time import time
import traceback import traceback
from json.decoder import JSONDecodeError from json.decoder import JSONDecodeError
from pathlib import Path, PurePosixPath from pathlib import Path
import backoff
import git
import openai import openai
import requests
from jsonschema import Draft7Validator from jsonschema import Draft7Validator
from openai.error import APIError, RateLimitError, ServiceUnavailableError, Timeout
from rich.console import Console, Text from rich.console import Console, Text
from rich.live import Live from rich.live import Live
from rich.markdown import Markdown from rich.markdown import Markdown
from aider import models, prompts, utils from aider import models, prompts, utils
from aider.commands import Commands from aider.commands import Commands
from aider.repo import GitRepo
from aider.repomap import RepoMap from aider.repomap import RepoMap
from aider.sendchat import send_with_retries
from ..dump import dump # noqa: F401 from ..dump import dump # noqa: F401
@ -100,6 +98,7 @@ class Coder:
main_model, main_model,
io, io,
fnames=None, fnames=None,
git_dname=None,
pretty=True, pretty=True,
show_diffs=False, show_diffs=False,
auto_commits=True, auto_commits=True,
@ -150,13 +149,27 @@ class Coder:
self.commands = Commands(self.io, self) self.commands = Commands(self.io, self)
for fname in fnames:
fname = Path(fname)
if not fname.exists():
self.io.tool_output(f"Creating empty file {fname}")
fname.parent.mkdir(parents=True, exist_ok=True)
fname.touch()
if not fname.is_file():
raise ValueError(f"{fname} is not a file")
self.abs_fnames.add(str(fname.resolve()))
if use_git: if use_git:
self.set_repo(fnames) try:
else: self.repo = GitRepo(self.io, fnames, git_dname)
self.abs_fnames = set([str(Path(fname).resolve()) for fname in fnames]) self.root = self.repo.root
except FileNotFoundError:
self.repo = None
if self.repo: if self.repo:
rel_repo_dir = self.get_rel_repo_dir() rel_repo_dir = self.repo.get_rel_repo_dir()
self.io.tool_output(f"Git repo: {rel_repo_dir}") self.io.tool_output(f"Git repo: {rel_repo_dir}")
else: else:
self.io.tool_output("Git repo: none") self.io.tool_output("Git repo: none")
@ -187,6 +200,9 @@ class Coder:
for fname in self.get_inchat_relative_files(): for fname in self.get_inchat_relative_files():
self.io.tool_output(f"Added {fname} to the chat.") self.io.tool_output(f"Added {fname} to the chat.")
if self.repo:
self.repo.add_new_files(fname for fname in fnames if not Path(fname).is_dir())
# validate the functions jsonschema # validate the functions jsonschema
if self.functions: if self.functions:
for function in self.functions: for function in self.functions:
@ -206,12 +222,6 @@ class Coder:
self.root = utils.safe_abs_path(self.root) self.root = utils.safe_abs_path(self.root)
def get_rel_repo_dir(self):
try:
return os.path.relpath(self.repo.git_dir, os.getcwd())
except ValueError:
return self.repo.git_dir
def add_rel_fname(self, rel_fname): def add_rel_fname(self, rel_fname):
self.abs_fnames.add(self.abs_root_path(rel_fname)) self.abs_fnames.add(self.abs_root_path(rel_fname))
@ -219,73 +229,6 @@ class Coder:
res = Path(self.root) / path res = Path(self.root) / path
return utils.safe_abs_path(res) return utils.safe_abs_path(res)
def set_repo(self, cmd_line_fnames):
if not cmd_line_fnames:
cmd_line_fnames = ["."]
repo_paths = []
for fname in cmd_line_fnames:
fname = Path(fname)
if not fname.exists():
self.io.tool_output(f"Creating empty file {fname}")
fname.parent.mkdir(parents=True, exist_ok=True)
fname.touch()
fname = fname.resolve()
try:
repo_path = git.Repo(fname, search_parent_directories=True).working_dir
repo_path = utils.safe_abs_path(repo_path)
repo_paths.append(repo_path)
except git.exc.InvalidGitRepositoryError:
pass
if fname.is_dir():
continue
self.abs_fnames.add(str(fname))
num_repos = len(set(repo_paths))
if num_repos == 0:
return
if num_repos > 1:
self.io.tool_error("Files are in different git repos.")
return
# https://github.com/gitpython-developers/GitPython/issues/427
self.repo = git.Repo(repo_paths.pop(), odbt=git.GitDB)
self.root = utils.safe_abs_path(self.repo.working_tree_dir)
new_files = []
for fname in self.abs_fnames:
relative_fname = self.get_rel_fname(fname)
tracked_files = set(self.get_tracked_files())
if relative_fname not in tracked_files:
new_files.append(relative_fname)
if new_files:
rel_repo_dir = self.get_rel_repo_dir()
self.io.tool_output(f"Files not tracked in {rel_repo_dir}:")
for fn in new_files:
self.io.tool_output(f" - {fn}")
if self.io.confirm_ask("Add them?"):
for relative_fname in new_files:
self.repo.git.add(relative_fname)
self.io.tool_output(f"Added {relative_fname} to the git repo")
show_files = ", ".join(new_files)
commit_message = f"Added new files to the git repo: {show_files}"
self.repo.git.commit("-m", commit_message, "--no-verify")
commit_hash = self.repo.head.commit.hexsha[:7]
self.io.tool_output(f"Commit {commit_hash} {commit_message}")
else:
self.io.tool_error("Skipped adding new files to the git repo.")
return
# fences are obfuscated so aider can modify this file!
fences = [ fences = [
("``" + "`", "``" + "`"), ("``" + "`", "``" + "`"),
wrap_fence("source"), wrap_fence("source"),
@ -412,25 +355,6 @@ class Coder:
self.last_keyboard_interrupt = now self.last_keyboard_interrupt = now
def should_dirty_commit(self, inp):
cmds = self.commands.matching_commands(inp)
if cmds:
matching_commands, _, _ = cmds
if len(matching_commands) == 1:
cmd = matching_commands[0][1:]
if cmd in "add clear commit diff drop exit help ls tokens".split():
return
if not self.dirty_commits:
return
if not self.repo:
return
if not self.repo.is_dirty():
return
if self.last_asked_for_commit_time >= self.get_last_modified():
return
return True
def move_back_cur_messages(self, message): def move_back_cur_messages(self, message):
self.done_messages += self.cur_messages self.done_messages += self.cur_messages
if message: if message:
@ -448,13 +372,7 @@ class Coder:
self.commands, self.commands,
) )
if self.should_dirty_commit(inp): if self.should_dirty_commit(inp) and self.dirty_commit():
self.io.tool_output("Git repo has uncommitted changes, preparing commit...")
self.commit(ask=True, which="repo_files")
# files changed, move cur messages back behind the files messages
self.move_back_cur_messages(self.gpt_prompts.files_content_local_edits)
if inp.strip(): if inp.strip():
self.io.tool_output("Use up-arrow to retry previous command:", inp) self.io.tool_output("Use up-arrow to retry previous command:", inp)
return return
@ -569,23 +487,6 @@ class Coder:
) )
] ]
def auto_commit(self):
res = self.commit(history=self.cur_messages, prefix="aider: ")
if res:
commit_hash, commit_message = res
self.last_aider_commit_hash = commit_hash
saved_message = self.gpt_prompts.files_content_gpt_edits.format(
hash=commit_hash,
message=commit_message,
)
else:
if self.repo:
self.io.tool_output("No changes made to git tracked files.")
saved_message = self.gpt_prompts.files_content_gpt_no_edits
return saved_message
def check_for_file_mentions(self, content): def check_for_file_mentions(self, content):
words = set(word for word in content.split()) words = set(word for word in content.split())
@ -627,44 +528,7 @@ class Coder:
return prompts.added_files.format(fnames=", ".join(mentioned_rel_fnames)) return prompts.added_files.format(fnames=", ".join(mentioned_rel_fnames))
@backoff.on_exception( def send(self, messages, model=None, functions=None):
backoff.expo,
(
Timeout,
APIError,
ServiceUnavailableError,
RateLimitError,
requests.exceptions.ConnectionError,
),
max_tries=10,
on_backoff=lambda details: print(
f"{details.get('exception','Exception')}\nRetry in {details['wait']:.1f} seconds."
),
)
def send_with_retries(self, model, messages, functions):
kwargs = dict(
model=model,
messages=messages,
temperature=0,
stream=self.stream,
)
if functions is not None:
kwargs["functions"] = self.functions
# we are abusing the openai object to stash these values
if hasattr(openai, "api_deployment_id"):
kwargs["deployment_id"] = openai.api_deployment_id
if hasattr(openai, "api_engine"):
kwargs["engine"] = openai.api_engine
# Generate SHA1 hash of kwargs and append it to chat_completion_call_hashes
hash_object = hashlib.sha1(json.dumps(kwargs, sort_keys=True).encode())
self.chat_completion_call_hashes.append(hash_object.hexdigest())
res = openai.ChatCompletion.create(**kwargs)
return res
def send(self, messages, model=None, silent=False, functions=None):
if not model: if not model:
model = self.main_model.name model = self.main_model.name
@ -673,16 +537,17 @@ class Coder:
interrupted = False interrupted = False
try: try:
completion = self.send_with_retries(model, messages, functions) hash_object, completion = send_with_retries(model, messages, functions, self.stream)
self.chat_completion_call_hashes.append(hash_object.hexdigest())
if self.stream: if self.stream:
self.show_send_output_stream(completion, silent) self.show_send_output_stream(completion)
else: else:
self.show_send_output(completion, silent) self.show_send_output(completion)
except KeyboardInterrupt: except KeyboardInterrupt:
self.keyboard_interrupt() self.keyboard_interrupt()
interrupted = True interrupted = True
if not silent:
if self.partial_response_content: if self.partial_response_content:
self.io.ai_output(self.partial_response_content) self.io.ai_output(self.partial_response_content)
elif self.partial_response_function_call: elif self.partial_response_function_call:
@ -693,7 +558,7 @@ class Coder:
return interrupted return interrupted
def show_send_output(self, completion, silent): def show_send_output(self, completion):
if self.verbose: if self.verbose:
print(completion) print(completion)
@ -742,9 +607,9 @@ class Coder:
self.io.console.print(show_resp) self.io.console.print(show_resp)
self.io.console.print(tokens) self.io.console.print(tokens)
def show_send_output_stream(self, completion, silent): def show_send_output_stream(self, completion):
live = None live = None
if self.pretty and not silent: if self.pretty:
live = Live(vertical_overflow="scroll") live = Live(vertical_overflow="scroll")
try: try:
@ -773,9 +638,6 @@ class Coder:
except AttributeError: except AttributeError:
pass pass
if silent:
continue
if self.pretty: if self.pretty:
self.live_incremental_response(live, False) self.live_incremental_response(live, False)
else: else:
@ -797,145 +659,6 @@ class Coder:
def render_incremental_response(self, final): def render_incremental_response(self, final):
return self.partial_response_content return self.partial_response_content
def get_context_from_history(self, history):
context = ""
if history:
for msg in history:
context += "\n" + msg["role"].upper() + ": " + msg["content"] + "\n"
return context
def get_commit_message(self, diffs, context):
if len(diffs) >= 4 * 1024 * 4:
self.io.tool_error(
f"Diff is too large for {models.GPT35.name} to generate a commit message."
)
return
diffs = "# Diffs:\n" + diffs
messages = [
dict(role="system", content=prompts.commit_system),
dict(role="user", content=context + diffs),
]
try:
interrupted = self.send(
messages,
model=models.GPT35.name,
silent=True,
)
except openai.error.InvalidRequestError:
self.io.tool_error(
f"Failed to generate commit message using {models.GPT35.name} due to an invalid"
" request."
)
return
commit_message = self.partial_response_content
commit_message = commit_message.strip()
if commit_message and commit_message[0] == '"' and commit_message[-1] == '"':
commit_message = commit_message[1:-1].strip()
if interrupted:
self.io.tool_error(
f"Unable to get commit message from {models.GPT35.name}. Use /commit to try again."
)
return
return commit_message
def get_diffs(self, *args):
if self.pretty:
args = ["--color"] + list(args)
diffs = self.repo.git.diff(*args)
return diffs
def commit(self, history=None, prefix=None, ask=False, message=None, which="chat_files"):
repo = self.repo
if not repo:
return
if not repo.is_dirty():
return
def get_dirty_files_and_diffs(file_list):
diffs = ""
relative_dirty_files = []
for fname in file_list:
relative_fname = self.get_rel_fname(fname)
relative_dirty_files.append(relative_fname)
try:
current_branch_commit_count = len(
list(self.repo.iter_commits(self.repo.active_branch))
)
except git.exc.GitCommandError:
current_branch_commit_count = None
if not current_branch_commit_count:
continue
these_diffs = self.get_diffs("HEAD", "--", relative_fname)
if these_diffs:
diffs += these_diffs + "\n"
return relative_dirty_files, diffs
if which == "repo_files":
all_files = [os.path.join(self.root, f) for f in self.get_all_relative_files()]
relative_dirty_fnames, diffs = get_dirty_files_and_diffs(all_files)
elif which == "chat_files":
relative_dirty_fnames, diffs = get_dirty_files_and_diffs(self.abs_fnames)
else:
raise ValueError(f"Invalid value for 'which': {which}")
if self.show_diffs or ask:
# don't use io.tool_output() because we don't want to log or further colorize
print(diffs)
context = self.get_context_from_history(history)
if message:
commit_message = message
else:
commit_message = self.get_commit_message(diffs, context)
if not commit_message:
commit_message = "work in progress"
if prefix:
commit_message = prefix + commit_message
if ask:
if which == "repo_files":
self.io.tool_output("Git repo has uncommitted changes.")
else:
self.io.tool_output("Files have uncommitted changes.")
res = self.io.prompt_ask(
"Commit before the chat proceeds [y/n/commit message]?",
default=commit_message,
).strip()
self.last_asked_for_commit_time = self.get_last_modified()
self.io.tool_output()
if res.lower() in ["n", "no"]:
self.io.tool_error("Skipped commmit.")
return
if res.lower() not in ["y", "yes"] and res:
commit_message = res
repo.git.add(*relative_dirty_fnames)
full_commit_message = commit_message + "\n\n# Aider chat conversation:\n\n" + context
repo.git.commit("-m", full_commit_message, "--no-verify")
commit_hash = repo.head.commit.hexsha[:7]
self.io.tool_output(f"Commit {commit_hash} {commit_message}")
return commit_hash, commit_message
def get_rel_fname(self, fname): def get_rel_fname(self, fname):
return os.path.relpath(fname, self.root) return os.path.relpath(fname, self.root)
@ -945,7 +668,7 @@ class Coder:
def get_all_relative_files(self): def get_all_relative_files(self):
if self.repo: if self.repo:
files = self.get_tracked_files() files = self.repo.get_tracked_files()
else: else:
files = self.get_inchat_relative_files() files = self.get_inchat_relative_files()
@ -1000,32 +723,6 @@ class Coder:
return full_path return full_path
def get_tracked_files(self):
if not self.repo:
return []
try:
commit = self.repo.head.commit
except ValueError:
commit = None
files = []
if commit:
for blob in commit.tree.traverse():
if blob.type == "blob": # blob is a file
files.append(blob.path)
# Add staged files
index = self.repo.index
staged_files = [path for path, _ in index.entries.keys()]
files.extend(staged_files)
# convert to appropriate os.sep, since git always normalizes to /
res = set(str(Path(PurePosixPath(path))) for path in files)
return res
apply_update_errors = 0 apply_update_errors = 0
def apply_updates(self): def apply_updates(self):
@ -1094,6 +791,72 @@ class Coder:
except JSONDecodeError: except JSONDecodeError:
pass pass
# commits...
def get_context_from_history(self, history):
context = ""
if history:
for msg in history:
context += "\n" + msg["role"].upper() + ": " + msg["content"] + "\n"
return context
def auto_commit(self):
context = self.get_context_from_history(self.cur_messages)
res = self.repo.commit(context=context, prefix="aider: ")
if res:
commit_hash, commit_message = res
self.last_aider_commit_hash = commit_hash
return self.gpt_prompts.files_content_gpt_edits.format(
hash=commit_hash,
message=commit_message,
)
self.io.tool_output("No changes made to git tracked files.")
return self.gpt_prompts.files_content_gpt_no_edits
def should_dirty_commit(self, inp):
cmds = self.commands.matching_commands(inp)
if cmds:
matching_commands, _, _ = cmds
if len(matching_commands) == 1:
cmd = matching_commands[0][1:]
if cmd in "add clear commit diff drop exit help ls tokens".split():
return
if self.last_asked_for_commit_time >= self.get_last_modified():
return
return True
def dirty_commit(self):
if not self.dirty_commits:
return
if not self.repo:
return
if not self.repo.is_dirty():
return
self.io.tool_output("Git repo has uncommitted changes.")
self.repo.show_diffs(self.pretty)
self.last_asked_for_commit_time = self.get_last_modified()
res = self.io.prompt_ask(
"Commit before the chat proceeds [y/n/commit message]?",
default="y",
).strip()
if res.lower() in ["n", "no"]:
self.io.tool_error("Skipped commmit.")
return
if res.lower() in ["y", "yes"]:
message = None
else:
message = res.strip()
self.repo.commit(message=message)
# files changed, move cur messages back behind the files messages
self.move_back_cur_messages(self.gpt_prompts.files_content_local_edits)
return True
def check_model_availability(main_model): def check_model_availability(main_model):
available_models = openai.Model.list() available_models = openai.Model.list()

View file

@ -42,15 +42,6 @@ class SingleWholeFileFunctionCoder(Coder):
else: else:
self.cur_messages += [dict(role="assistant", content=self.partial_response_content)] self.cur_messages += [dict(role="assistant", content=self.partial_response_content)]
def get_context_from_history(self, history):
context = ""
if history:
context += "# Context:\n"
for msg in history:
if msg["role"] == "user":
context += msg["role"].upper() + ": " + msg["content"] + "\n"
return context
def render_incremental_response(self, final=False): def render_incremental_response(self, final=False):
if self.partial_response_content: if self.partial_response_content:
return self.partial_response_content return self.partial_response_content

View file

@ -20,15 +20,6 @@ class WholeFileCoder(Coder):
else: else:
self.cur_messages += [dict(role="assistant", content=self.partial_response_content)] self.cur_messages += [dict(role="assistant", content=self.partial_response_content)]
def get_context_from_history(self, history):
context = ""
if history:
context += "# Context:\n"
for msg in history:
if msg["role"] == "user":
context += msg["role"].upper() + ": " + msg["content"] + "\n"
return context
def render_incremental_response(self, final): def render_incremental_response(self, final):
try: try:
return self.update_files(mode="diff") return self.update_files(mode="diff")

View file

@ -55,15 +55,6 @@ class WholeFileFunctionCoder(Coder):
else: else:
self.cur_messages += [dict(role="assistant", content=self.partial_response_content)] self.cur_messages += [dict(role="assistant", content=self.partial_response_content)]
def get_context_from_history(self, history):
context = ""
if history:
context += "# Context:\n"
for msg in history:
if msg["role"] == "user":
context += msg["role"].upper() + ": " + msg["content"] + "\n"
return context
def render_incremental_response(self, final=False): def render_incremental_response(self, final=False):
if self.partial_response_content: if self.partial_response_content:
return self.partial_response_content return self.partial_response_content

View file

@ -176,10 +176,10 @@ class Commands:
) )
return return
local_head = self.coder.repo.git.rev_parse("HEAD") local_head = self.coder.repo.repo.git.rev_parse("HEAD")
current_branch = self.coder.repo.active_branch.name current_branch = self.coder.repo.repo.active_branch.name
try: try:
remote_head = self.coder.repo.git.rev_parse(f"origin/{current_branch}") remote_head = self.coder.repo.repo.git.rev_parse(f"origin/{current_branch}")
has_origin = True has_origin = True
except git.exc.GitCommandError: except git.exc.GitCommandError:
has_origin = False has_origin = False
@ -192,14 +192,14 @@ class Commands:
) )
return return
last_commit = self.coder.repo.head.commit last_commit = self.coder.repo.repo.head.commit
if ( if (
not last_commit.message.startswith("aider:") not last_commit.message.startswith("aider:")
or last_commit.hexsha[:7] != self.coder.last_aider_commit_hash or last_commit.hexsha[:7] != self.coder.last_aider_commit_hash
): ):
self.io.tool_error("The last commit was not made by aider in this chat session.") self.io.tool_error("The last commit was not made by aider in this chat session.")
return return
self.coder.repo.git.reset("--hard", "HEAD~1") self.coder.repo.repo.git.reset("--hard", "HEAD~1")
self.io.tool_output( self.io.tool_output(
f"{last_commit.message.strip()}\n" f"{last_commit.message.strip()}\n"
f"The above commit {self.coder.last_aider_commit_hash} " f"The above commit {self.coder.last_aider_commit_hash} "
@ -220,7 +220,11 @@ class Commands:
return return
commits = f"{self.coder.last_aider_commit_hash}~1" commits = f"{self.coder.last_aider_commit_hash}~1"
diff = self.coder.get_diffs(commits, self.coder.last_aider_commit_hash) diff = self.coder.repo.get_diffs(
self.coder.pretty,
commits,
self.coder.last_aider_commit_hash,
)
# don't use io.tool_output() because we don't want to log or further colorize # don't use io.tool_output() because we don't want to log or further colorize
print(diff) print(diff)
@ -243,7 +247,7 @@ class Commands:
# if repo, filter against it # if repo, filter against it
if self.coder.repo: if self.coder.repo:
git_files = self.coder.get_tracked_files() git_files = self.coder.repo.get_tracked_files()
matched_files = [fn for fn in matched_files if str(fn) in git_files] matched_files = [fn for fn in matched_files if str(fn) in git_files]
res = list(map(str, matched_files)) res = list(map(str, matched_files))
@ -254,7 +258,7 @@ class Commands:
added_fnames = [] added_fnames = []
git_added = [] git_added = []
git_files = self.coder.get_tracked_files() git_files = self.coder.repo.get_tracked_files() if self.coder.repo else []
all_matched_files = set() all_matched_files = set()
for word in args.split(): for word in args.split():
@ -281,7 +285,7 @@ class Commands:
abs_file_path = self.coder.abs_root_path(matched_file) abs_file_path = self.coder.abs_root_path(matched_file)
if self.coder.repo and matched_file not in git_files: if self.coder.repo and matched_file not in git_files:
self.coder.repo.git.add(abs_file_path) self.coder.repo.repo.git.add(abs_file_path)
git_added.append(matched_file) git_added.append(matched_file)
if abs_file_path in self.coder.abs_fnames: if abs_file_path in self.coder.abs_fnames:
@ -298,8 +302,8 @@ class Commands:
if self.coder.repo and git_added: if self.coder.repo and git_added:
git_added = " ".join(git_added) git_added = " ".join(git_added)
commit_message = f"aider: Added {git_added}" commit_message = f"aider: Added {git_added}"
self.coder.repo.git.commit("-m", commit_message, "--no-verify") self.coder.repo.repo.git.commit("-m", commit_message, "--no-verify")
commit_hash = self.coder.repo.head.commit.hexsha[:7] commit_hash = self.coder.repo.repo.head.commit.hexsha[:7]
self.io.tool_output(f"Commit {commit_hash} {commit_message}") self.io.tool_output(f"Commit {commit_hash} {commit_message}")
if not added_fnames: if not added_fnames:

View file

@ -9,10 +9,14 @@ import openai
from aider import __version__, models from aider import __version__, models
from aider.coders import Coder from aider.coders import Coder
from aider.io import InputOutput from aider.io import InputOutput
from aider.repo import GitRepo
from aider.versioncheck import check_version from aider.versioncheck import check_version
from .dump import dump # noqa: F401
def get_git_root(): def get_git_root():
"""Try and guess the git repo, since the conf.yml can be at the repo root"""
try: try:
repo = git.Repo(search_parent_directories=True) repo = git.Repo(search_parent_directories=True)
return repo.working_tree_dir return repo.working_tree_dir
@ -20,6 +24,25 @@ def get_git_root():
return None return None
def guessed_wrong_repo(io, git_root, fnames, git_dname):
"""After we parse the args, we can determine the real repo. Did we guess wrong?"""
try:
check_repo = Path(GitRepo(io, fnames, git_dname).root).resolve()
except FileNotFoundError:
return
# we had no guess, rely on the "true" repo result
if not git_root:
return str(check_repo)
git_root = Path(git_root).resolve()
if check_repo == git_root:
return
return str(check_repo)
def setup_git(git_root, io): def setup_git(git_root, io):
if git_root: if git_root:
return git_root return git_root
@ -71,10 +94,13 @@ def check_gitignore(git_root, io, ask=True):
io.tool_output(f"Added {pat} to .gitignore") io.tool_output(f"Added {pat} to .gitignore")
def main(args=None, input=None, output=None): def main(argv=None, input=None, output=None, force_git_root=None):
if args is None: if argv is None:
args = sys.argv[1:] argv = sys.argv[1:]
if force_git_root:
git_root = force_git_root
else:
git_root = get_git_root() git_root = get_git_root()
conf_fname = Path(".aider.conf.yml") conf_fname = Path(".aider.conf.yml")
@ -101,7 +127,7 @@ def main(args=None, input=None, output=None):
"files", "files",
metavar="FILE", metavar="FILE",
nargs="*", nargs="*",
help="a list of source code files to edit with GPT (optional)", help="the directory of a git repo, or a list of files to edit with GPT (optional)",
) )
core_group.add_argument( core_group.add_argument(
"--openai-api-key", "--openai-api-key",
@ -344,7 +370,7 @@ def main(args=None, input=None, output=None):
), ),
) )
args = parser.parse_args(args) args = parser.parse_args(argv)
if args.dark_mode: if args.dark_mode:
args.user_input_color = "#32FF32" args.user_input_color = "#32FF32"
@ -371,6 +397,37 @@ def main(args=None, input=None, output=None):
dry_run=args.dry_run, dry_run=args.dry_run,
) )
fnames = [str(Path(fn).resolve()) for fn in args.files]
if len(args.files) > 1:
good = True
for fname in args.files:
if Path(fname).is_dir():
io.tool_error(f"{fname} is a directory, not provided alone.")
good = False
if not good:
io.tool_error(
"Provide either a single directory of a git repo, or a list of one or more files."
)
return 1
git_dname = None
if len(args.files) == 1:
if Path(args.files[0]).is_dir():
if args.git:
git_dname = str(Path(args.files[0]).resolve())
fnames = []
else:
io.tool_error(f"{args.files[0]} is a directory, but --no-git selected.")
return 1
# We can't know the git repo for sure until after parsing the args.
# If we guessed wrong, reparse because that changes things like
# the location of the config.yml and history files.
if args.git and not force_git_root:
right_repo_root = guessed_wrong_repo(io, git_root, fnames, git_dname)
if right_repo_root:
return main(argv, input, output, right_repo_root)
io.tool_output(f"Aider v{__version__}") io.tool_output(f"Aider v{__version__}")
check_version(io.tool_error) check_version(io.tool_error)
@ -418,12 +475,14 @@ def main(args=None, input=None, output=None):
setattr(openai, mod_key, val) setattr(openai, mod_key, val)
io.tool_output(f"Setting openai.{mod_key}={val}") io.tool_output(f"Setting openai.{mod_key}={val}")
try:
coder = Coder.create( coder = Coder.create(
main_model, main_model,
args.edit_format, args.edit_format,
io, io,
## ##
fnames=args.files, fnames=fnames,
git_dname=git_dname,
pretty=args.pretty, pretty=args.pretty,
show_diffs=args.show_diffs, show_diffs=args.show_diffs,
auto_commits=args.auto_commits, auto_commits=args.auto_commits,
@ -436,6 +495,9 @@ def main(args=None, input=None, output=None):
stream=args.stream, stream=args.stream,
use_git=args.git, use_git=args.git,
) )
except ValueError as err:
io.tool_error(str(err))
return 1
if args.show_repo_map: if args.show_repo_map:
repo_map = coder.get_repo_map() repo_map = coder.get_repo_map()
@ -443,9 +505,6 @@ def main(args=None, input=None, output=None):
io.tool_output(repo_map) io.tool_output(repo_map)
return return
if args.dirty_commits:
coder.commit(ask=True, which="repo_files")
if args.apply: if args.apply:
content = io.read_text(args.apply) content = io.read_text(args.apply)
if content is None: if content is None:
@ -454,6 +513,9 @@ def main(args=None, input=None, output=None):
return return
io.tool_output("Use /help to see in-chat commands, run with --help to see cmd line args") io.tool_output("Use /help to see in-chat commands, run with --help to see cmd line args")
coder.dirty_commit()
if args.message: if args.message:
io.tool_output() io.tool_output()
coder.run(with_message=args.message) coder.run(with_message=args.message)

177
aider/repo.py Normal file
View file

@ -0,0 +1,177 @@
import os
from pathlib import Path, PurePosixPath
import git
from aider import models, prompts, utils
from aider.sendchat import simple_send_with_retries
from .dump import dump # noqa: F401
class GitRepo:
repo = None
def __init__(self, io, fnames, git_dname):
self.io = io
if git_dname:
check_fnames = [git_dname]
elif fnames:
check_fnames = fnames
else:
check_fnames = ["."]
repo_paths = []
for fname in check_fnames:
fname = Path(fname)
fname = fname.resolve()
if not fname.exists() and fname.parent.exists():
fname = fname.parent
try:
repo_path = git.Repo(fname, search_parent_directories=True).working_dir
repo_path = utils.safe_abs_path(repo_path)
repo_paths.append(repo_path)
except git.exc.InvalidGitRepositoryError:
pass
num_repos = len(set(repo_paths))
if num_repos == 0:
raise FileNotFoundError
if num_repos > 1:
self.io.tool_error("Files are in different git repos.")
raise FileNotFoundError
# https://github.com/gitpython-developers/GitPython/issues/427
self.repo = git.Repo(repo_paths.pop(), odbt=git.GitDB)
self.root = utils.safe_abs_path(self.repo.working_tree_dir)
def add_new_files(self, fnames):
cur_files = [Path(fn).resolve() for fn in self.get_tracked_files()]
for fname in fnames:
if Path(fname).resolve() in cur_files:
continue
if not Path(fname).exists():
continue
self.io.tool_output(f"Adding {fname} to git")
self.repo.git.add(fname)
def commit(self, context=None, prefix=None, message=None):
if not self.repo.is_dirty():
return
if message:
commit_message = message
else:
diffs = self.get_diffs(False)
commit_message = self.get_commit_message(diffs, context)
if not commit_message:
commit_message = "(no commit message provided)"
if prefix:
commit_message = prefix + commit_message
full_commit_message = commit_message
if context:
full_commit_message += "\n\n# Aider chat conversation:\n\n" + context
self.repo.git.commit("-a", "-m", full_commit_message, "--no-verify")
commit_hash = self.repo.head.commit.hexsha[:7]
self.io.tool_output(f"Commit {commit_hash} {commit_message}")
return commit_hash, commit_message
def get_rel_repo_dir(self):
try:
return os.path.relpath(self.repo.git_dir, os.getcwd())
except ValueError:
return self.repo.git_dir
def get_commit_message(self, diffs, context):
if len(diffs) >= 4 * 1024 * 4:
self.io.tool_error(
f"Diff is too large for {models.GPT35.name} to generate a commit message."
)
return
diffs = "# Diffs:\n" + diffs
content = ""
if context:
content += context + "\n"
content += diffs
messages = [
dict(role="system", content=prompts.commit_system),
dict(role="user", content=content),
]
for model in [models.GPT35.name, models.GPT35_16k.name]:
commit_message = simple_send_with_retries(model, messages)
if commit_message:
break
if not commit_message:
self.io.tool_error("Failed to generate commit message!")
return
commit_message = commit_message.strip()
if commit_message and commit_message[0] == '"' and commit_message[-1] == '"':
commit_message = commit_message[1:-1].strip()
return commit_message
def get_diffs(self, pretty, *args):
try:
commits = self.repo.iter_commits(self.repo.active_branch)
current_branch_has_commits = any(commits)
except git.exc.GitCommandError:
current_branch_has_commits = False
if not current_branch_has_commits:
return ""
if pretty:
args = ["--color"] + list(args)
if not args:
args = ["HEAD"]
diffs = self.repo.git.diff(*args)
return diffs
def show_diffs(self, pretty):
diffs = self.get_diffs(pretty)
print(diffs)
def get_tracked_files(self):
if not self.repo:
return []
try:
commit = self.repo.head.commit
except ValueError:
commit = None
files = []
if commit:
for blob in commit.tree.traverse():
if blob.type == "blob": # blob is a file
files.append(blob.path)
# Add staged files
index = self.repo.index
staged_files = [path for path, _ in index.entries.keys()]
files.extend(staged_files)
# convert to appropriate os.sep, since git always normalizes to /
res = set(str(Path(PurePosixPath(path))) for path in files)
return res
def is_dirty(self):
return self.repo.is_dirty()

57
aider/sendchat.py Normal file
View file

@ -0,0 +1,57 @@
import hashlib
import json
import backoff
import openai
import requests
from openai.error import APIError, RateLimitError, ServiceUnavailableError, Timeout
@backoff.on_exception(
backoff.expo,
(
Timeout,
APIError,
ServiceUnavailableError,
RateLimitError,
requests.exceptions.ConnectionError,
),
max_tries=10,
on_backoff=lambda details: print(
f"{details.get('exception','Exception')}\nRetry in {details['wait']:.1f} seconds."
),
)
def send_with_retries(model, messages, functions, stream):
kwargs = dict(
model=model,
messages=messages,
temperature=0,
stream=stream,
)
if functions is not None:
kwargs["functions"] = functions
# we are abusing the openai object to stash these values
if hasattr(openai, "api_deployment_id"):
kwargs["deployment_id"] = openai.api_deployment_id
if hasattr(openai, "api_engine"):
kwargs["engine"] = openai.api_engine
# Generate SHA1 hash of kwargs and append it to chat_completion_call_hashes
hash_object = hashlib.sha1(json.dumps(kwargs, sort_keys=True).encode())
res = openai.ChatCompletion.create(**kwargs)
return hash_object, res
def simple_send_with_retries(model, messages):
try:
_hash, response = send_with_retries(
model=model,
messages=messages,
functions=None,
stream=False,
)
return response.choices[0].message.content
except (AttributeError, openai.error.InvalidRequestError):
return

View file

@ -1,4 +1,3 @@
import os
import tempfile import tempfile
import unittest import unittest
from pathlib import Path from pathlib import Path
@ -6,7 +5,6 @@ from unittest.mock import MagicMock, patch
import git import git
import openai import openai
import requests
from aider import models from aider import models
from aider.coders import Coder from aider.coders import Coder
@ -77,7 +75,7 @@ class TestCoder(unittest.TestCase):
# Mock the git repo # Mock the git repo
mock = MagicMock() mock = MagicMock()
mock.return_value = set(["file1.txt", "file2.py"]) mock.return_value = set(["file1.txt", "file2.py"])
coder.get_tracked_files = mock coder.repo.get_tracked_files = mock
# Call the check_for_file_mentions method # Call the check_for_file_mentions method
coder.check_for_file_mentions("Please check file1.txt and file2.py") coder.check_for_file_mentions("Please check file1.txt and file2.py")
@ -121,7 +119,7 @@ class TestCoder(unittest.TestCase):
mock = MagicMock() mock = MagicMock()
mock.return_value = set(["file1.txt", "file2.py"]) mock.return_value = set(["file1.txt", "file2.py"])
coder.get_tracked_files = mock coder.repo.get_tracked_files = mock
# Call the check_for_file_mentions method # Call the check_for_file_mentions method
coder.check_for_file_mentions("Please check file1.txt and file2.py") coder.check_for_file_mentions("Please check file1.txt and file2.py")
@ -152,7 +150,7 @@ class TestCoder(unittest.TestCase):
mock = MagicMock() mock = MagicMock()
mock.return_value = set([str(fname), str(other_fname)]) mock.return_value = set([str(fname), str(other_fname)])
coder.get_tracked_files = mock coder.repo.get_tracked_files = mock
# Call the check_for_file_mentions method # Call the check_for_file_mentions method
coder.check_for_file_mentions(f"Please check {fname}!") coder.check_for_file_mentions(f"Please check {fname}!")
@ -170,7 +168,7 @@ class TestCoder(unittest.TestCase):
mock = MagicMock() mock = MagicMock()
mock.return_value = set([str(fname)]) mock.return_value = set([str(fname)])
coder.get_tracked_files = mock coder.repo.get_tracked_files = mock
dump(fname) dump(fname)
# Call the check_for_file_mentions method # Call the check_for_file_mentions method
@ -178,110 +176,6 @@ class TestCoder(unittest.TestCase):
self.assertEqual(coder.abs_fnames, set([str(fname.resolve())])) self.assertEqual(coder.abs_fnames, set([str(fname.resolve())]))
def test_get_commit_message(self):
# Mock the IO object
mock_io = MagicMock()
# Initialize the Coder object with the mocked IO and mocked repo
coder = Coder.create(models.GPT4, None, mock_io)
# Mock the send method to set partial_response_content and return False
def mock_send(*args, **kwargs):
coder.partial_response_content = "a good commit message"
return False
coder.send = MagicMock(side_effect=mock_send)
# Call the get_commit_message method with dummy diff and context
result = coder.get_commit_message("dummy diff", "dummy context")
# Assert that the returned message is the expected one
self.assertEqual(result, "a good commit message")
def test_get_commit_message_strip_quotes(self):
# Mock the IO object
mock_io = MagicMock()
# Initialize the Coder object with the mocked IO and mocked repo
coder = Coder.create(models.GPT4, None, mock_io)
# Mock the send method to set partial_response_content and return False
def mock_send(*args, **kwargs):
coder.partial_response_content = "a good commit message"
return False
coder.send = MagicMock(side_effect=mock_send)
# Call the get_commit_message method with dummy diff and context
result = coder.get_commit_message("dummy diff", "dummy context")
# Assert that the returned message is the expected one
self.assertEqual(result, "a good commit message")
def test_get_commit_message_no_strip_unmatched_quotes(self):
# Mock the IO object
mock_io = MagicMock()
# Initialize the Coder object with the mocked IO and mocked repo
coder = Coder.create(models.GPT4, None, mock_io)
# Mock the send method to set partial_response_content and return False
def mock_send(*args, **kwargs):
coder.partial_response_content = 'a good "commit message"'
return False
coder.send = MagicMock(side_effect=mock_send)
# Call the get_commit_message method with dummy diff and context
result = coder.get_commit_message("dummy diff", "dummy context")
# Assert that the returned message is the expected one
self.assertEqual(result, 'a good "commit message"')
@patch("aider.coders.base_coder.openai.ChatCompletion.create")
@patch("builtins.print")
def test_send_with_retries_rate_limit_error(self, mock_print, mock_chat_completion_create):
# Mock the IO object
mock_io = MagicMock()
# Initialize the Coder object with the mocked IO and mocked repo
coder = Coder.create(models.GPT4, None, mock_io)
# Set up the mock to raise RateLimitError on
# the first call and return None on the second call
mock_chat_completion_create.side_effect = [
openai.error.RateLimitError("Rate limit exceeded"),
None,
]
# Call the send_with_retries method
coder.send_with_retries("model", ["message"], None)
# Assert that print was called once
mock_print.assert_called_once()
@patch("aider.coders.base_coder.openai.ChatCompletion.create")
@patch("builtins.print")
def test_send_with_retries_connection_error(self, mock_print, mock_chat_completion_create):
# Mock the IO object
mock_io = MagicMock()
# Initialize the Coder object with the mocked IO and mocked repo
coder = Coder.create(models.GPT4, None, mock_io)
# Set up the mock to raise ConnectionError on the first call
# and return None on the second call
mock_chat_completion_create.side_effect = [
requests.exceptions.ConnectionError("Connection error"),
None,
]
# Call the send_with_retries method
coder.send_with_retries("model", ["message"], None)
# Assert that print was called once
mock_print.assert_called_once()
def test_run_with_file_deletion(self): def test_run_with_file_deletion(self):
# Create a few temporary files # Create a few temporary files
@ -419,49 +313,5 @@ class TestCoder(unittest.TestCase):
with self.assertRaises(openai.error.InvalidRequestError): with self.assertRaises(openai.error.InvalidRequestError):
coder.run(with_message="hi") coder.run(with_message="hi")
def test_get_tracked_files(self):
# Create a temporary directory
tempdir = Path(tempfile.mkdtemp())
# Initialize a git repository in the temporary directory and set user name and email
repo = git.Repo.init(tempdir)
repo.config_writer().set_value("user", "name", "Test User").release()
repo.config_writer().set_value("user", "email", "testuser@example.com").release()
# Create three empty files and add them to the git repository
filenames = ["README.md", "subdir/fänny.md", "systemüber/blick.md", 'file"with"quotes.txt']
created_files = []
for filename in filenames:
file_path = tempdir / filename
try:
file_path.parent.mkdir(parents=True, exist_ok=True)
file_path.touch()
repo.git.add(str(file_path))
created_files.append(Path(filename))
except OSError:
# windows won't allow files with quotes, that's ok
self.assertIn('"', filename)
self.assertEqual(os.name, "nt")
self.assertTrue(len(created_files) >= 3)
repo.git.commit("-m", "added")
# Create a Coder object on the temporary directory
coder = Coder.create(
models.GPT4,
None,
io=InputOutput(),
fnames=[str(tempdir / filenames[0])],
)
tracked_files = coder.get_tracked_files()
# On windows, paths will come back \like\this, so normalize them back to Paths
tracked_files = [Path(fn) for fn in tracked_files]
# Assert that coder.get_tracked_files() returns the three filenames
self.assertEqual(set(tracked_files), set(created_files))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View file

@ -35,14 +35,42 @@ class TestMain(TestCase):
main(["--no-git"], input=DummyInput(), output=DummyOutput()) main(["--no-git"], input=DummyInput(), output=DummyOutput())
def test_main_with_empty_dir_new_file(self): def test_main_with_empty_dir_new_file(self):
main(["foo.txt", "--yes"], input=DummyInput(), output=DummyOutput()) main(["foo.txt", "--yes", "--no-git"], input=DummyInput(), output=DummyOutput())
self.assertTrue(os.path.exists("foo.txt")) self.assertTrue(os.path.exists("foo.txt"))
def test_main_with_empty_git_dir_new_file(self): @patch("aider.repo.GitRepo.get_commit_message", return_value="mock commit message")
def test_main_with_empty_git_dir_new_file(self, _):
make_repo() make_repo()
main(["--yes", "foo.txt"], input=DummyInput(), output=DummyOutput()) main(["--yes", "foo.txt"], input=DummyInput(), output=DummyOutput())
self.assertTrue(os.path.exists("foo.txt")) self.assertTrue(os.path.exists("foo.txt"))
@patch("aider.repo.GitRepo.get_commit_message", return_value="mock commit message")
def test_main_with_empty_git_dir_new_files(self, _):
make_repo()
main(["--yes", "foo.txt", "bar.txt"], input=DummyInput(), output=DummyOutput())
self.assertTrue(os.path.exists("foo.txt"))
self.assertTrue(os.path.exists("bar.txt"))
def test_main_with_dname_and_fname(self):
subdir = Path("subdir")
subdir.mkdir()
make_repo(str(subdir))
res = main(["subdir", "foo.txt"], input=DummyInput(), output=DummyOutput())
self.assertNotEqual(res, None)
@patch("aider.repo.GitRepo.get_commit_message", return_value="mock commit message")
def test_main_with_subdir_repo_fnames(self, _):
subdir = Path("subdir")
subdir.mkdir()
make_repo(str(subdir))
main(
["--yes", str(subdir / "foo.txt"), str(subdir / "bar.txt")],
input=DummyInput(),
output=DummyOutput(),
)
self.assertTrue((subdir / "foo.txt").exists())
self.assertTrue((subdir / "bar.txt").exists())
def test_main_with_git_config_yml(self): def test_main_with_git_config_yml(self):
make_repo() make_repo()

114
tests/test_repo.py Normal file
View file

@ -0,0 +1,114 @@
import os
import tempfile
import unittest
from pathlib import Path
from unittest.mock import patch
import git
from aider.dump import dump # noqa: F401
from aider.io import InputOutput
from aider.repo import GitRepo
from tests.utils import GitTemporaryDirectory
class TestRepo(unittest.TestCase):
@patch("aider.repo.simple_send_with_retries")
def test_get_commit_message(self, mock_send):
mock_send.return_value = "a good commit message"
repo = GitRepo(InputOutput(), None, None)
# Call the get_commit_message method with dummy diff and context
result = repo.get_commit_message("dummy diff", "dummy context")
# Assert that the returned message is the expected one
self.assertEqual(result, "a good commit message")
@patch("aider.repo.simple_send_with_retries")
def test_get_commit_message_strip_quotes(self, mock_send):
mock_send.return_value = '"a good commit message"'
repo = GitRepo(InputOutput(), None, None)
# Call the get_commit_message method with dummy diff and context
result = repo.get_commit_message("dummy diff", "dummy context")
# Assert that the returned message is the expected one
self.assertEqual(result, "a good commit message")
@patch("aider.repo.simple_send_with_retries")
def test_get_commit_message_no_strip_unmatched_quotes(self, mock_send):
mock_send.return_value = 'a good "commit message"'
repo = GitRepo(InputOutput(), None, None)
# Call the get_commit_message method with dummy diff and context
result = repo.get_commit_message("dummy diff", "dummy context")
# Assert that the returned message is the expected one
self.assertEqual(result, 'a good "commit message"')
def test_get_tracked_files(self):
# Create a temporary directory
tempdir = Path(tempfile.mkdtemp())
# Initialize a git repository in the temporary directory and set user name and email
repo = git.Repo.init(tempdir)
repo.config_writer().set_value("user", "name", "Test User").release()
repo.config_writer().set_value("user", "email", "testuser@example.com").release()
# Create three empty files and add them to the git repository
filenames = ["README.md", "subdir/fänny.md", "systemüber/blick.md", 'file"with"quotes.txt']
created_files = []
for filename in filenames:
file_path = tempdir / filename
try:
file_path.parent.mkdir(parents=True, exist_ok=True)
file_path.touch()
repo.git.add(str(file_path))
created_files.append(Path(filename))
except OSError:
# windows won't allow files with quotes, that's ok
self.assertIn('"', filename)
self.assertEqual(os.name, "nt")
self.assertTrue(len(created_files) >= 3)
repo.git.commit("-m", "added")
tracked_files = GitRepo(InputOutput(), [tempdir], None).get_tracked_files()
# On windows, paths will come back \like\this, so normalize them back to Paths
tracked_files = [Path(fn) for fn in tracked_files]
# Assert that coder.get_tracked_files() returns the three filenames
self.assertEqual(set(tracked_files), set(created_files))
def test_get_tracked_files_with_new_staged_file(self):
with GitTemporaryDirectory():
# new repo
raw_repo = git.Repo()
# add it, but no commits at all in the raw_repo yet
fname = Path("new.txt")
fname.touch()
raw_repo.git.add(str(fname))
git_repo = GitRepo(InputOutput(), None, None)
# better be there
fnames = git_repo.get_tracked_files()
self.assertIn(str(fname), fnames)
# commit it, better still be there
raw_repo.git.commit("-m", "new")
fnames = git_repo.get_tracked_files()
self.assertIn(str(fname), fnames)
# new file, added but not committed
fname2 = Path("new2.txt")
fname2.touch()
raw_repo.git.add(str(fname2))
# both should be there
fnames = git_repo.get_tracked_files()
self.assertIn(str(fname), fnames)
self.assertIn(str(fname2), fnames)

View file

@ -4,7 +4,6 @@ from unittest.mock import patch
from aider.io import InputOutput from aider.io import InputOutput
from aider.repomap import RepoMap from aider.repomap import RepoMap
from tests.utils import IgnorantTemporaryDirectory from tests.utils import IgnorantTemporaryDirectory

41
tests/test_sendchat.py Normal file
View file

@ -0,0 +1,41 @@
import unittest
from unittest.mock import patch
import openai
import requests
from aider.sendchat import send_with_retries
class TestSendChat(unittest.TestCase):
@patch("aider.sendchat.openai.ChatCompletion.create")
@patch("builtins.print")
def test_send_with_retries_rate_limit_error(self, mock_print, mock_chat_completion_create):
# Set up the mock to raise RateLimitError on
# the first call and return None on the second call
mock_chat_completion_create.side_effect = [
openai.error.RateLimitError("Rate limit exceeded"),
None,
]
# Call the send_with_retries method
send_with_retries("model", ["message"], None, False)
# Assert that print was called once
mock_print.assert_called_once()
@patch("aider.sendchat.openai.ChatCompletion.create")
@patch("builtins.print")
def test_send_with_retries_connection_error(self, mock_print, mock_chat_completion_create):
# Set up the mock to raise ConnectionError on the first call
# and return None on the second call
mock_chat_completion_create.side_effect = [
requests.exceptions.ConnectionError("Connection error"),
None,
]
# Call the send_with_retries method
send_with_retries("model", ["message"], None, False)
# Assert that print was called once
mock_print.assert_called_once()

View file

@ -42,7 +42,11 @@ class GitTemporaryDirectory(ChdirTemporaryDirectory):
return res return res
def make_repo(): def make_repo(path=None):
repo = git.Repo.init() if not path:
path = "."
repo = git.Repo.init(path)
repo.config_writer().set_value("user", "name", "Test User").release() repo.config_writer().set_value("user", "name", "Test User").release()
repo.config_writer().set_value("user", "email", "testuser@example.com").release() repo.config_writer().set_value("user", "email", "testuser@example.com").release()
return repo