mirror of
https://github.com/Aider-AI/aider.git
synced 2025-05-30 01:04:59 +00:00
wip
This commit is contained in:
parent
289887d94f
commit
23beb7cb5d
6 changed files with 87 additions and 105 deletions
|
@ -1,4 +1,11 @@
|
|||
import os
|
||||
from pathlib import Path, PurePosixPath
|
||||
|
||||
import git
|
||||
import openai
|
||||
|
||||
from aider import models, prompts, utils
|
||||
from aider.sendchat import send_with_retries
|
||||
|
||||
|
||||
class AiderRepo:
|
||||
|
@ -30,23 +37,26 @@ class AiderRepo:
|
|||
if fname.is_dir():
|
||||
continue
|
||||
|
||||
self.abs_fnames.add(str(fname))
|
||||
|
||||
num_repos = len(set(repo_paths))
|
||||
|
||||
if num_repos == 0:
|
||||
return
|
||||
raise FileNotFoundError
|
||||
if num_repos > 1:
|
||||
self.io.tool_error("Files are in different git repos.")
|
||||
return
|
||||
raise FileNotFoundError
|
||||
|
||||
# https://github.com/gitpython-developers/GitPython/issues/427
|
||||
self.repo = git.Repo(repo_paths.pop(), odbt=git.GitDB)
|
||||
|
||||
self.root = utils.safe_abs_path(self.repo.working_tree_dir)
|
||||
|
||||
def ___(self, fnames):
|
||||
|
||||
# TODO!
|
||||
|
||||
self.abs_fnames.add(str(fname))
|
||||
|
||||
new_files = []
|
||||
for fname in self.abs_fnames:
|
||||
for fname in fnames:
|
||||
relative_fname = self.get_rel_fname(fname)
|
||||
|
||||
tracked_files = set(self.get_tracked_files())
|
||||
|
@ -71,7 +81,12 @@ class AiderRepo:
|
|||
else:
|
||||
self.io.tool_error("Skipped adding new files to the git repo.")
|
||||
|
||||
def commit(self, history=None, prefix=None, ask=False, message=None, which="chat_files"):
|
||||
def commit(
|
||||
self, context=None, prefix=None, ask=False, message=None, which="chat_files", pretty=False
|
||||
):
|
||||
|
||||
## TODO!
|
||||
|
||||
repo = self.repo
|
||||
if not repo:
|
||||
return
|
||||
|
@ -96,7 +111,7 @@ class AiderRepo:
|
|||
if not current_branch_commit_count:
|
||||
continue
|
||||
|
||||
these_diffs = self.get_diffs("HEAD", "--", relative_fname)
|
||||
these_diffs = self.get_diffs(pretty, "HEAD", "--", relative_fname)
|
||||
|
||||
if these_diffs:
|
||||
diffs += these_diffs + "\n"
|
||||
|
@ -115,7 +130,6 @@ class AiderRepo:
|
|||
# don't use io.tool_output() because we don't want to log or further colorize
|
||||
print(diffs)
|
||||
|
||||
context = self.get_context_from_history(history)
|
||||
if message:
|
||||
commit_message = message
|
||||
else:
|
||||
|
@ -162,13 +176,6 @@ class AiderRepo:
|
|||
except ValueError:
|
||||
return self.repo.git_dir
|
||||
|
||||
def get_context_from_history(self, history):
|
||||
context = ""
|
||||
if history:
|
||||
for msg in history:
|
||||
context += "\n" + msg["role"].upper() + ": " + msg["content"] + "\n"
|
||||
return context
|
||||
|
||||
def get_commit_message(self, diffs, context):
|
||||
if len(diffs) >= 4 * 1024 * 4:
|
||||
self.io.tool_error(
|
||||
|
@ -184,34 +191,45 @@ class AiderRepo:
|
|||
]
|
||||
|
||||
try:
|
||||
interrupted = self.send(
|
||||
messages,
|
||||
_hash, response = send_with_retries(
|
||||
model=models.GPT35.name,
|
||||
silent=True,
|
||||
)
|
||||
except openai.error.InvalidRequestError:
|
||||
self.io.tool_error(
|
||||
f"Failed to generate commit message using {models.GPT35.name} due to an invalid"
|
||||
" request."
|
||||
messages=messages,
|
||||
functions=None,
|
||||
stream=False,
|
||||
)
|
||||
commit_message = completion.choices[0].message.content
|
||||
except (AttributeError, openai.error.InvalidRequestError):
|
||||
self.io.tool_error(f"Failed to generate commit message using {models.GPT35.name}")
|
||||
return
|
||||
|
||||
commit_message = self.partial_response_content
|
||||
commit_message = commit_message.strip()
|
||||
if commit_message and commit_message[0] == '"' and commit_message[-1] == '"':
|
||||
commit_message = commit_message[1:-1].strip()
|
||||
|
||||
if interrupted:
|
||||
self.io.tool_error(
|
||||
f"Unable to get commit message from {models.GPT35.name}. Use /commit to try again."
|
||||
)
|
||||
return
|
||||
|
||||
return commit_message
|
||||
|
||||
def get_diffs(self, *args):
|
||||
if self.pretty:
|
||||
def get_diffs(self, pretty, *args):
|
||||
if pretty:
|
||||
args = ["--color"] + list(args)
|
||||
|
||||
diffs = self.repo.git.diff(*args)
|
||||
return diffs
|
||||
|
||||
def get_tracked_files(self):
|
||||
if not self.repo:
|
||||
return []
|
||||
|
||||
try:
|
||||
commit = self.repo.head.commit
|
||||
except ValueError:
|
||||
return set()
|
||||
|
||||
files = []
|
||||
for blob in commit.tree.traverse():
|
||||
if blob.type == "blob": # blob is a file
|
||||
files.append(blob.path)
|
||||
|
||||
# convert to appropriate os.sep, since git always normalizes to /
|
||||
res = set(str(Path(PurePosixPath(path))) for path in files)
|
||||
|
||||
return res
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue