Merge pull request #207 from paul-gauthier/late-bind-commits

Only commit dirty files if GPT tries to edit them #200
This commit is contained in:
paul-gauthier 2023-08-18 10:43:06 -07:00 committed by GitHub
commit 9933ad85b0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 419 additions and 217 deletions

View file

@ -2,6 +2,7 @@
### main branch ### main branch
- [Only git commit dirty files that GPT tries to edit](https://aider.chat/docs/faq.html#how-does-aider-use-git)
- Send chat history as prompt/context for Whisper voice transcription - Send chat history as prompt/context for Whisper voice transcription
- Added `--voice-language` switch to constrain `/voice` to transcribe to a specific language - Added `--voice-language` switch to constrain `/voice` to transcribe to a specific language

View file

@ -57,13 +57,7 @@ class Coder:
io, io,
**kwargs, **kwargs,
): ):
from . import ( from . import EditBlockCoder, WholeFileCoder
EditBlockCoder,
EditBlockFunctionCoder,
SingleWholeFileFunctionCoder,
WholeFileCoder,
WholeFileFunctionCoder,
)
if not main_model: if not main_model:
main_model = models.GPT35_16k main_model = models.GPT35_16k
@ -84,14 +78,6 @@ class Coder:
return EditBlockCoder(main_model, io, **kwargs) return EditBlockCoder(main_model, io, **kwargs)
elif edit_format == "whole": elif edit_format == "whole":
return WholeFileCoder(main_model, io, **kwargs) return WholeFileCoder(main_model, io, **kwargs)
elif edit_format == "whole-func":
return WholeFileFunctionCoder(main_model, io, **kwargs)
elif edit_format == "single-whole-func":
return SingleWholeFileFunctionCoder(main_model, io, **kwargs)
elif edit_format == "diff-func-list":
return EditBlockFunctionCoder("list", main_model, io, **kwargs)
elif edit_format in ("diff-func", "diff-func-string"):
return EditBlockFunctionCoder("string", main_model, io, **kwargs)
else: else:
raise ValueError(f"Unknown edit format {edit_format}") raise ValueError(f"Unknown edit format {edit_format}")
@ -119,6 +105,7 @@ class Coder:
self.chat_completion_call_hashes = [] self.chat_completion_call_hashes = []
self.chat_completion_response_hashes = [] self.chat_completion_response_hashes = []
self.need_commit_before_edits = set()
self.verbose = verbose self.verbose = verbose
self.abs_fnames = set() self.abs_fnames = set()
@ -203,9 +190,6 @@ class Coder:
for fname in self.get_inchat_relative_files(): for fname in self.get_inchat_relative_files():
self.io.tool_output(f"Added {fname} to the chat.") self.io.tool_output(f"Added {fname} to the chat.")
if self.repo:
self.repo.add_new_files(fname for fname in fnames if not Path(fname).is_dir())
self.summarizer = ChatSummary() self.summarizer = ChatSummary()
self.summarizer_thread = None self.summarizer_thread = None
self.summarized_done_messages = None self.summarized_done_messages = None
@ -408,11 +392,6 @@ class Coder:
self.commands, self.commands,
) )
if self.should_dirty_commit(inp) and self.dirty_commit():
if inp.strip():
self.io.tool_output("Use up-arrow to retry previous command:", inp)
return
if not inp: if not inp:
return return
@ -500,7 +479,7 @@ class Coder:
if edited: if edited:
if self.repo and self.auto_commits and not self.dry_run: if self.repo and self.auto_commits and not self.dry_run:
saved_message = self.auto_commit() saved_message = self.auto_commit(edited)
elif hasattr(self.gpt_prompts, "files_content_gpt_edits_no_repo"): elif hasattr(self.gpt_prompts, "files_content_gpt_edits_no_repo"):
saved_message = self.gpt_prompts.files_content_gpt_edits_no_repo saved_message = self.gpt_prompts.files_content_gpt_edits_no_repo
else: else:
@ -728,43 +707,94 @@ class Coder:
def get_addable_relative_files(self): def get_addable_relative_files(self):
return set(self.get_all_relative_files()) - set(self.get_inchat_relative_files()) return set(self.get_all_relative_files()) - set(self.get_inchat_relative_files())
def allowed_to_edit(self, path, write_content=None): def check_for_dirty_commit(self, path):
full_path = self.abs_root_path(path) if not self.repo:
return
if full_path in self.abs_fnames: if not self.dirty_commits:
if write_content: return
self.io.write_text(full_path, write_content) if not self.repo.is_dirty(path):
return full_path
if not Path(full_path).exists():
question = f"Allow creation of new file {path}?" # noqa: E501
else:
question = f"Allow edits to {path} which was not previously provided?" # noqa: E501
if not self.io.confirm_ask(question):
self.io.tool_error(f"Skipping edit to {path}")
return return
if not Path(full_path).exists() and not self.dry_run: fullp = Path(self.abs_root_path(path))
Path(full_path).parent.mkdir(parents=True, exist_ok=True) if not fullp.stat().st_size:
Path(full_path).touch() return
self.abs_fnames.add(full_path) self.io.tool_output(f"Committing {path} before applying edits.")
self.need_commit_before_edits.add(path)
return
# Check if the file is already in the repo def allowed_to_edit(self, path):
full_path = self.abs_root_path(path)
if self.repo: if self.repo:
tracked_files = set(self.repo.get_tracked_files()) need_to_add = not self.repo.path_in_repo(path)
relative_fname = self.get_rel_fname(full_path) else:
if relative_fname not in tracked_files and self.io.confirm_ask(f"Add {path} to git?"): need_to_add = False
if not self.dry_run:
if full_path in self.abs_fnames:
self.check_for_dirty_commit(path)
return True
if not Path(full_path).exists():
if not self.io.confirm_ask(f"Allow creation of new file {path}?"):
self.io.tool_error(f"Skipping edits to {path}")
return
if not self.dry_run:
Path(full_path).parent.mkdir(parents=True, exist_ok=True)
Path(full_path).touch()
# Seems unlikely that we needed to create the file, but it was
# actually already part of the repo.
# But let's only add if we need to, just to be safe.
if need_to_add:
self.repo.repo.git.add(full_path) self.repo.repo.git.add(full_path)
if write_content: self.abs_fnames.add(full_path)
self.io.write_text(full_path, write_content) return True
return full_path if not self.io.confirm_ask(
f"Allow edits to {path} which was not previously added to chat?"
):
self.io.tool_error(f"Skipping edits to {path}")
return
if need_to_add:
self.repo.repo.git.add(full_path)
self.abs_fnames.add(full_path)
self.check_for_dirty_commit(path)
return True
apply_update_errors = 0 apply_update_errors = 0
def prepare_to_edit(self, edits):
res = []
seen = dict()
self.need_commit_before_edits = set()
for edit in edits:
path = edit[0]
if path in seen:
allowed = seen[path]
else:
allowed = self.allowed_to_edit(path)
seen[path] = allowed
if allowed:
res.append(edit)
self.dirty_commit()
self.need_commit_before_edits = set()
return res
def update_files(self):
edits = self.get_edits()
edits = self.prepare_to_edit(edits)
self.apply_edits(edits)
return set(edit[0] for edit in edits)
def apply_updates(self): def apply_updates(self):
max_apply_update_errors = 3 max_apply_update_errors = 3
@ -795,12 +825,11 @@ class Coder:
self.apply_update_errors = 0 self.apply_update_errors = 0
if edited: for path in edited:
for path in sorted(edited): if self.dry_run:
if self.dry_run: self.io.tool_output(f"Did not apply edit to {path} (--dry-run)")
self.io.tool_output(f"Did not apply edit to {path} (--dry-run)") else:
else: self.io.tool_output(f"Applied edit to {path}")
self.io.tool_output(f"Applied edit to {path}")
return edited, None return edited, None
@ -840,9 +869,9 @@ class Coder:
context += "\n" + msg["role"].upper() + ": " + msg["content"] + "\n" context += "\n" + msg["role"].upper() + ": " + msg["content"] + "\n"
return context return context
def auto_commit(self): def auto_commit(self, edited):
context = self.get_context_from_history(self.cur_messages) context = self.get_context_from_history(self.cur_messages)
res = self.repo.commit(context=context, prefix="aider: ") res = self.repo.commit(fnames=edited, context=context, prefix="aider: ")
if res: if res:
commit_hash, commit_message = res commit_hash, commit_message = res
self.last_aider_commit_hash = commit_hash self.last_aider_commit_hash = commit_hash
@ -855,43 +884,14 @@ class Coder:
self.io.tool_output("No changes made to git tracked files.") self.io.tool_output("No changes made to git tracked files.")
return self.gpt_prompts.files_content_gpt_no_edits return self.gpt_prompts.files_content_gpt_no_edits
def should_dirty_commit(self, inp):
cmds = self.commands.matching_commands(inp)
if cmds:
matching_commands, _, _ = cmds
if len(matching_commands) == 1:
cmd = matching_commands[0][1:]
if cmd in "add clear commit diff drop exit help ls tokens".split():
return
if self.last_asked_for_commit_time >= self.get_last_modified():
return
return True
def dirty_commit(self): def dirty_commit(self):
if not self.need_commit_before_edits:
return
if not self.dirty_commits: if not self.dirty_commits:
return return
if not self.repo: if not self.repo:
return return
if not self.repo.is_dirty(): self.repo.commit(fnames=self.need_commit_before_edits)
return
self.io.tool_output("Git repo has uncommitted changes.")
self.repo.show_diffs(self.pretty)
self.last_asked_for_commit_time = self.get_last_modified()
res = self.io.prompt_ask(
"Commit before the chat proceeds [y/n/commit message]?",
default="y",
).strip()
if res.lower() in ["n", "no"]:
self.io.tool_error("Skipped commmit.")
return
if res.lower() in ["y", "yes"]:
message = None
else:
message = res.strip()
self.repo.commit(message=message)
# files changed, move cur messages back behind the files messages # files changed, move cur messages back behind the files messages
self.move_back_cur_messages(self.gpt_prompts.files_content_local_edits) self.move_back_cur_messages(self.gpt_prompts.files_content_local_edits)

View file

@ -13,22 +13,21 @@ class EditBlockCoder(Coder):
self.gpt_prompts = EditBlockPrompts() self.gpt_prompts = EditBlockPrompts()
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def update_files(self): def get_edits(self):
content = self.partial_response_content content = self.partial_response_content
# might raise ValueError for malformed ORIG/UPD blocks # might raise ValueError for malformed ORIG/UPD blocks
edits = list(find_original_update_blocks(content)) edits = list(find_original_update_blocks(content))
edited = set() return edits
def apply_edits(self, edits):
for path, original, updated in edits: for path, original, updated in edits:
full_path = self.allowed_to_edit(path) full_path = self.abs_root_path(path)
if not full_path:
continue
content = self.io.read_text(full_path) content = self.io.read_text(full_path)
content = do_replace(full_path, content, original, updated) content = do_replace(full_path, content, original, updated)
if content: if content:
self.io.write_text(full_path, content) self.io.write_text(full_path, content)
edited.add(path)
continue continue
raise ValueError(f"""InvalidEditBlock: edit failed! raise ValueError(f"""InvalidEditBlock: edit failed!
@ -42,8 +41,6 @@ The HEAD block needs to be EXACTLY the same as the lines in {path} with nothing
{original}``` {original}```
""") """)
return edited
def prep(content): def prep(content):
if content and not content.endswith("\n"): if content and not content.endswith("\n"):

