This commit is contained in:
Paul Gauthier 2023-06-20 16:59:55 -07:00
parent 02c9a30c45
commit f1350f169f
5 changed files with 167 additions and 154 deletions

View file

@ -14,7 +14,7 @@ from rich.console import Console
from rich.live import Live
from rich.markdown import Markdown
from aider import diffs, models, prompts, utils
from aider import models, prompts, utils
from aider.commands import Commands
from aider.repomap import RepoMap
@ -36,9 +36,10 @@ class Coder:
def create(self, edit_format, *args, **kwargs):
from . import EditBlockCoder, WholeFileCoder
if edit_format == "whole":
if edit_format == "diff":
dump("here")
return EditBlockCoder(*args, **kwargs)
elif edit_format == "diff":
elif edit_format == "whole":
return WholeFileCoder(*args, **kwargs)
else:
raise ValueError(f"Unknown edit format {edit_format}")
@ -363,16 +364,7 @@ class Coder:
if edit_error:
return edit_error
if not (self.edit_format == "whole" and edited):
# Don't add assistant messages to the history if they contain "edits"
# from the "whole" edit format.
# Because those edits are actually fully copies of the file!
# That wastes too much context window.
self.cur_messages += [dict(role="assistant", content=content)]
else:
self.cur_messages += [
dict(role="assistant", content=self.gpt_prompts.redacted_edit_message)
]
self.update_cur_messages(content, edited)
if edited:
if self.auto_commits:
@ -498,12 +490,7 @@ class Coder:
continue
if self.pretty:
show_resp = self.resp
if self.edit_format == "whole":
try:
show_resp = self.update_files_gpt35(self.resp, mode="diff")
except ValueError:
pass
show_resp = self.modify_incremental_response(self.resp)
md = Markdown(
show_resp, style=self.assistant_output_color, code_theme="default"
)
@ -515,126 +502,8 @@ class Coder:
if live:
live.stop()
def update_files_gpt35(self, content, mode="update"):
edited = set()
chat_files = self.get_inchat_relative_files()
if not chat_files:
if mode == "diff":
return content
return
output = []
lines = content.splitlines(keepends=True)
fname = None
new_lines = []
for i, line in enumerate(lines):
if line.startswith("```"):
if fname:
# ending an existing block
full_path = os.path.abspath(os.path.join(self.root, fname))
if mode == "diff":
with open(full_path, "r") as f:
orig_lines = f.readlines()
show_diff = diffs.diff_partial_update(
orig_lines,
new_lines,
final=True,
).splitlines()
output += show_diff
else:
new_lines = "".join(new_lines)
Path(full_path).write_text(new_lines)
edited.add(fname)
fname = None
new_lines = []
continue
# starting a new block
if i == 0:
raise ValueError("No filename provided before ``` block")
fname = lines[i - 1].strip()
if fname not in chat_files:
if len(chat_files) == 1:
fname = list(chat_files)[0]
else:
show_chat_files = " ".join(chat_files)
raise ValueError(f"{fname} is not one of: {show_chat_files}")
elif fname:
new_lines.append(line)
else:
output.append(line)
if mode == "diff":
if fname:
# ending an existing block
full_path = os.path.abspath(os.path.join(self.root, fname))
if mode == "diff":
with open(full_path, "r") as f:
orig_lines = f.readlines()
show_diff = diffs.diff_partial_update(
orig_lines,
new_lines,
).splitlines()
output += show_diff
return "\n".join(output)
if fname:
raise ValueError("Started a ``` block without closing it")
return edited
def update_files_gpt4(self, content):
# might raise ValueError for malformed ORIG/UPD blocks
edits = list(utils.find_original_update_blocks(content))
edited = set()
for path, original, updated in edits:
full_path = os.path.abspath(os.path.join(self.root, path))
if full_path not in self.abs_fnames:
if not Path(full_path).exists():
question = f"Allow creation of new file {path}?" # noqa: E501
else:
question = (
f"Allow edits to {path} which was not previously provided?" # noqa: E501
)
if not self.io.confirm_ask(question):
self.io.tool_error(f"Skipping edit to {path}")
continue
if not Path(full_path).exists():
Path(full_path).parent.mkdir(parents=True, exist_ok=True)
Path(full_path).touch()
self.abs_fnames.add(full_path)
# Check if the file is already in the repo
if self.repo:
tracked_files = set(self.repo.git.ls_files().splitlines())
relative_fname = self.get_rel_fname(full_path)
if relative_fname not in tracked_files and self.io.confirm_ask(
f"Add {path} to git?"
):
self.repo.git.add(full_path)
edited.add(path)
if utils.do_replace(full_path, original, updated, self.dry_run):
if self.dry_run:
self.io.tool_output(f"Dry run, did not apply edit to {path}")
else:
self.io.tool_output(f"Applied edit to {path}")
else:
self.io.tool_error(f"Failed to apply edit to {path}")
return edited
def modify_incremental_response(self, resp):
return resp
def get_context_from_history(self, history):
context = ""
@ -805,15 +674,8 @@ class Coder:
return set(self.get_all_relative_files()) - set(self.get_inchat_relative_files())
def apply_updates(self, content):
if self.edit_format == "diff":
method = self.update_files_gpt4
elif self.edit_format == "whole":
method = self.update_files_gpt35
else:
raise ValueError(f"apply_updates() doesn't support {self.main_model.name}")
try:
edited = method(content)
edited = self.update_files(content)
return edited, None
except ValueError as err:
err = err.args[0]

View file

