Merge pull request #207 from paul-gauthier/late-bind-commits

Only commit dirty files if GPT tries to edit them #200
This commit is contained in:
paul-gauthier 2023-08-18 10:43:06 -07:00 committed by GitHub
commit 9933ad85b0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 419 additions and 217 deletions

View file

@ -2,6 +2,7 @@
### main branch
- [Only git commit dirty files that GPT tries to edit](https://aider.chat/docs/faq.html#how-does-aider-use-git)
- Send chat history as prompt/context for Whisper voice transcription
- Added `--voice-language` switch to constrain `/voice` to transcribe to a specific language

View file

@ -57,13 +57,7 @@ class Coder:
io,
**kwargs,
):
from . import (
EditBlockCoder,
EditBlockFunctionCoder,
SingleWholeFileFunctionCoder,
WholeFileCoder,
WholeFileFunctionCoder,
)
from . import EditBlockCoder, WholeFileCoder
if not main_model:
main_model = models.GPT35_16k
@ -84,14 +78,6 @@ class Coder:
return EditBlockCoder(main_model, io, **kwargs)
elif edit_format == "whole":
return WholeFileCoder(main_model, io, **kwargs)
elif edit_format == "whole-func":
return WholeFileFunctionCoder(main_model, io, **kwargs)
elif edit_format == "single-whole-func":
return SingleWholeFileFunctionCoder(main_model, io, **kwargs)
elif edit_format == "diff-func-list":
return EditBlockFunctionCoder("list", main_model, io, **kwargs)
elif edit_format in ("diff-func", "diff-func-string"):
return EditBlockFunctionCoder("string", main_model, io, **kwargs)
else:
raise ValueError(f"Unknown edit format {edit_format}")
@ -119,6 +105,7 @@ class Coder:
self.chat_completion_call_hashes = []
self.chat_completion_response_hashes = []
self.need_commit_before_edits = set()
self.verbose = verbose
self.abs_fnames = set()
@ -203,9 +190,6 @@ class Coder:
for fname in self.get_inchat_relative_files():
self.io.tool_output(f"Added {fname} to the chat.")
if self.repo:
self.repo.add_new_files(fname for fname in fnames if not Path(fname).is_dir())
self.summarizer = ChatSummary()
self.summarizer_thread = None
self.summarized_done_messages = None
@ -408,11 +392,6 @@ class Coder:
self.commands,
)
if self.should_dirty_commit(inp) and self.dirty_commit():
if inp.strip():
self.io.tool_output("Use up-arrow to retry previous command:", inp)
return
if not inp:
return
@ -500,7 +479,7 @@ class Coder:
if edited:
if self.repo and self.auto_commits and not self.dry_run:
saved_message = self.auto_commit()
saved_message = self.auto_commit(edited)
elif hasattr(self.gpt_prompts, "files_content_gpt_edits_no_repo"):
saved_message = self.gpt_prompts.files_content_gpt_edits_no_repo
else:
@ -728,43 +707,94 @@ class Coder:
def get_addable_relative_files(self):
return set(self.get_all_relative_files()) - set(self.get_inchat_relative_files())
def allowed_to_edit(self, path, write_content=None):
full_path = self.abs_root_path(path)
if full_path in self.abs_fnames:
if write_content:
self.io.write_text(full_path, write_content)
return full_path
if not Path(full_path).exists():
question = f"Allow creation of new file {path}?" # noqa: E501
else:
question = f"Allow edits to {path} which was not previously provided?" # noqa: E501
if not self.io.confirm_ask(question):
self.io.tool_error(f"Skipping edit to {path}")
def check_for_dirty_commit(self, path):
if not self.repo:
return
if not self.dirty_commits:
return
if not self.repo.is_dirty(path):
return
if not Path(full_path).exists() and not self.dry_run:
Path(full_path).parent.mkdir(parents=True, exist_ok=True)
Path(full_path).touch()
fullp = Path(self.abs_root_path(path))
if not fullp.stat().st_size:
return
self.abs_fnames.add(full_path)
self.io.tool_output(f"Committing {path} before applying edits.")
self.need_commit_before_edits.add(path)
return
# Check if the file is already in the repo
def allowed_to_edit(self, path):
full_path = self.abs_root_path(path)
if self.repo:
tracked_files = set(self.repo.get_tracked_files())
relative_fname = self.get_rel_fname(full_path)
if relative_fname not in tracked_files and self.io.confirm_ask(f"Add {path} to git?"):
if not self.dry_run:
need_to_add = not self.repo.path_in_repo(path)
else:
need_to_add = False
if full_path in self.abs_fnames:
self.check_for_dirty_commit(path)
return True
if not Path(full_path).exists():
if not self.io.confirm_ask(f"Allow creation of new file {path}?"):
self.io.tool_error(f"Skipping edits to {path}")
return
if not self.dry_run:
Path(full_path).parent.mkdir(parents=True, exist_ok=True)
Path(full_path).touch()
# Seems unlikely that we needed to create the file, but it was
# actually already part of the repo.
# But let's only add if we need to, just to be safe.
if need_to_add:
self.repo.repo.git.add(full_path)
if write_content:
self.io.write_text(full_path, write_content)
self.abs_fnames.add(full_path)
return True
return full_path
if not self.io.confirm_ask(
f"Allow edits to {path} which was not previously added to chat?"
):
self.io.tool_error(f"Skipping edits to {path}")
return
if need_to_add:
self.repo.repo.git.add(full_path)
self.abs_fnames.add(full_path)
self.check_for_dirty_commit(path)
return True
apply_update_errors = 0
def prepare_to_edit(self, edits):
res = []
seen = dict()
self.need_commit_before_edits = set()
for edit in edits:
path = edit[0]
if path in seen:
allowed = seen[path]
else:
allowed = self.allowed_to_edit(path)
seen[path] = allowed
if allowed:
res.append(edit)
self.dirty_commit()
self.need_commit_before_edits = set()
return res
def update_files(self):
edits = self.get_edits()
edits = self.prepare_to_edit(edits)
self.apply_edits(edits)
return set(edit[0] for edit in edits)
def apply_updates(self):
max_apply_update_errors = 3
@ -795,12 +825,11 @@ class Coder:
self.apply_update_errors = 0
if edited:
for path in sorted(edited):
if self.dry_run:
self.io.tool_output(f"Did not apply edit to {path} (--dry-run)")
else:
self.io.tool_output(f"Applied edit to {path}")
for path in edited:
if self.dry_run:
self.io.tool_output(f"Did not apply edit to {path} (--dry-run)")
else:
self.io.tool_output(f"Applied edit to {path}")
return edited, None
@ -840,9 +869,9 @@ class Coder:
context += "\n" + msg["role"].upper() + ": " + msg["content"] + "\n"
return context
def auto_commit(self):
def auto_commit(self, edited):
context = self.get_context_from_history(self.cur_messages)
res = self.repo.commit(context=context, prefix="aider: ")
res = self.repo.commit(fnames=edited, context=context, prefix="aider: ")
if res:
commit_hash, commit_message = res
self.last_aider_commit_hash = commit_hash
@ -855,43 +884,14 @@ class Coder:
self.io.tool_output("No changes made to git tracked files.")
return self.gpt_prompts.files_content_gpt_no_edits
def should_dirty_commit(self, inp):
cmds = self.commands.matching_commands(inp)
if cmds:
matching_commands, _, _ = cmds
if len(matching_commands) == 1:
cmd = matching_commands[0][1:]
if cmd in "add clear commit diff drop exit help ls tokens".split():
return
if self.last_asked_for_commit_time >= self.get_last_modified():
return
return True
def dirty_commit(self):
if not self.need_commit_before_edits:
return
if not self.dirty_commits:
return
if not self.repo:
return
if not self.repo.is_dirty():
return
self.io.tool_output("Git repo has uncommitted changes.")
self.repo.show_diffs(self.pretty)
self.last_asked_for_commit_time = self.get_last_modified()
res = self.io.prompt_ask(
"Commit before the chat proceeds [y/n/commit message]?",
default="y",
).strip()
if res.lower() in ["n", "no"]:
self.io.tool_error("Skipped commmit.")
return
if res.lower() in ["y", "yes"]:
message = None
else:
message = res.strip()
self.repo.commit(message=message)
self.repo.commit(fnames=self.need_commit_before_edits)
# files changed, move cur messages back behind the files messages
self.move_back_cur_messages(self.gpt_prompts.files_content_local_edits)

View file

@ -13,22 +13,21 @@ class EditBlockCoder(Coder):
self.gpt_prompts = EditBlockPrompts()
super().__init__(*args, **kwargs)
def update_files(self):
def get_edits(self):
content = self.partial_response_content
# might raise ValueError for malformed ORIG/UPD blocks
edits = list(find_original_update_blocks(content))
edited = set()
return edits
def apply_edits(self, edits):
for path, original, updated in edits:
full_path = self.allowed_to_edit(path)
if not full_path:
continue
full_path = self.abs_root_path(path)
content = self.io.read_text(full_path)
content = do_replace(full_path, content, original, updated)
if content:
self.io.write_text(full_path, content)
edited.add(path)
continue
raise ValueError(f"""InvalidEditBlock: edit failed!
@ -42,8 +41,6 @@ The HEAD block needs to be EXACTLY the same as the lines in {path} with nothing
{original}```
""")
return edited
def prep(content):
if content and not content.endswith("\n"):

View file

@ -58,6 +58,7 @@ class EditBlockFunctionCoder(Coder):
]
def __init__(self, code_format, *args, **kwargs):
raise RuntimeError("Deprecated, needs to be refactored to support get_edits/apply_edits")
self.code_format = code_format
if code_format == "string":
@ -91,7 +92,7 @@ class EditBlockFunctionCoder(Coder):
res = json.dumps(args, indent=4)
return res
def update_files(self):
def _update_files(self):
name = self.partial_response_function_call.get("name")
if name and name != "replace_lines":

