mirror of
https://github.com/Aider-AI/aider.git
synced 2025-05-23 22:04:59 +00:00
wip: Added methods to get context from history and commit message.
This commit is contained in:
parent
adbdbbcb74
commit
db2e062634
1 changed files with 45 additions and 52 deletions
|
@ -414,69 +414,22 @@ class Coder:
|
|||
|
||||
return edited
|
||||
|
||||
def commit(self, history=None, prefix=None, ask=False, message=None, which="chat_files"):
|
||||
repo = self.repo
|
||||
if not repo:
|
||||
return
|
||||
|
||||
if not repo.is_dirty():
|
||||
return
|
||||
|
||||
def get_dirty_files(file_list):
|
||||
dirty_files = []
|
||||
relative_dirty_files = []
|
||||
for fname in file_list:
|
||||
relative_fname = os.path.relpath(fname, repo.working_tree_dir)
|
||||
if self.pretty:
|
||||
these_diffs = repo.git.diff("HEAD", "--color", relative_fname)
|
||||
else:
|
||||
these_diffs = repo.git.diff("HEAD", relative_fname)
|
||||
|
||||
if these_diffs:
|
||||
dirty_files.append(fname)
|
||||
relative_dirty_files.append(relative_fname)
|
||||
|
||||
return dirty_files, relative_dirty_files
|
||||
|
||||
if which == "repo_files":
|
||||
all_files = [os.path.join(self.root, f) for f in self.get_all_relative_files()]
|
||||
dirty_fnames, relative_dirty_fnames = get_dirty_files(all_files)
|
||||
elif which == "chat_files":
|
||||
dirty_fnames, relative_dirty_fnames = get_dirty_files(self.abs_fnames)
|
||||
else:
|
||||
raise ValueError(f"Invalid value for 'which': {which}")
|
||||
|
||||
diffs = ""
|
||||
for (abs_fname,relative_fname) in zip(dirty_fnames, relative_dirty_fnames):
|
||||
if self.pretty:
|
||||
these_diffs = repo.git.diff("HEAD", "--color", relative_fname)
|
||||
else:
|
||||
these_diffs = repo.git.diff("HEAD", relative_fname)
|
||||
|
||||
diffs += these_diffs + "\n"
|
||||
|
||||
if self.show_diffs or ask:
|
||||
self.console.print(Text(diffs))
|
||||
|
||||
diffs = "# Diffs:\n" + diffs
|
||||
|
||||
# for fname in dirty_fnames:
|
||||
# self.console.print(f"[red] {fname}")
|
||||
|
||||
def get_context_from_history(self, history):
|
||||
context = ""
|
||||
if history:
|
||||
context += "# Context:\n"
|
||||
for msg in history:
|
||||
context += msg["role"].upper() + ": " + msg["content"] + "\n"
|
||||
return context
|
||||
|
||||
def get_commit_message(self, diffs, context):
|
||||
diffs = "# Diffs:\n" + diffs
|
||||
|
||||
messages = [
|
||||
dict(role="system", content=prompts.commit_system),
|
||||
dict(role="user", content=context + diffs),
|
||||
]
|
||||
|
||||
# if history:
|
||||
# self.show_messages(messages, "commit")
|
||||
|
||||
commit_message, interrupted = self.send(
|
||||
messages,
|
||||
model="gpt-3.5-turbo",
|
||||
|
@ -491,6 +444,46 @@ class Coder:
|
|||
)
|
||||
return
|
||||
|
||||
return commit_message
|
||||
|
||||
def commit(self, history=None, prefix=None, ask=False, message=None, which="chat_files"):
|
||||
repo = self.repo
|
||||
if not repo:
|
||||
return
|
||||
|
||||
if not repo.is_dirty():
|
||||
return
|
||||
|
||||
def get_dirty_files(file_list):
|
||||
diffs = ""
|
||||
relative_dirty_files = []
|
||||
for fname in file_list:
|
||||
relative_fname = os.path.relpath(fname, repo.working_tree_dir)
|
||||
if self.pretty:
|
||||
these_diffs = repo.git.diff("HEAD", "--color", relative_fname)
|
||||
else:
|
||||
these_diffs = repo.git.diff("HEAD", relative_fname)
|
||||
|
||||
if these_diffs:
|
||||
relative_dirty_files.append(relative_fname)
|
||||
diffs += these_diffs + "\n"
|
||||
|
||||
return relative_dirty_files, diffs
|
||||
|
||||
if which == "repo_files":
|
||||
all_files = [os.path.join(self.root, f) for f in self.get_all_relative_files()]
|
||||
relative_dirty_fnames,diffs = get_dirty_files(all_files)
|
||||
elif which == "chat_files":
|
||||
relative_dirty_fnames,diffs = get_dirty_files(self.abs_fnames)
|
||||
else:
|
||||
raise ValueError(f"Invalid value for 'which': {which}")
|
||||
|
||||
if self.show_diffs or ask:
|
||||
self.console.print(Text(diffs))
|
||||
|
||||
context = self.get_context_from_history(history)
|
||||
commit_message = self.get_commit_message(diffs, context)
|
||||
|
||||
if prefix:
|
||||
commit_message = prefix + commit_message
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue