This commit is contained in:
Paul Gauthier 2023-07-21 11:49:19 -03:00
parent 289887d94f
commit 23beb7cb5d
6 changed files with 87 additions and 105 deletions

View file

@ -7,9 +7,8 @@ import sys
import time
import traceback
from json.decoder import JSONDecodeError
from pathlib import Path, PurePosixPath
from pathlib import Path
import git
import openai
from jsonschema import Draft7Validator
from rich.console import Console, Text
@ -18,6 +17,7 @@ from rich.markdown import Markdown
from aider import models, prompts, utils
from aider.commands import Commands
from aider.repo import AiderRepo
from aider.repomap import RepoMap
from aider.sendchat import send_with_retries
@ -149,12 +149,16 @@ class Coder:
self.commands = Commands(self.io, self)
if use_git:
self.set_repo(fnames)
try:
self.repo = AiderRepo(fnames)
self.root = self.repo.root
except FileNotFoundError:
self.repo = None
else:
self.abs_fnames = set([str(Path(fname).resolve()) for fname in fnames])
if self.repo:
rel_repo_dir = self.get_rel_repo_dir()
rel_repo_dir = self.repo.get_rel_repo_dir()
self.io.tool_output(f"Git repo: {rel_repo_dir}")
else:
self.io.tool_output("Git repo: none")
@ -376,7 +380,7 @@ class Coder:
if self.should_dirty_commit(inp):
self.io.tool_output("Git repo has uncommitted changes, preparing commit...")
self.commit(ask=True, which="repo_files")
self.commit(ask=True, which="repo_files", pretty=self.pretty)
# files changed, move cur messages back behind the files messages
self.move_back_cur_messages(self.gpt_prompts.files_content_local_edits)
@ -495,8 +499,16 @@ class Coder:
)
]
def get_context_from_history(self, history):
context = ""
if history:
for msg in history:
context += "\n" + msg["role"].upper() + ": " + msg["content"] + "\n"
return context
def auto_commit(self):
res = self.commit(history=self.cur_messages, prefix="aider: ")
context = self.get_context_from_history(self.cur_messages)
res = self.commit(context=context, prefix="aider: ", pretty=self.pretty)
if res:
commit_hash, commit_message = res
self.last_aider_commit_hash = commit_hash
@ -553,7 +565,7 @@ class Coder:
return prompts.added_files.format(fnames=", ".join(mentioned_rel_fnames))
def send(self, messages, model=None, silent=False, functions=None):
def send(self, messages, model=None, functions=None):
if not model:
model = self.main_model.name
@ -566,25 +578,24 @@ class Coder:
self.chat_completion_call_hashes.append(hash_object.hexdigest())
if self.stream:
self.show_send_output_stream(completion, silent)
self.show_send_output_stream(completion)
else:
self.show_send_output(completion, silent)
self.show_send_output(completion)
except KeyboardInterrupt:
self.keyboard_interrupt()
interrupted = True
if not silent:
if self.partial_response_content:
self.io.ai_output(self.partial_response_content)
elif self.partial_response_function_call:
# TODO: push this into subclasses
args = self.parse_partial_args()
if args:
self.io.ai_output(json.dumps(args, indent=4))
if self.partial_response_content:
self.io.ai_output(self.partial_response_content)
elif self.partial_response_function_call:
# TODO: push this into subclasses
args = self.parse_partial_args()
if args:
self.io.ai_output(json.dumps(args, indent=4))
return interrupted
def show_send_output(self, completion, silent):
def show_send_output(self, completion):
if self.verbose:
print(completion)
@ -633,9 +644,9 @@ class Coder:
self.io.console.print(show_resp)
self.io.console.print(tokens)
def show_send_output_stream(self, completion, silent):
def show_send_output_stream(self, completion):
live = None
if self.pretty and not silent:
if self.pretty:
live = Live(vertical_overflow="scroll")
try:
@ -664,9 +675,6 @@ class Coder:
except AttributeError:
pass
if silent:
continue
if self.pretty:
self.live_incremental_response(live, False)
else:
@ -697,7 +705,7 @@ class Coder:
def get_all_relative_files(self):
if self.repo:
files = self.get_tracked_files()
files = self.repo.get_tracked_files()
else:
files = self.get_inchat_relative_files()
@ -752,25 +760,6 @@ class Coder:
return full_path
def get_tracked_files(self):
if not self.repo:
return []
try:
commit = self.repo.head.commit
except ValueError:
return set()
files = []
for blob in commit.tree.traverse():
if blob.type == "blob": # blob is a file
files.append(blob.path)
# convert to appropriate os.sep, since git always normalizes to /
res = set(str(Path(PurePosixPath(path))) for path in files)
return res
apply_update_errors = 0
def apply_updates(self):

View file

@ -42,15 +42,6 @@ class SingleWholeFileFunctionCoder(Coder):
else:
self.cur_messages += [dict(role="assistant", content=self.partial_response_content)]
def get_context_from_history(self, history):
context = ""
if history:
context += "# Context:\n"
for msg in history:
if msg["role"] == "user":
context += msg["role"].upper() + ": " + msg["content"] + "\n"
return context
def render_incremental_response(self, final=False):
if self.partial_response_content:
return self.partial_response_content

View file