View file

@ -31,6 +31,7 @@ class SingleWholeFileFunctionCoder(Coder):
]
def __init__(self, *args, **kwargs):
raise RuntimeError("Deprecated, needs to be refactored to support get_edits/apply_edits")
self.gpt_prompts = SingleWholeFileFunctionPrompts()
super().__init__(*args, **kwargs)
@ -94,7 +95,7 @@ class SingleWholeFileFunctionCoder(Coder):
return "\n".join(show_diff)
def update_files(self):
def _update_files(self):
name = self.partial_response_function_call.get("name")
if name and name != "write_file":
raise ValueError(f'Unknown function_call name="{name}", use name="write_file"')

View file

@ -22,11 +22,11 @@ class WholeFileCoder(Coder):
def render_incremental_response(self, final):
try:
return self.update_files(mode="diff")
return self.get_edits(mode="diff")
except ValueError:
return self.partial_response_content
def update_files(self, mode="update"):
def get_edits(self, mode="update"):
content = self.partial_response_content
chat_files = self.get_inchat_relative_files()
@ -46,7 +46,7 @@ class WholeFileCoder(Coder):
# ending an existing block
saw_fname = None
full_path = (Path(self.root) / fname).absolute()
full_path = self.abs_root_path(fname)
if mode == "diff":
output += self.do_live_diff(full_path, new_lines, True)
@ -104,25 +104,30 @@ class WholeFileCoder(Coder):
if fname:
edits.append((fname, fname_source, new_lines))
edited = set()
seen = set()
refined_edits = []
# process from most reliable filename, to least reliable
for source in ("block", "saw", "chat"):
for fname, fname_source, new_lines in edits:
if fname_source != source:
continue
# if a higher priority source already edited the file, skip
if fname in edited:
if fname in seen:
continue
# we have a winner
new_lines = "".join(new_lines)
if self.allowed_to_edit(fname, new_lines):
edited.add(fname)
seen.add(fname)
refined_edits.append((fname, fname_source, new_lines))
return edited
return refined_edits
def apply_edits(self, edits):
for path, fname_source, new_lines in edits:
full_path = self.abs_root_path(path)
new_lines = "".join(new_lines)
self.io.write_text(full_path, new_lines)
def do_live_diff(self, full_path, new_lines, final):
if full_path.exists():
if Path(full_path).exists():
orig_lines = self.io.read_text(full_path).splitlines(keepends=True)
show_diff = diffs.diff_partial_update(

View file

@ -44,6 +44,8 @@ class WholeFileFunctionCoder(Coder):
]
def __init__(self, *args, **kwargs):
raise RuntimeError("Deprecated, needs to be refactored to support get_edits/apply_edits")
self.gpt_prompts = WholeFileFunctionPrompts()
super().__init__(*args, **kwargs)
@ -105,7 +107,7 @@ class WholeFileFunctionCoder(Coder):
return "\n".join(show_diff)
def update_files(self):
def _update_files(self):
name = self.partial_response_function_call.get("name")
if name and name != "write_file":
raise ValueError(f'Unknown function_call name="{name}", use name="write_file"')

View file

@ -230,7 +230,7 @@ class Commands:
return
commits = f"{self.coder.last_aider_commit_hash}~1"
diff = self.coder.repo.get_diffs(
diff = self.coder.repo.diff_commits(
self.coder.pretty,
commits,
self.coder.last_aider_commit_hash,

View file

@ -522,8 +522,6 @@ def main(argv=None, input=None, output=None, force_git_root=None):
io.tool_output("Use /help to see in-chat commands, run with --help to see cmd line args")
coder.dirty_commit()
if args.message:
io.tool_output()
coder.run(with_message=args.message)

View file

@ -49,24 +49,14 @@ class GitRepo:
self.repo = git.Repo(repo_paths.pop(), odbt=git.GitDB)
self.root = utils.safe_abs_path(self.repo.working_tree_dir)
def add_new_files(self, fnames):
cur_files = [str(Path(fn).resolve()) for fn in self.get_tracked_files()]
for fname in fnames:
if str(Path(fname).resolve()) in cur_files:
continue
if not Path(fname).exists():
continue
self.io.tool_output(f"Adding {fname} to git")
self.repo.git.add(fname)
def commit(self, context=None, prefix=None, message=None):
if not self.repo.is_dirty():
def commit(self, fnames=None, context=None, prefix=None, message=None):
if not fnames and not self.repo.is_dirty():
return
if message:
commit_message = message
else:
diffs = self.get_diffs(False)
diffs = self.get_diffs(fnames)
commit_message = self.get_commit_message(diffs, context)
if not commit_message:
@ -79,7 +69,16 @@ class GitRepo:
if context:
full_commit_message += "\n\n# Aider chat conversation:\n\n" + context
self.repo.git.commit("-a", "-m", full_commit_message, "--no-verify")
cmd = ["-m", full_commit_message, "--no-verify"]
if fnames:
fnames = [str(self.abs_root_path(fn)) for fn in fnames]
for fname in fnames:
self.repo.git.add(fname)
cmd += ["--"] + fnames
else:
cmd += ["-a"]
self.repo.git.commit(cmd)
commit_hash = self.repo.head.commit.hexsha[:7]
self.io.tool_output(f"Commit {commit_hash} {commit_message}")
@ -125,41 +124,38 @@ class GitRepo:
return commit_message
def get_diffs(self, pretty, *args):
args = list(args)
# if args are specified, just add --pretty if needed
if args:
if pretty:
args = ["--color"] + args
return self.repo.git.diff(*args)
# otherwise, we always want diffs of index and working dir
def get_diffs(self, fnames=None):
# We always want diffs of index and working dir
try:
commits = self.repo.iter_commits(self.repo.active_branch)
current_branch_has_commits = any(commits)
except git.exc.GitCommandError:
current_branch_has_commits = False
if pretty:
args = ["--color"]
if not fnames:
fnames = []
if current_branch_has_commits:
# if there is a HEAD, just diff against it to pick up index + working
args += ["HEAD"]
args = ["HEAD", "--"] + list(fnames)
return self.repo.git.diff(*args)
# diffs in the index
diffs = self.repo.git.diff(*(args + ["--cached"]))
# plus, diffs in the working dir
diffs += self.repo.git.diff(*args)
wd_args = ["--"] + list(fnames)
index_args = ["--cached"] + wd_args
diffs = self.repo.git.diff(*index_args)
diffs += self.repo.git.diff(*wd_args)
return diffs
def show_diffs(self, pretty):
diffs = self.get_diffs(pretty)
print(diffs)
def diff_commits(self, pretty, from_commit, to_commit):
args = []
if pretty:
args += ["--color"]
args += [from_commit, to_commit]
diffs = self.repo.git.diff(*args)
return diffs
def get_tracked_files(self):
if not self.repo:
@ -190,5 +186,19 @@ class GitRepo:
return res
def is_dirty(self):
return self.repo.is_dirty()
def path_in_repo(self, path):
if not self.repo:
return
tracked_files = set(self.get_tracked_files())
return path in tracked_files
def abs_root_path(self, path):
res = Path(self.root) / path
return utils.safe_abs_path(res)
def is_dirty(self, path=None):
if path and not self.path_in_repo(path):
return True
return self.repo.is_dirty(path=path)

View file

@ -13,29 +13,29 @@
It is recommended that you use aider with code that is part of a git repo.
This allows aider to maintain the safety of your code. Using git makes it easy to:
- Review the changes GPT made to your code
- Undo changes that weren't appropriate
- Undo any changes that weren't appropriate
- Go back later to review the changes GPT made to your code
- Manage a series of GPT's changes on a git branch
- etc
Working without git means that GPT might drastically change your code without an easy way to undo the changes.
Working without git means that GPT might drastically change your code
without an easy way to undo the changes.
Aider tries to provide safety using git in a few ways:
- It asks to create a git repo if you launch it in a directory without one.
- When you add a file to the chat, aider asks permission to add it to the git repo if needed.
- At launch and before sending requests to GPT, aider checks if the repo is dirty and offers to commit those changes for you. This way, the GPT changes will be applied to a clean repo and won't be intermingled with your own changes.
- After GPT changes your code, aider commits those changes with a descriptive commit message.
- Whenever GPT edits a file, those changes are committed with a descriptive commit message. This makes it easy to revert or review GPT's changes.
- If GPT tries to edit files that already have uncommitted changes (dirty files), aider will first commit those existing changes with a descriptive commit message. This makes sure you never lose your work if GPT makes an inappropriate change to uncommitted code.
Aider also allows you to use in-chat commands to `/diff` or `/undo` the last change made by GPT.
To do more complex management of your git history, you should use `git` on the command line outside of aider.
To do more complex management of your git history, you cat use raw `git` commands,
either by using `/git` or with git tools outside of aider.
You can start a branch before using aider to make a sequence of changes.
Or you can `git reset` a longer series of aider changes that didn't pan out. Etc.
While it is not recommended, you can disable aider's use of git in a few ways:
- `--no-auto-commits` will stop aider from git committing each of GPT's changes.
- `--no-dirty-commits` will stop aider from ensuring your repo is clean before sending requests to GPT.
- `--no-dirty-commits` will stop aider from committing dirty files before applying GPT's edits.
- `--no-git` will completely stop aider from using git on your files. You should ensure you are keeping sensible backups of the files you are working with.

View file

@ -22,55 +22,83 @@ class TestCoder(unittest.TestCase):
def tearDown(self):
self.patcher.stop()
def test_new_file_commit_message(self):
with GitTemporaryDirectory():
repo = git.Repo()
fname = Path("foo.txt")
io = InputOutput(yes=True)
# Initialize the Coder object with the mocked IO and mocked repo
Coder.create(models.GPT4, None, io, fnames=[str(fname)])
self.assertTrue(fname.exists())
# Mock the get_commit_message method to return "I added str(fname)"
repo.get_commit_message = MagicMock(return_value=f"I added {str(fname)}")
# Get the latest commit message
commit_message = repo.get_commit_message()
# Check that the latest commit message is "I added str(fname)"
self.assertEqual(commit_message, f"I added {str(fname)}")
def test_allowed_to_edit(self):
with GitTemporaryDirectory():
repo = git.Repo(Path.cwd())
fname = Path("foo.txt")
repo = git.Repo()
fname = Path("added.txt")
fname.touch()
repo.git.add(str(fname))
fname = Path("repo.txt")
fname.touch()
repo.git.add(str(fname))
repo.git.commit("-m", "init")
# YES!
io = InputOutput(yes=True)
# Initialize the Coder object with the mocked IO and mocked repo
coder = Coder.create(models.GPT4, None, io, fnames=["foo.txt"])
coder = Coder.create(models.GPT4, None, io, fnames=["added.txt"])
self.assertTrue(coder.allowed_to_edit("foo.txt"))
self.assertTrue(coder.allowed_to_edit("added.txt"))
self.assertTrue(coder.allowed_to_edit("repo.txt"))
self.assertTrue(coder.allowed_to_edit("new.txt"))
self.assertIn("repo.txt", str(coder.abs_fnames))
self.assertIn("new.txt", str(coder.abs_fnames))
self.assertFalse(coder.need_commit_before_edits)
def test_allowed_to_edit_no(self):
with GitTemporaryDirectory():
repo = git.Repo(Path.cwd())
fname = Path("foo.txt")
repo = git.Repo()
fname = Path("added.txt")
fname.touch()
repo.git.add(str(fname))
fname = Path("repo.txt")
fname.touch()
repo.git.add(str(fname))
repo.git.commit("-m", "init")
# say NO
io = InputOutput(yes=False)
coder = Coder.create(models.GPT4, None, io, fnames=["foo.txt"])
coder = Coder.create(models.GPT4, None, io, fnames=["added.txt"])
self.assertTrue(coder.allowed_to_edit("foo.txt"))
self.assertTrue(coder.allowed_to_edit("added.txt"))
self.assertFalse(coder.allowed_to_edit("repo.txt"))
self.assertFalse(coder.allowed_to_edit("new.txt"))
self.assertNotIn("repo.txt", str(coder.abs_fnames))
self.assertNotIn("new.txt", str(coder.abs_fnames))
self.assertFalse(coder.need_commit_before_edits)
def test_allowed_to_edit_dirty(self):
with GitTemporaryDirectory():
repo = git.Repo()
fname = Path("added.txt")
fname.touch()
repo.git.add(str(fname))
repo.git.commit("-m", "init")
# say NO
io = InputOutput(yes=False)
coder = Coder.create(models.GPT4, None, io, fnames=["added.txt"])
self.assertTrue(coder.allowed_to_edit("added.txt"))
self.assertFalse(coder.need_commit_before_edits)
fname.write_text("dirty!")
self.assertTrue(coder.allowed_to_edit("added.txt"))
self.assertTrue(coder.need_commit_before_edits)
def test_get_last_modified(self):
# Mock the IO object
mock_io = MagicMock()
@ -94,26 +122,6 @@ class TestCoder(unittest.TestCase):
fname.unlink()
self.assertEqual(coder.get_last_modified(), 0)
def test_should_dirty_commit(self):
# Mock the IO object
mock_io = MagicMock()
with GitTemporaryDirectory():
repo = git.Repo(Path.cwd())
fname = Path("new.txt")
fname.touch()
repo.git.add(str(fname))
repo.git.commit("-m", "new")
# Initialize the Coder object with the mocked IO and mocked repo
coder = Coder.create(models.GPT4, None, mock_io)
fname.write_text("hi")
self.assertTrue(coder.should_dirty_commit("hi"))
self.assertFalse(coder.should_dirty_commit("/exit"))
self.assertFalse(coder.should_dirty_commit("/help"))
def test_check_for_file_mentions(self):
# Mock the IO object
mock_io = MagicMock()
@ -219,7 +227,6 @@ class TestCoder(unittest.TestCase):
mock.return_value = set([str(fname)])
coder.repo.get_tracked_files = mock
dump(fname)
# Call the check_for_file_mentions method
coder.check_for_file_mentions(f"Please check `{fname}`")
@ -362,5 +369,185 @@ class TestCoder(unittest.TestCase):
with self.assertRaises(openai.error.InvalidRequestError):
coder.run(with_message="hi")
if __name__ == "__main__":
unittest.main()
def test_new_file_edit_one_commit(self):
"""A new file shouldn't get pre-committed before the GPT edit commit"""
with GitTemporaryDirectory():
repo = git.Repo()
fname = Path("file.txt")
io = InputOutput(yes=True)
coder = Coder.create(models.GPT4, "diff", io=io, fnames=[str(fname)])
self.assertTrue(fname.exists())
# make sure it was not committed
with self.assertRaises(git.exc.GitCommandError):
list(repo.iter_commits(repo.active_branch.name))
def mock_send(*args, **kwargs):
coder.partial_response_content = f"""
Do this:
{str(fname)}
<<<<<<< HEAD
=======
new
>>>>>>> updated
"""
coder.partial_response_function_call = dict()
coder.send = MagicMock(side_effect=mock_send)
coder.repo.get_commit_message = MagicMock()
coder.repo.get_commit_message.return_value = "commit message"
coder.run(with_message="hi")
content = fname.read_text()
self.assertEqual(content, "new\n")
num_commits = len(list(repo.iter_commits(repo.active_branch.name)))
self.assertEqual(num_commits, 1)
def test_only_commit_gpt_edited_file(self):
"""
Only commit file that gpt edits, not other dirty files.
Also ensure commit msg only depends on diffs from the GPT edited file.
"""
with GitTemporaryDirectory():
repo = git.Repo()
fname1 = Path("file1.txt")
fname2 = Path("file2.txt")
fname1.write_text("one\n")
fname2.write_text("two\n")
repo.git.add(str(fname1))
repo.git.add(str(fname2))
repo.git.commit("-m", "new")
# DIRTY!
fname1.write_text("ONE\n")
io = InputOutput(yes=True)
coder = Coder.create(models.GPT4, "diff", io=io, fnames=[str(fname1), str(fname2)])
def mock_send(*args, **kwargs):
coder.partial_response_content = f"""
Do this:
{str(fname2)}
<<<<<<< HEAD
two
=======
TWO
>>>>>>> updated
"""
coder.partial_response_function_call = dict()
def mock_get_commit_message(diffs, context):
self.assertNotIn("one", diffs)
self.assertNotIn("ONE", diffs)
return "commit message"
coder.send = MagicMock(side_effect=mock_send)
coder.repo.get_commit_message = MagicMock(side_effect=mock_get_commit_message)
coder.run(with_message="hi")
content = fname2.read_text()
self.assertEqual(content, "TWO\n")
self.assertTrue(repo.is_dirty(path=str(fname1)))
def test_gpt_edit_to_dirty_file(self):
"""A dirty file should be committed before the GPT edits are committed"""
with GitTemporaryDirectory():
repo = git.Repo()
fname = Path("file.txt")
fname.write_text("one\n")
repo.git.add(str(fname))
fname2 = Path("other.txt")
fname2.write_text("other\n")
repo.git.add(str(fname2))
repo.git.commit("-m", "new")
# dirty
fname.write_text("two\n")
fname2.write_text("OTHER\n")
io = InputOutput(yes=True)
coder = Coder.create(models.GPT4, "diff", io=io, fnames=[str(fname)])
def mock_send(*args, **kwargs):
coder.partial_response_content = f"""
Do this:
{str(fname)}
<<<<<<< HEAD
two
=======
three
>>>>>>> updated
"""
coder.partial_response_function_call = dict()
saved_diffs = []
def mock_get_commit_message(diffs, context):
saved_diffs.append(diffs)
return "commit message"
coder.repo.get_commit_message = MagicMock(side_effect=mock_get_commit_message)
coder.send = MagicMock(side_effect=mock_send)
coder.run(with_message="hi")
content = fname.read_text()
self.assertEqual(content, "three\n")
num_commits = len(list(repo.iter_commits(repo.active_branch.name)))
self.assertEqual(num_commits, 3)
diff = repo.git.diff(["HEAD~2", "HEAD~1"])
self.assertIn("one", diff)
self.assertIn("two", diff)
self.assertNotIn("three", diff)
self.assertNotIn("other", diff)
self.assertNotIn("OTHER", diff)
diff = saved_diffs[0]
self.assertIn("one", diff)
self.assertIn("two", diff)
self.assertNotIn("three", diff)
self.assertNotIn("other", diff)
self.assertNotIn("OTHER", diff)
diff = repo.git.diff(["HEAD~1", "HEAD"])
self.assertNotIn("one", diff)
self.assertIn("two", diff)
self.assertIn("three", diff)
self.assertNotIn("other", diff)
self.assertNotIn("OTHER", diff)
diff = saved_diffs[1]
self.assertNotIn("one", diff)
self.assertIn("two", diff)
self.assertIn("three", diff)
self.assertNotIn("other", diff)
self.assertNotIn("OTHER", diff)
self.assertEqual(len(saved_diffs), 2)
if __name__ == "__main__":
unittest.main()

View file

@ -26,7 +26,7 @@ class TestRepo(unittest.TestCase):
fname.write_text("workingdir\n")
git_repo = GitRepo(InputOutput(), None, ".")
diffs = git_repo.get_diffs(False)
diffs = git_repo.get_diffs()
self.assertIn("index", diffs)
self.assertIn("workingdir", diffs)
@ -49,7 +49,7 @@ class TestRepo(unittest.TestCase):
fname2.write_text("workingdir\n")
git_repo = GitRepo(InputOutput(), None, ".")
diffs = git_repo.get_diffs(False)
diffs = git_repo.get_diffs()
self.assertIn("index", diffs)
self.assertIn("workingdir", diffs)
@ -67,7 +67,7 @@ class TestRepo(unittest.TestCase):
repo.git.commit("-m", "second")
git_repo = GitRepo(InputOutput(), None, ".")
diffs = git_repo.get_diffs(False, ["HEAD~1", "HEAD"])
diffs = git_repo.diff_commits(False, "HEAD~1", "HEAD")
dump(diffs)
self.assertIn("two", diffs)

View file

@ -90,7 +90,7 @@ class TestWholeFileCoder(unittest.TestCase):
# Set the partial response content with the updated content
coder.partial_response_content = f"{sample_file}\n```\n0\n\1\n2\n"
lines = coder.update_files(mode="diff").splitlines()
lines = coder.get_edits(mode="diff").splitlines()
# the live diff should be concise, since we haven't changed anything yet
self.assertLess(len(lines), 20)