@ -1,8 +1,61 @@
import os
from pathlib import Path
from aider import utils
from ..editors import EditBlockPrompts
from .base import Coder
class EditBlockCoder(Coder):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.gpt_prompts = EditBlockPrompts()
super().__init__(*args, **kwargs)
def update_cur_messages(self, content, edited):
self.cur_messages += [dict(role="assistant", content=content)]
def update_files(self, content):
# might raise ValueError for malformed ORIG/UPD blocks
edits = list(utils.find_original_update_blocks(content))
edited = set()
for path, original, updated in edits:
full_path = os.path.abspath(os.path.join(self.root, path))
if full_path not in self.abs_fnames:
if not Path(full_path).exists():
question = f"Allow creation of new file {path}?" # noqa: E501
else:
question = (
f"Allow edits to {path} which was not previously provided?" # noqa: E501
)
if not self.io.confirm_ask(question):
self.io.tool_error(f"Skipping edit to {path}")
continue
if not Path(full_path).exists():
Path(full_path).parent.mkdir(parents=True, exist_ok=True)
Path(full_path).touch()
self.abs_fnames.add(full_path)
# Check if the file is already in the repo
if self.repo:
tracked_files = set(self.repo.git.ls_files().splitlines())
relative_fname = self.get_rel_fname(full_path)
if relative_fname not in tracked_files and self.io.confirm_ask(
f"Add {path} to git?"
):
self.repo.git.add(full_path)
edited.add(path)
if utils.do_replace(full_path, original, updated, self.dry_run):
if self.dry_run:
self.io.tool_output(f"Dry run, did not apply edit to {path}")
else:
self.io.tool_output(f"Applied edit to {path}")
else:
self.io.tool_error(f"Failed to apply edit to {path}")
return edited

View file

@ -1,8 +1,100 @@
import os
from pathlib import Path
from aider import diffs
from ..editors import WholeFilePrompts
from .base import Coder
class WholeFileCoder(Coder):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.gpt_prompts = WholeFilePrompts()
super().__init__(*args, **kwargs)
def update_cur_messages(self, content, edited):
if edited:
self.cur_messages += [
dict(role="assistant", content=self.gpt_prompts.redacted_edit_message)
]
else:
self.cur_messages += [dict(role="assistant", content=content)]
def modify_incremental_response(self, resp):
return self.update_files(resp, mode="diff")
def update_files(self, content, mode="update"):
edited = set()
chat_files = self.get_inchat_relative_files()
if not chat_files:
if mode == "diff":
return content
return
output = []
lines = content.splitlines(keepends=True)
fname = None
new_lines = []
for i, line in enumerate(lines):
if line.startswith("```"):
if fname:
# ending an existing block
full_path = os.path.abspath(os.path.join(self.root, fname))
if mode == "diff":
with open(full_path, "r") as f:
orig_lines = f.readlines()
show_diff = diffs.diff_partial_update(
orig_lines,
new_lines,
final=True,
).splitlines()
output += show_diff
else:
new_lines = "".join(new_lines)
Path(full_path).write_text(new_lines)
edited.add(fname)
fname = None
new_lines = []
continue
# starting a new block
if i == 0:
raise ValueError("No filename provided before ``` block")
fname = lines[i - 1].strip()
if fname not in chat_files:
if len(chat_files) == 1:
fname = list(chat_files)[0]
else:
show_chat_files = " ".join(chat_files)
raise ValueError(f"{fname} is not one of: {show_chat_files}")
elif fname:
new_lines.append(line)
else:
output.append(line)
if mode == "diff":
if fname:
# ending an existing block
full_path = os.path.abspath(os.path.join(self.root, fname))
if mode == "diff":
with open(full_path, "r") as f:
orig_lines = f.readlines()
show_diff = diffs.diff_partial_update(
orig_lines,
new_lines,
).splitlines()
output += show_diff
return "\n".join(output)
if fname:
raise ValueError("Started a ``` block without closing it")
return edited

View file

@ -1,18 +1,21 @@
import argparse
import os
import re
import subprocess
from packaging import version
def main():
parser = argparse.ArgumentParser(description="Bump version")
parser.add_argument("new_version", help="New version in x.y.z format")
parser.add_argument("--dry-run", action="store_true", help="Print each step without actually executing them")
parser.add_argument(
"--dry-run", action="store_true", help="Print each step without actually executing them"
)
args = parser.parse_args()
dry_run = args.dry_run
new_version_str = args.new_version
if not re.match(r'^\d+\.\d+\.\d+$', new_version_str):
if not re.match(r"^\d+\.\d+\.\d+$", new_version_str):
raise ValueError(f"Invalid version format, must be x.y.z: {new_version_str}")
new_version = version.parse(new_version_str)
@ -22,7 +25,9 @@ def main():
current_version = re.search(r'__version__ = "(.+?)"', content).group(1)
if new_version <= version.parse(current_version):
raise ValueError(f"New version {new_version} must be greater than the current version {current_version}")
raise ValueError(
f"New version {new_version} must be greater than the current version {current_version}"
)
updated_content = re.sub(r'__version__ = ".+?"', f'__version__ = "{new_version}"', content)
@ -47,5 +52,6 @@ def main():
else:
subprocess.run(cmd, check=True)
if __name__ == "__main__":
main()

View file

@ -200,7 +200,6 @@ These changes replace the `subprocess.run` patches with `subprocess.check_output
self.assertEqual(edit_blocks[0][0], "tests/test_repomap.py")
self.assertEqual(edit_blocks[1][0], "tests/test_repomap.py")
def test_replace_part_with_missing_leading_whitespace(self):
whole = " line1\n line2\n line3\n"
part = "line1\nline2"
@ -210,5 +209,6 @@ These changes replace the `subprocess.run` patches with `subprocess.check_output
result = utils.replace_part_with_missing_leading_whitespace(whole, part, replace)
self.assertEqual(result, expected_output)
if __name__ == "__main__":
unittest.main()