@ -20,15 +20,6 @@ class WholeFileCoder(Coder):
else:
self.cur_messages += [dict(role="assistant", content=self.partial_response_content)]
def get_context_from_history(self, history):
context = ""
if history:
context += "# Context:\n"
for msg in history:
if msg["role"] == "user":
context += msg["role"].upper() + ": " + msg["content"] + "\n"
return context
def render_incremental_response(self, final):
try:
return self.update_files(mode="diff")

View file

@ -55,15 +55,6 @@ class WholeFileFunctionCoder(Coder):
else:
self.cur_messages += [dict(role="assistant", content=self.partial_response_content)]
def get_context_from_history(self, history):
context = ""
if history:
context += "# Context:\n"
for msg in history:
if msg["role"] == "user":
context += msg["role"].upper() + ": " + msg["content"] + "\n"
return context
def render_incremental_response(self, final=False):
if self.partial_response_content:
return self.partial_response_content

View file

@ -215,7 +215,9 @@ class Commands:
return
commits = f"{self.coder.last_aider_commit_hash}~1"
diff = self.coder.get_diffs(commits, self.coder.last_aider_commit_hash)
diff = self.coder.get_diffs(
commits, self.coder.last_aider_commit_hash, pretty=self.coder.pretty
)
# don't use io.tool_output() because we don't want to log or further colorize
print(diff)
@ -247,7 +249,7 @@ class Commands:
added_fnames = []
git_added = []
git_files = self.coder.get_tracked_files()
git_files = self.coder.get_tracked_files() if self.coder.repo else []
all_matched_files = set()
for word in args.split():

View file

@ -1,4 +1,11 @@
import os
from pathlib import Path, PurePosixPath
import git
import openai
from aider import models, prompts, utils
from aider.sendchat import send_with_retries
class AiderRepo:
@ -30,23 +37,26 @@ class AiderRepo:
if fname.is_dir():
continue
self.abs_fnames.add(str(fname))
num_repos = len(set(repo_paths))
if num_repos == 0:
return
raise FileNotFoundError
if num_repos > 1:
self.io.tool_error("Files are in different git repos.")
return
raise FileNotFoundError
# https://github.com/gitpython-developers/GitPython/issues/427
self.repo = git.Repo(repo_paths.pop(), odbt=git.GitDB)
self.root = utils.safe_abs_path(self.repo.working_tree_dir)
def ___(self, fnames):
# TODO!
self.abs_fnames.add(str(fname))
new_files = []
for fname in self.abs_fnames:
for fname in fnames:
relative_fname = self.get_rel_fname(fname)
tracked_files = set(self.get_tracked_files())
@ -71,7 +81,12 @@ class AiderRepo:
else:
self.io.tool_error("Skipped adding new files to the git repo.")
def commit(self, history=None, prefix=None, ask=False, message=None, which="chat_files"):
def commit(
self, context=None, prefix=None, ask=False, message=None, which="chat_files", pretty=False
):
## TODO!
repo = self.repo
if not repo:
return
@ -96,7 +111,7 @@ class AiderRepo:
if not current_branch_commit_count:
continue
these_diffs = self.get_diffs("HEAD", "--", relative_fname)
these_diffs = self.get_diffs(pretty, "HEAD", "--", relative_fname)
if these_diffs:
diffs += these_diffs + "\n"
@ -115,7 +130,6 @@ class AiderRepo:
# don't use io.tool_output() because we don't want to log or further colorize
print(diffs)
context = self.get_context_from_history(history)
if message:
commit_message = message
else:
@ -162,13 +176,6 @@ class AiderRepo:
except ValueError:
return self.repo.git_dir
def get_context_from_history(self, history):
context = ""
if history:
for msg in history:
context += "\n" + msg["role"].upper() + ": " + msg["content"] + "\n"
return context
def get_commit_message(self, diffs, context):
if len(diffs) >= 4 * 1024 * 4:
self.io.tool_error(
@ -184,34 +191,45 @@ class AiderRepo:
]
try:
interrupted = self.send(
messages,
_hash, response = send_with_retries(
model=models.GPT35.name,
silent=True,
)
except openai.error.InvalidRequestError:
self.io.tool_error(
f"Failed to generate commit message using {models.GPT35.name} due to an invalid"
" request."
messages=messages,
functions=None,
stream=False,
)
commit_message = completion.choices[0].message.content
except (AttributeError, openai.error.InvalidRequestError):
self.io.tool_error(f"Failed to generate commit message using {models.GPT35.name}")
return
commit_message = self.partial_response_content
commit_message = commit_message.strip()
if commit_message and commit_message[0] == '"' and commit_message[-1] == '"':
commit_message = commit_message[1:-1].strip()
if interrupted:
self.io.tool_error(
f"Unable to get commit message from {models.GPT35.name}. Use /commit to try again."
)
return
return commit_message
def get_diffs(self, *args):
if self.pretty:
def get_diffs(self, pretty, *args):
if pretty:
args = ["--color"] + list(args)
diffs = self.repo.git.diff(*args)
return diffs
def get_tracked_files(self):
if not self.repo:
return []
try:
commit = self.repo.head.commit
except ValueError:
return set()
files = []
for blob in commit.tree.traverse():
if blob.type == "blob": # blob is a file
files.append(blob.path)
# convert to appropriate os.sep, since git always normalizes to /
res = set(str(Path(PurePosixPath(path))) for path in files)
return res