View file

@ -58,6 +58,7 @@ class EditBlockFunctionCoder(Coder):
] ]
def __init__(self, code_format, *args, **kwargs): def __init__(self, code_format, *args, **kwargs):
raise RuntimeError("Deprecated, needs to be refactored to support get_edits/apply_edits")
self.code_format = code_format self.code_format = code_format
if code_format == "string": if code_format == "string":
@ -91,7 +92,7 @@ class EditBlockFunctionCoder(Coder):
res = json.dumps(args, indent=4) res = json.dumps(args, indent=4)
return res return res
def update_files(self): def _update_files(self):
name = self.partial_response_function_call.get("name") name = self.partial_response_function_call.get("name")
if name and name != "replace_lines": if name and name != "replace_lines":

View file

@ -31,6 +31,7 @@ class SingleWholeFileFunctionCoder(Coder):
] ]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
raise RuntimeError("Deprecated, needs to be refactored to support get_edits/apply_edits")
self.gpt_prompts = SingleWholeFileFunctionPrompts() self.gpt_prompts = SingleWholeFileFunctionPrompts()
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -94,7 +95,7 @@ class SingleWholeFileFunctionCoder(Coder):
return "\n".join(show_diff) return "\n".join(show_diff)
def update_files(self): def _update_files(self):
name = self.partial_response_function_call.get("name") name = self.partial_response_function_call.get("name")
if name and name != "write_file": if name and name != "write_file":
raise ValueError(f'Unknown function_call name="{name}", use name="write_file"') raise ValueError(f'Unknown function_call name="{name}", use name="write_file"')

View file

