This commit is contained in:
Paul Gauthier 2023-07-21 11:49:19 -03:00
parent 289887d94f
commit 23beb7cb5d
6 changed files with 87 additions and 105 deletions

View file

@ -7,9 +7,8 @@ import sys
import time import time
import traceback import traceback
from json.decoder import JSONDecodeError from json.decoder import JSONDecodeError
from pathlib import Path, PurePosixPath from pathlib import Path
import git
import openai import openai
from jsonschema import Draft7Validator from jsonschema import Draft7Validator
from rich.console import Console, Text from rich.console import Console, Text
@ -18,6 +17,7 @@ from rich.markdown import Markdown
from aider import models, prompts, utils from aider import models, prompts, utils
from aider.commands import Commands from aider.commands import Commands
from aider.repo import AiderRepo
from aider.repomap import RepoMap from aider.repomap import RepoMap
from aider.sendchat import send_with_retries from aider.sendchat import send_with_retries
@ -149,12 +149,16 @@ class Coder:
self.commands = Commands(self.io, self) self.commands = Commands(self.io, self)
if use_git: if use_git:
self.set_repo(fnames) try:
self.repo = AiderRepo(fnames)
self.root = self.repo.root
except FileNotFoundError:
self.repo = None
else: else:
self.abs_fnames = set([str(Path(fname).resolve()) for fname in fnames]) self.abs_fnames = set([str(Path(fname).resolve()) for fname in fnames])
if self.repo: if self.repo:
rel_repo_dir = self.get_rel_repo_dir() rel_repo_dir = self.repo.get_rel_repo_dir()
self.io.tool_output(f"Git repo: {rel_repo_dir}") self.io.tool_output(f"Git repo: {rel_repo_dir}")
else: else:
self.io.tool_output("Git repo: none") self.io.tool_output("Git repo: none")
@ -376,7 +380,7 @@ class Coder:
if self.should_dirty_commit(inp): if self.should_dirty_commit(inp):
self.io.tool_output("Git repo has uncommitted changes, preparing commit...") self.io.tool_output("Git repo has uncommitted changes, preparing commit...")
self.commit(ask=True, which="repo_files") self.commit(ask=True, which="repo_files", pretty=self.pretty)
# files changed, move cur messages back behind the files messages # files changed, move cur messages back behind the files messages
self.move_back_cur_messages(self.gpt_prompts.files_content_local_edits) self.move_back_cur_messages(self.gpt_prompts.files_content_local_edits)
@ -495,8 +499,16 @@ class Coder:
) )
] ]
def get_context_from_history(self, history):
context = ""
if history:
for msg in history:
context += "\n" + msg["role"].upper() + ": " + msg["content"] + "\n"
return context
def auto_commit(self): def auto_commit(self):
res = self.commit(history=self.cur_messages, prefix="aider: ") context = self.get_context_from_history(self.cur_messages)
res = self.commit(context=context, prefix="aider: ", pretty=self.pretty)
if res: if res:
commit_hash, commit_message = res commit_hash, commit_message = res
self.last_aider_commit_hash = commit_hash self.last_aider_commit_hash = commit_hash
@ -553,7 +565,7 @@ class Coder:
return prompts.added_files.format(fnames=", ".join(mentioned_rel_fnames)) return prompts.added_files.format(fnames=", ".join(mentioned_rel_fnames))
def send(self, messages, model=None, silent=False, functions=None): def send(self, messages, model=None, functions=None):
if not model: if not model:
model = self.main_model.name model = self.main_model.name
@ -566,25 +578,24 @@ class Coder:
self.chat_completion_call_hashes.append(hash_object.hexdigest()) self.chat_completion_call_hashes.append(hash_object.hexdigest())
if self.stream: if self.stream:
self.show_send_output_stream(completion, silent) self.show_send_output_stream(completion)
else: else:
self.show_send_output(completion, silent) self.show_send_output(completion)
except KeyboardInterrupt: except KeyboardInterrupt:
self.keyboard_interrupt() self.keyboard_interrupt()
interrupted = True interrupted = True
if not silent: if self.partial_response_content:
if self.partial_response_content: self.io.ai_output(self.partial_response_content)
self.io.ai_output(self.partial_response_content) elif self.partial_response_function_call:
elif self.partial_response_function_call: # TODO: push this into subclasses
# TODO: push this into subclasses args = self.parse_partial_args()
args = self.parse_partial_args() if args:
if args: self.io.ai_output(json.dumps(args, indent=4))
self.io.ai_output(json.dumps(args, indent=4))
return interrupted return interrupted
def show_send_output(self, completion, silent): def show_send_output(self, completion):
if self.verbose: if self.verbose:
print(completion) print(completion)
@ -633,9 +644,9 @@ class Coder:
self.io.console.print(show_resp) self.io.console.print(show_resp)
self.io.console.print(tokens) self.io.console.print(tokens)
def show_send_output_stream(self, completion, silent): def show_send_output_stream(self, completion):
live = None live = None
if self.pretty and not silent: if self.pretty:
live = Live(vertical_overflow="scroll") live = Live(vertical_overflow="scroll")
try: try:
@ -664,9 +675,6 @@ class Coder:
except AttributeError: except AttributeError:
pass pass
if silent:
continue
if self.pretty: if self.pretty:
self.live_incremental_response(live, False) self.live_incremental_response(live, False)
else: else:
@ -697,7 +705,7 @@ class Coder:
def get_all_relative_files(self): def get_all_relative_files(self):
if self.repo: if self.repo:
files = self.get_tracked_files() files = self.repo.get_tracked_files()
else: else:
files = self.get_inchat_relative_files() files = self.get_inchat_relative_files()
@ -752,25 +760,6 @@ class Coder:
return full_path return full_path
def get_tracked_files(self):
if not self.repo:
return []
try:
commit = self.repo.head.commit
except ValueError:
return set()
files = []
for blob in commit.tree.traverse():
if blob.type == "blob": # blob is a file
files.append(blob.path)
# convert to appropriate os.sep, since git always normalizes to /
res = set(str(Path(PurePosixPath(path))) for path in files)
return res
apply_update_errors = 0 apply_update_errors = 0
def apply_updates(self): def apply_updates(self):

View file

@ -42,15 +42,6 @@ class SingleWholeFileFunctionCoder(Coder):
else: else:
self.cur_messages += [dict(role="assistant", content=self.partial_response_content)] self.cur_messages += [dict(role="assistant", content=self.partial_response_content)]
def get_context_from_history(self, history):
context = ""
if history:
context += "# Context:\n"
for msg in history:
if msg["role"] == "user":
context += msg["role"].upper() + ": " + msg["content"] + "\n"
return context
def render_incremental_response(self, final=False): def render_incremental_response(self, final=False):
if self.partial_response_content: if self.partial_response_content:
return self.partial_response_content return self.partial_response_content

View file

@ -20,15 +20,6 @@ class WholeFileCoder(Coder):
else: else:
self.cur_messages += [dict(role="assistant", content=self.partial_response_content)] self.cur_messages += [dict(role="assistant", content=self.partial_response_content)]
def get_context_from_history(self, history):
context = ""
if history:
context += "# Context:\n"
for msg in history:
if msg["role"] == "user":
context += msg["role"].upper() + ": " + msg["content"] + "\n"
return context
def render_incremental_response(self, final): def render_incremental_response(self, final):
try: try:
return self.update_files(mode="diff") return self.update_files(mode="diff")

View file

@ -55,15 +55,6 @@ class WholeFileFunctionCoder(Coder):
else: else:
self.cur_messages += [dict(role="assistant", content=self.partial_response_content)] self.cur_messages += [dict(role="assistant", content=self.partial_response_content)]
def get_context_from_history(self, history):
context = ""
if history:
context += "# Context:\n"
for msg in history:
if msg["role"] == "user":
context += msg["role"].upper() + ": " + msg["content"] + "\n"
return context
def render_incremental_response(self, final=False): def render_incremental_response(self, final=False):
if self.partial_response_content: if self.partial_response_content:
return self.partial_response_content return self.partial_response_content

View file

@ -215,7 +215,9 @@ class Commands:
return return
commits = f"{self.coder.last_aider_commit_hash}~1" commits = f"{self.coder.last_aider_commit_hash}~1"
diff = self.coder.get_diffs(commits, self.coder.last_aider_commit_hash) diff = self.coder.get_diffs(
commits, self.coder.last_aider_commit_hash, pretty=self.coder.pretty
)
# don't use io.tool_output() because we don't want to log or further colorize # don't use io.tool_output() because we don't want to log or further colorize
print(diff) print(diff)
@ -247,7 +249,7 @@ class Commands:
added_fnames = [] added_fnames = []
git_added = [] git_added = []
git_files = self.coder.get_tracked_files() git_files = self.coder.get_tracked_files() if self.coder.repo else []
all_matched_files = set() all_matched_files = set()
for word in args.split(): for word in args.split():

View file

@ -1,4 +1,11 @@
import os
from pathlib import Path, PurePosixPath
import git import git
import openai
from aider import models, prompts, utils
from aider.sendchat import send_with_retries
class AiderRepo: class AiderRepo:
@ -30,23 +37,26 @@ class AiderRepo:
if fname.is_dir(): if fname.is_dir():
continue continue
self.abs_fnames.add(str(fname))
num_repos = len(set(repo_paths)) num_repos = len(set(repo_paths))
if num_repos == 0: if num_repos == 0:
return raise FileNotFoundError
if num_repos > 1: if num_repos > 1:
self.io.tool_error("Files are in different git repos.") self.io.tool_error("Files are in different git repos.")
return raise FileNotFoundError
# https://github.com/gitpython-developers/GitPython/issues/427 # https://github.com/gitpython-developers/GitPython/issues/427
self.repo = git.Repo(repo_paths.pop(), odbt=git.GitDB) self.repo = git.Repo(repo_paths.pop(), odbt=git.GitDB)
self.root = utils.safe_abs_path(self.repo.working_tree_dir) self.root = utils.safe_abs_path(self.repo.working_tree_dir)
def ___(self, fnames):
# TODO!
self.abs_fnames.add(str(fname))
new_files = [] new_files = []
for fname in self.abs_fnames: for fname in fnames:
relative_fname = self.get_rel_fname(fname) relative_fname = self.get_rel_fname(fname)
tracked_files = set(self.get_tracked_files()) tracked_files = set(self.get_tracked_files())
@ -71,7 +81,12 @@ class AiderRepo:
else: else:
self.io.tool_error("Skipped adding new files to the git repo.") self.io.tool_error("Skipped adding new files to the git repo.")
def commit(self, history=None, prefix=None, ask=False, message=None, which="chat_files"): def commit(
self, context=None, prefix=None, ask=False, message=None, which="chat_files", pretty=False
):
## TODO!
repo = self.repo repo = self.repo
if not repo: if not repo:
return return
@ -96,7 +111,7 @@ class AiderRepo:
if not current_branch_commit_count: if not current_branch_commit_count:
continue continue
these_diffs = self.get_diffs("HEAD", "--", relative_fname) these_diffs = self.get_diffs(pretty, "HEAD", "--", relative_fname)
if these_diffs: if these_diffs:
diffs += these_diffs + "\n" diffs += these_diffs + "\n"
@ -115,7 +130,6 @@ class AiderRepo:
# don't use io.tool_output() because we don't want to log or further colorize # don't use io.tool_output() because we don't want to log or further colorize
print(diffs) print(diffs)
context = self.get_context_from_history(history)
if message: if message:
commit_message = message commit_message = message
else: else:
@ -162,13 +176,6 @@ class AiderRepo:
except ValueError: except ValueError:
return self.repo.git_dir return self.repo.git_dir
def get_context_from_history(self, history):
context = ""
if history:
for msg in history:
context += "\n" + msg["role"].upper() + ": " + msg["content"] + "\n"
return context
def get_commit_message(self, diffs, context): def get_commit_message(self, diffs, context):
if len(diffs) >= 4 * 1024 * 4: if len(diffs) >= 4 * 1024 * 4:
self.io.tool_error( self.io.tool_error(
@ -184,34 +191,45 @@ class AiderRepo:
] ]
try: try:
interrupted = self.send( _hash, response = send_with_retries(
messages,
model=models.GPT35.name, model=models.GPT35.name,
silent=True, messages=messages,
) functions=None,
except openai.error.InvalidRequestError: stream=False,
self.io.tool_error(
f"Failed to generate commit message using {models.GPT35.name} due to an invalid"
" request."
) )
commit_message = completion.choices[0].message.content
except (AttributeError, openai.error.InvalidRequestError):
self.io.tool_error(f"Failed to generate commit message using {models.GPT35.name}")
return return
commit_message = self.partial_response_content
commit_message = commit_message.strip() commit_message = commit_message.strip()
if commit_message and commit_message[0] == '"' and commit_message[-1] == '"': if commit_message and commit_message[0] == '"' and commit_message[-1] == '"':
commit_message = commit_message[1:-1].strip() commit_message = commit_message[1:-1].strip()
if interrupted:
self.io.tool_error(
f"Unable to get commit message from {models.GPT35.name}. Use /commit to try again."
)
return
return commit_message return commit_message
def get_diffs(self, *args): def get_diffs(self, pretty, *args):
if self.pretty: if pretty:
args = ["--color"] + list(args) args = ["--color"] + list(args)
diffs = self.repo.git.diff(*args) diffs = self.repo.git.diff(*args)
return diffs return diffs
def get_tracked_files(self):
if not self.repo:
return []
try:
commit = self.repo.head.commit
except ValueError:
return set()
files = []
for blob in commit.tree.traverse():
if blob.type == "blob": # blob is a file
files.append(blob.path)
# convert to appropriate os.sep, since git always normalizes to /
res = set(str(Path(PurePosixPath(path))) for path in files)
return res