@ -22,11 +22,11 @@ class WholeFileCoder(Coder):
def render_incremental_response(self, final): def render_incremental_response(self, final):
try: try:
return self.update_files(mode="diff") return self.get_edits(mode="diff")
except ValueError: except ValueError:
return self.partial_response_content return self.partial_response_content
def update_files(self, mode="update"): def get_edits(self, mode="update"):
content = self.partial_response_content content = self.partial_response_content
chat_files = self.get_inchat_relative_files() chat_files = self.get_inchat_relative_files()
@ -46,7 +46,7 @@ class WholeFileCoder(Coder):
# ending an existing block # ending an existing block
saw_fname = None saw_fname = None
full_path = (Path(self.root) / fname).absolute() full_path = self.abs_root_path(fname)
if mode == "diff": if mode == "diff":
output += self.do_live_diff(full_path, new_lines, True) output += self.do_live_diff(full_path, new_lines, True)
@ -104,25 +104,30 @@ class WholeFileCoder(Coder):
if fname: if fname:
edits.append((fname, fname_source, new_lines)) edits.append((fname, fname_source, new_lines))
edited = set() seen = set()
refined_edits = []
# process from most reliable filename, to least reliable # process from most reliable filename, to least reliable
for source in ("block", "saw", "chat"): for source in ("block", "saw", "chat"):
for fname, fname_source, new_lines in edits: for fname, fname_source, new_lines in edits:
if fname_source != source: if fname_source != source:
continue continue
# if a higher priority source already edited the file, skip # if a higher priority source already edited the file, skip
if fname in edited: if fname in seen:
continue continue
# we have a winner seen.add(fname)
new_lines = "".join(new_lines) refined_edits.append((fname, fname_source, new_lines))
if self.allowed_to_edit(fname, new_lines):
edited.add(fname)
return edited return refined_edits
def apply_edits(self, edits):
for path, fname_source, new_lines in edits:
full_path = self.abs_root_path(path)
new_lines = "".join(new_lines)
self.io.write_text(full_path, new_lines)
def do_live_diff(self, full_path, new_lines, final): def do_live_diff(self, full_path, new_lines, final):
if full_path.exists(): if Path(full_path).exists():
orig_lines = self.io.read_text(full_path).splitlines(keepends=True) orig_lines = self.io.read_text(full_path).splitlines(keepends=True)
show_diff = diffs.diff_partial_update( show_diff = diffs.diff_partial_update(

View file

@ -44,6 +44,8 @@ class WholeFileFunctionCoder(Coder):
] ]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
raise RuntimeError("Deprecated, needs to be refactored to support get_edits/apply_edits")
self.gpt_prompts = WholeFileFunctionPrompts() self.gpt_prompts = WholeFileFunctionPrompts()
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -105,7 +107,7 @@ class WholeFileFunctionCoder(Coder):
return "\n".join(show_diff) return "\n".join(show_diff)
def update_files(self): def _update_files(self):
name = self.partial_response_function_call.get("name") name = self.partial_response_function_call.get("name")
if name and name != "write_file": if name and name != "write_file":
raise ValueError(f'Unknown function_call name="{name}", use name="write_file"') raise ValueError(f'Unknown function_call name="{name}", use name="write_file"')

View file

@ -230,7 +230,7 @@ class Commands:
return return
commits = f"{self.coder.last_aider_commit_hash}~1" commits = f"{self.coder.last_aider_commit_hash}~1"
diff = self.coder.repo.get_diffs( diff = self.coder.repo.diff_commits(
self.coder.pretty, self.coder.pretty,
commits, commits,
self.coder.last_aider_commit_hash, self.coder.last_aider_commit_hash,

View file

@ -522,8 +522,6 @@ def main(argv=None, input=None, output=None, force_git_root=None):
io.tool_output("Use /help to see in-chat commands, run with --help to see cmd line args") io.tool_output("Use /help to see in-chat commands, run with --help to see cmd line args")
coder.dirty_commit()
if args.message: if args.message:
io.tool_output() io.tool_output()
coder.run(with_message=args.message) coder.run(with_message=args.message)

View file

@ -49,24 +49,14 @@ class GitRepo:
self.repo = git.Repo(repo_paths.pop(), odbt=git.GitDB) self.repo = git.Repo(repo_paths.pop(), odbt=git.GitDB)
self.root = utils.safe_abs_path(self.repo.working_tree_dir) self.root = utils.safe_abs_path(self.repo.working_tree_dir)
def add_new_files(self, fnames): def commit(self, fnames=None, context=None, prefix=None, message=None):
cur_files = [str(Path(fn).resolve()) for fn in self.get_tracked_files()] if not fnames and not self.repo.is_dirty():
for fname in fnames:
if str(Path(fname).resolve()) in cur_files:
continue
if not Path(fname).exists():
continue
self.io.tool_output(f"Adding {fname} to git")
self.repo.git.add(fname)
def commit(self, context=None, prefix=None, message=None):
if not self.repo.is_dirty():
return return
if message: if message:
commit_message = message commit_message = message
else: else:
diffs = self.get_diffs(False) diffs = self.get_diffs(fnames)
commit_message = self.get_commit_message(diffs, context) commit_message = self.get_commit_message(diffs, context)
if not commit_message: if not commit_message:
@ -79,7 +69,16 @@ class GitRepo:
if context: if context:
full_commit_message += "\n\n# Aider chat conversation:\n\n" + context full_commit_message += "\n\n# Aider chat conversation:\n\n" + context
self.repo.git.commit("-a", "-m", full_commit_message, "--no-verify") cmd = ["-m", full_commit_message, "--no-verify"]
if fnames:
fnames = [str(self.abs_root_path(fn)) for fn in fnames]
for fname in fnames:
self.repo.git.add(fname)
cmd += ["--"] + fnames
else:
cmd += ["-a"]
self.repo.git.commit(cmd)
commit_hash = self.repo.head.commit.hexsha[:7] commit_hash = self.repo.head.commit.hexsha[:7]
self.io.tool_output(f"Commit {commit_hash} {commit_message}") self.io.tool_output(f"Commit {commit_hash} {commit_message}")
@ -125,41 +124,38 @@ class GitRepo:
return commit_message return commit_message
def get_diffs(self, pretty, *args): def get_diffs(self, fnames=None):
args = list(args) # We always want diffs of index and working dir
# if args are specified, just add --pretty if needed
if args:
if pretty:
args = ["--color"] + args
return self.repo.git.diff(*args)
# otherwise, we always want diffs of index and working dir
try: try:
commits = self.repo.iter_commits(self.repo.active_branch) commits = self.repo.iter_commits(self.repo.active_branch)
current_branch_has_commits = any(commits) current_branch_has_commits = any(commits)
except git.exc.GitCommandError: except git.exc.GitCommandError:
current_branch_has_commits = False current_branch_has_commits = False
if pretty: if not fnames:
args = ["--color"] fnames = []
if current_branch_has_commits: if current_branch_has_commits:
# if there is a HEAD, just diff against it to pick up index + working args = ["HEAD", "--"] + list(fnames)
args += ["HEAD"]
return self.repo.git.diff(*args) return self.repo.git.diff(*args)
# diffs in the index wd_args = ["--"] + list(fnames)
diffs = self.repo.git.diff(*(args + ["--cached"])) index_args = ["--cached"] + wd_args
# plus, diffs in the working dir
diffs += self.repo.git.diff(*args) diffs = self.repo.git.diff(*index_args)
diffs += self.repo.git.diff(*wd_args)
return diffs return diffs
def show_diffs(self, pretty): def diff_commits(self, pretty, from_commit, to_commit):
diffs = self.get_diffs(pretty) args = []
print(diffs) if pretty:
args += ["--color"]
args += [from_commit, to_commit]
diffs = self.repo.git.diff(*args)
return diffs
def get_tracked_files(self): def get_tracked_files(self):
if not self.repo: if not self.repo:
@ -190,5 +186,19 @@ class GitRepo:
return res return res
def is_dirty(self): def path_in_repo(self, path):
return self.repo.is_dirty() if not self.repo:
return
tracked_files = set(self.get_tracked_files())
return path in tracked_files
def abs_root_path(self, path):
res = Path(self.root) / path
return utils.safe_abs_path(res)
def is_dirty(self, path=None):
if path and not self.path_in_repo(path):
return True
return self.repo.is_dirty(path=path)

View file

@ -13,29 +13,29 @@
It is recommended that you use aider with code that is part of a git repo. It is recommended that you use aider with code that is part of a git repo.
This allows aider to maintain the safety of your code. Using git makes it easy to: This allows aider to maintain the safety of your code. Using git makes it easy to:
- Review the changes GPT made to your code - Undo any changes that weren't appropriate
- Undo changes that weren't appropriate - Go back later to review the changes GPT made to your code
- Manage a series of GPT's changes on a git branch - Manage a series of GPT's changes on a git branch
- etc
Working without git means that GPT might drastically change your code without an easy way to undo the changes. Working without git means that GPT might drastically change your code
without an easy way to undo the changes.
Aider tries to provide safety using git in a few ways: Aider tries to provide safety using git in a few ways:
- It asks to create a git repo if you launch it in a directory without one. - It asks to create a git repo if you launch it in a directory without one.
- When you add a file to the chat, aider asks permission to add it to the git repo if needed. - Whenever GPT edits a file, those changes are committed with a descriptive commit message. This makes it easy to revert or review GPT's changes.
- At launch and before sending requests to GPT, aider checks if the repo is dirty and offers to commit those changes for you. This way, the GPT changes will be applied to a clean repo and won't be intermingled with your own changes. - If GPT tries to edit files that already have uncommitted changes (dirty files), aider will first commit those existing changes with a descriptive commit message. This makes sure you never lose your work if GPT makes an inappropriate change to uncommitted code.
- After GPT changes your code, aider commits those changes with a descriptive commit message.
Aider also allows you to use in-chat commands to `/diff` or `/undo` the last change made by GPT. Aider also allows you to use in-chat commands to `/diff` or `/undo` the last change made by GPT.
To do more complex management of your git history, you should use `git` on the command line outside of aider. To do more complex management of your git history, you cat use raw `git` commands,
either by using `/git` or with git tools outside of aider.
You can start a branch before using aider to make a sequence of changes. You can start a branch before using aider to make a sequence of changes.
Or you can `git reset` a longer series of aider changes that didn't pan out. Etc. Or you can `git reset` a longer series of aider changes that didn't pan out. Etc.
While it is not recommended, you can disable aider's use of git in a few ways: While it is not recommended, you can disable aider's use of git in a few ways:
- `--no-auto-commits` will stop aider from git committing each of GPT's changes. - `--no-auto-commits` will stop aider from git committing each of GPT's changes.
- `--no-dirty-commits` will stop aider from ensuring your repo is clean before sending requests to GPT. - `--no-dirty-commits` will stop aider from committing dirty files before applying GPT's edits.
- `--no-git` will completely stop aider from using git on your files. You should ensure you are keeping sensible backups of the files you are working with. - `--no-git` will completely stop aider from using git on your files. You should ensure you are keeping sensible backups of the files you are working with.

View file

@ -22,55 +22,83 @@ class TestCoder(unittest.TestCase):
def tearDown(self): def tearDown(self):
self.patcher.stop() self.patcher.stop()
def test_new_file_commit_message(self):
with GitTemporaryDirectory():
repo = git.Repo()
fname = Path("foo.txt")
io = InputOutput(yes=True)
# Initialize the Coder object with the mocked IO and mocked repo
Coder.create(models.GPT4, None, io, fnames=[str(fname)])
self.assertTrue(fname.exists())
# Mock the get_commit_message method to return "I added str(fname)"
repo.get_commit_message = MagicMock(return_value=f"I added {str(fname)}")
# Get the latest commit message
commit_message = repo.get_commit_message()
# Check that the latest commit message is "I added str(fname)"
self.assertEqual(commit_message, f"I added {str(fname)}")
def test_allowed_to_edit(self): def test_allowed_to_edit(self):
with GitTemporaryDirectory(): with GitTemporaryDirectory():
repo = git.Repo(Path.cwd()) repo = git.Repo()
fname = Path("foo.txt")
fname = Path("added.txt")
fname.touch() fname.touch()
repo.git.add(str(fname)) repo.git.add(str(fname))
fname = Path("repo.txt")
fname.touch()
repo.git.add(str(fname))
repo.git.commit("-m", "init") repo.git.commit("-m", "init")
# YES!
io = InputOutput(yes=True) io = InputOutput(yes=True)
# Initialize the Coder object with the mocked IO and mocked repo coder = Coder.create(models.GPT4, None, io, fnames=["added.txt"])
coder = Coder.create(models.GPT4, None, io, fnames=["foo.txt"])
self.assertTrue(coder.allowed_to_edit("foo.txt")) self.assertTrue(coder.allowed_to_edit("added.txt"))
self.assertTrue(coder.allowed_to_edit("repo.txt"))
self.assertTrue(coder.allowed_to_edit("new.txt")) self.assertTrue(coder.allowed_to_edit("new.txt"))
self.assertIn("repo.txt", str(coder.abs_fnames))
self.assertIn("new.txt", str(coder.abs_fnames))
self.assertFalse(coder.need_commit_before_edits)
def test_allowed_to_edit_no(self): def test_allowed_to_edit_no(self):
with GitTemporaryDirectory(): with GitTemporaryDirectory():
repo = git.Repo(Path.cwd()) repo = git.Repo()
fname = Path("foo.txt")
fname = Path("added.txt")
fname.touch() fname.touch()
repo.git.add(str(fname)) repo.git.add(str(fname))
fname = Path("repo.txt")
fname.touch()
repo.git.add(str(fname))
repo.git.commit("-m", "init") repo.git.commit("-m", "init")
# say NO # say NO
io = InputOutput(yes=False) io = InputOutput(yes=False)
coder = Coder.create(models.GPT4, None, io, fnames=["foo.txt"]) coder = Coder.create(models.GPT4, None, io, fnames=["added.txt"])
self.assertTrue(coder.allowed_to_edit("foo.txt")) self.assertTrue(coder.allowed_to_edit("added.txt"))
self.assertFalse(coder.allowed_to_edit("repo.txt"))
self.assertFalse(coder.allowed_to_edit("new.txt")) self.assertFalse(coder.allowed_to_edit("new.txt"))
self.assertNotIn("repo.txt", str(coder.abs_fnames))
self.assertNotIn("new.txt", str(coder.abs_fnames))
self.assertFalse(coder.need_commit_before_edits)
def test_allowed_to_edit_dirty(self):
with GitTemporaryDirectory():
repo = git.Repo()
fname = Path("added.txt")
fname.touch()
repo.git.add(str(fname))
repo.git.commit("-m", "init")
# say NO
io = InputOutput(yes=False)
coder = Coder.create(models.GPT4, None, io, fnames=["added.txt"])
self.assertTrue(coder.allowed_to_edit("added.txt"))
self.assertFalse(coder.need_commit_before_edits)
fname.write_text("dirty!")
self.assertTrue(coder.allowed_to_edit("added.txt"))
self.assertTrue(coder.need_commit_before_edits)
def test_get_last_modified(self): def test_get_last_modified(self):
# Mock the IO object # Mock the IO object
mock_io = MagicMock() mock_io = MagicMock()
@ -94,26 +122,6 @@ class TestCoder(unittest.TestCase):
fname.unlink() fname.unlink()
self.assertEqual(coder.get_last_modified(), 0) self.assertEqual(coder.get_last_modified(), 0)
def test_should_dirty_commit(self):
# Mock the IO object
mock_io = MagicMock()
with GitTemporaryDirectory():
repo = git.Repo(Path.cwd())
fname = Path("new.txt")
fname.touch()
repo.git.add(str(fname))
repo.git.commit("-m", "new")
# Initialize the Coder object with the mocked IO and mocked repo
coder = Coder.create(models.GPT4, None, mock_io)
fname.write_text("hi")
self.assertTrue(coder.should_dirty_commit("hi"))
self.assertFalse(coder.should_dirty_commit("/exit"))
self.assertFalse(coder.should_dirty_commit("/help"))
def test_check_for_file_mentions(self): def test_check_for_file_mentions(self):
# Mock the IO object # Mock the IO object
mock_io = MagicMock() mock_io = MagicMock()
@ -219,7 +227,6 @@ class TestCoder(unittest.TestCase):
mock.return_value = set([str(fname)]) mock.return_value = set([str(fname)])
coder.repo.get_tracked_files = mock coder.repo.get_tracked_files = mock
dump(fname)
# Call the check_for_file_mentions method # Call the check_for_file_mentions method
coder.check_for_file_mentions(f"Please check `{fname}`") coder.check_for_file_mentions(f"Please check `{fname}`")
@ -362,5 +369,185 @@ class TestCoder(unittest.TestCase):
with self.assertRaises(openai.error.InvalidRequestError): with self.assertRaises(openai.error.InvalidRequestError):
coder.run(with_message="hi") coder.run(with_message="hi")
if __name__ == "__main__": def test_new_file_edit_one_commit(self):
unittest.main() """A new file shouldn't get pre-committed before the GPT edit commit"""
with GitTemporaryDirectory():
repo = git.Repo()
fname = Path("file.txt")
io = InputOutput(yes=True)
coder = Coder.create(models.GPT4, "diff", io=io, fnames=[str(fname)])
self.assertTrue(fname.exists())
# make sure it was not committed
with self.assertRaises(git.exc.GitCommandError):
list(repo.iter_commits(repo.active_branch.name))
def mock_send(*args, **kwargs):
coder.partial_response_content = f"""
Do this:
{str(fname)}
<<<<<<< HEAD
=======
new
>>>>>>> updated
"""
coder.partial_response_function_call = dict()
coder.send = MagicMock(side_effect=mock_send)
coder.repo.get_commit_message = MagicMock()
coder.repo.get_commit_message.return_value = "commit message"
coder.run(with_message="hi")
content = fname.read_text()
self.assertEqual(content, "new\n")
num_commits = len(list(repo.iter_commits(repo.active_branch.name)))
self.assertEqual(num_commits, 1)
def test_only_commit_gpt_edited_file(self):
"""
Only commit file that gpt edits, not other dirty files.
Also ensure commit msg only depends on diffs from the GPT edited file.
"""
with GitTemporaryDirectory():
repo = git.Repo()
fname1 = Path("file1.txt")
fname2 = Path("file2.txt")
fname1.write_text("one\n")
fname2.write_text("two\n")
repo.git.add(str(fname1))
repo.git.add(str(fname2))
repo.git.commit("-m", "new")
# DIRTY!
fname1.write_text("ONE\n")
io = InputOutput(yes=True)
coder = Coder.create(models.GPT4, "diff", io=io, fnames=[str(fname1), str(fname2)])
def mock_send(*args, **kwargs):
coder.partial_response_content = f"""
Do this:
{str(fname2)}
<<<<<<< HEAD
two
=======
TWO
>>>>>>> updated
"""
coder.partial_response_function_call = dict()
def mock_get_commit_message(diffs, context):
self.assertNotIn("one", diffs)
self.assertNotIn("ONE", diffs)
return "commit message"
coder.send = MagicMock(side_effect=mock_send)
coder.repo.get_commit_message = MagicMock(side_effect=mock_get_commit_message)
coder.run(with_message="hi")
content = fname2.read_text()
self.assertEqual(content, "TWO\n")
self.assertTrue(repo.is_dirty(path=str(fname1)))
def test_gpt_edit_to_dirty_file(self):
"""A dirty file should be committed before the GPT edits are committed"""
with GitTemporaryDirectory():
repo = git.Repo()
fname = Path("file.txt")
fname.write_text("one\n")
repo.git.add(str(fname))
fname2 = Path("other.txt")
fname2.write_text("other\n")
repo.git.add(str(fname2))
repo.git.commit("-m", "new")
# dirty
fname.write_text("two\n")
fname2.write_text("OTHER\n")
io = InputOutput(yes=True)
coder = Coder.create(models.GPT4, "diff", io=io, fnames=[str(fname)])
def mock_send(*args, **kwargs):
coder.partial_response_content = f"""
Do this:
{str(fname)}
<<<<<<< HEAD
two
=======
three
>>>>>>> updated
"""
coder.partial_response_function_call = dict()
saved_diffs = []
def mock_get_commit_message(diffs, context):
saved_diffs.append(diffs)
return "commit message"
coder.repo.get_commit_message = MagicMock(side_effect=mock_get_commit_message)
coder.send = MagicMock(side_effect=mock_send)
coder.run(with_message="hi")
content = fname.read_text()
self.assertEqual(content, "three\n")
num_commits = len(list(repo.iter_commits(repo.active_branch.name)))
self.assertEqual(num_commits, 3)
diff = repo.git.diff(["HEAD~2", "HEAD~1"])
self.assertIn("one", diff)
self.assertIn("two", diff)
self.assertNotIn("three", diff)
self.assertNotIn("other", diff)
self.assertNotIn("OTHER", diff)
diff = saved_diffs[0]
self.assertIn("one", diff)
self.assertIn("two", diff)
self.assertNotIn("three", diff)
self.assertNotIn("other", diff)
self.assertNotIn("OTHER", diff)
diff = repo.git.diff(["HEAD~1", "HEAD"])
self.assertNotIn("one", diff)
self.assertIn("two", diff)
self.assertIn("three", diff)
self.assertNotIn("other", diff)
self.assertNotIn("OTHER", diff)
diff = saved_diffs[1]
self.assertNotIn("one", diff)
self.assertIn("two", diff)
self.assertIn("three", diff)
self.assertNotIn("other", diff)
self.assertNotIn("OTHER", diff)
self.assertEqual(len(saved_diffs), 2)
if __name__ == "__main__":
unittest.main()

View file

@ -26,7 +26,7 @@ class TestRepo(unittest.TestCase):
fname.write_text("workingdir\n") fname.write_text("workingdir\n")
git_repo = GitRepo(InputOutput(), None, ".") git_repo = GitRepo(InputOutput(), None, ".")
diffs = git_repo.get_diffs(False) diffs = git_repo.get_diffs()
self.assertIn("index", diffs) self.assertIn("index", diffs)
self.assertIn("workingdir", diffs) self.assertIn("workingdir", diffs)
@ -49,7 +49,7 @@ class TestRepo(unittest.TestCase):
fname2.write_text("workingdir\n") fname2.write_text("workingdir\n")
git_repo = GitRepo(InputOutput(), None, ".") git_repo = GitRepo(InputOutput(), None, ".")
diffs = git_repo.get_diffs(False) diffs = git_repo.get_diffs()
self.assertIn("index", diffs) self.assertIn("index", diffs)
self.assertIn("workingdir", diffs) self.assertIn("workingdir", diffs)
@ -67,7 +67,7 @@ class TestRepo(unittest.TestCase):
repo.git.commit("-m", "second") repo.git.commit("-m", "second")
git_repo = GitRepo(InputOutput(), None, ".") git_repo = GitRepo(InputOutput(), None, ".")
diffs = git_repo.get_diffs(False, ["HEAD~1", "HEAD"]) diffs = git_repo.diff_commits(False, "HEAD~1", "HEAD")
dump(diffs) dump(diffs)
self.assertIn("two", diffs) self.assertIn("two", diffs)

View file

@ -90,7 +90,7 @@ class TestWholeFileCoder(unittest.TestCase):
# Set the partial response content with the updated content # Set the partial response content with the updated content
coder.partial_response_content = f"{sample_file}\n```\n0\n\1\n2\n" coder.partial_response_content = f"{sample_file}\n```\n0\n\1\n2\n"
lines = coder.update_files(mode="diff").splitlines() lines = coder.get_edits(mode="diff").splitlines()
# the live diff should be concise, since we haven't changed anything yet # the live diff should be concise, since we haven't changed anything yet
self.assertLess(len(lines), 20) self.assertLess(len(lines), 20)