mirror of
https://github.com/Aider-AI/aider.git
synced 2025-05-23 22:04:59 +00:00
wip: Added methods to get context from history and commit message.
This commit is contained in:
parent
adbdbbcb74
commit
db2e062634
1 changed files with 45 additions and 52 deletions
|
@ -414,69 +414,22 @@ class Coder:
|
||||||
|
|
||||||
return edited
|
return edited
|
||||||
|
|
||||||
def commit(self, history=None, prefix=None, ask=False, message=None, which="chat_files"):
|
def get_context_from_history(self, history):
|
||||||
repo = self.repo
|
|
||||||
if not repo:
|
|
||||||
return
|
|
||||||
|
|
||||||
if not repo.is_dirty():
|
|
||||||
return
|
|
||||||
|
|
||||||
def get_dirty_files(file_list):
|
|
||||||
dirty_files = []
|
|
||||||
relative_dirty_files = []
|
|
||||||
for fname in file_list:
|
|
||||||
relative_fname = os.path.relpath(fname, repo.working_tree_dir)
|
|
||||||
if self.pretty:
|
|
||||||
these_diffs = repo.git.diff("HEAD", "--color", relative_fname)
|
|
||||||
else:
|
|
||||||
these_diffs = repo.git.diff("HEAD", relative_fname)
|
|
||||||
|
|
||||||
if these_diffs:
|
|
||||||
dirty_files.append(fname)
|
|
||||||
relative_dirty_files.append(relative_fname)
|
|
||||||
|
|
||||||
return dirty_files, relative_dirty_files
|
|
||||||
|
|
||||||
if which == "repo_files":
|
|
||||||
all_files = [os.path.join(self.root, f) for f in self.get_all_relative_files()]
|
|
||||||
dirty_fnames, relative_dirty_fnames = get_dirty_files(all_files)
|
|
||||||
elif which == "chat_files":
|
|
||||||
dirty_fnames, relative_dirty_fnames = get_dirty_files(self.abs_fnames)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid value for 'which': {which}")
|
|
||||||
|
|
||||||
diffs = ""
|
|
||||||
for (abs_fname,relative_fname) in zip(dirty_fnames, relative_dirty_fnames):
|
|
||||||
if self.pretty:
|
|
||||||
these_diffs = repo.git.diff("HEAD", "--color", relative_fname)
|
|
||||||
else:
|
|
||||||
these_diffs = repo.git.diff("HEAD", relative_fname)
|
|
||||||
|
|
||||||
diffs += these_diffs + "\n"
|
|
||||||
|
|
||||||
if self.show_diffs or ask:
|
|
||||||
self.console.print(Text(diffs))
|
|
||||||
|
|
||||||
diffs = "# Diffs:\n" + diffs
|
|
||||||
|
|
||||||
# for fname in dirty_fnames:
|
|
||||||
# self.console.print(f"[red] {fname}")
|
|
||||||
|
|
||||||
context = ""
|
context = ""
|
||||||
if history:
|
if history:
|
||||||
context += "# Context:\n"
|
context += "# Context:\n"
|
||||||
for msg in history:
|
for msg in history:
|
||||||
context += msg["role"].upper() + ": " + msg["content"] + "\n"
|
context += msg["role"].upper() + ": " + msg["content"] + "\n"
|
||||||
|
return context
|
||||||
|
|
||||||
|
def get_commit_message(self, diffs, context):
|
||||||
|
diffs = "# Diffs:\n" + diffs
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
dict(role="system", content=prompts.commit_system),
|
dict(role="system", content=prompts.commit_system),
|
||||||
dict(role="user", content=context + diffs),
|
dict(role="user", content=context + diffs),
|
||||||
]
|
]
|
||||||
|
|
||||||
# if history:
|
|
||||||
# self.show_messages(messages, "commit")
|
|
||||||
|
|
||||||
commit_message, interrupted = self.send(
|
commit_message, interrupted = self.send(
|
||||||
messages,
|
messages,
|
||||||
model="gpt-3.5-turbo",
|
model="gpt-3.5-turbo",
|
||||||
|
@ -491,6 +444,46 @@ class Coder:
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
return commit_message
|
||||||
|
|
||||||
|
def commit(self, history=None, prefix=None, ask=False, message=None, which="chat_files"):
|
||||||
|
repo = self.repo
|
||||||
|
if not repo:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not repo.is_dirty():
|
||||||
|
return
|
||||||
|
|
||||||
|
def get_dirty_files(file_list):
|
||||||
|
diffs = ""
|
||||||
|
relative_dirty_files = []
|
||||||
|
for fname in file_list:
|
||||||
|
relative_fname = os.path.relpath(fname, repo.working_tree_dir)
|
||||||
|
if self.pretty:
|
||||||
|
these_diffs = repo.git.diff("HEAD", "--color", relative_fname)
|
||||||
|
else:
|
||||||
|
these_diffs = repo.git.diff("HEAD", relative_fname)
|
||||||
|
|
||||||
|
if these_diffs:
|
||||||
|
relative_dirty_files.append(relative_fname)
|
||||||
|
diffs += these_diffs + "\n"
|
||||||
|
|
||||||
|
return relative_dirty_files, diffs
|
||||||
|
|
||||||
|
if which == "repo_files":
|
||||||
|
all_files = [os.path.join(self.root, f) for f in self.get_all_relative_files()]
|
||||||
|
relative_dirty_fnames,diffs = get_dirty_files(all_files)
|
||||||
|
elif which == "chat_files":
|
||||||
|
relative_dirty_fnames,diffs = get_dirty_files(self.abs_fnames)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid value for 'which': {which}")
|
||||||
|
|
||||||
|
if self.show_diffs or ask:
|
||||||
|
self.console.print(Text(diffs))
|
||||||
|
|
||||||
|
context = self.get_context_from_history(history)
|
||||||
|
commit_message = self.get_commit_message(diffs, context)
|
||||||
|
|
||||||
if prefix:
|
if prefix:
|
||||||
commit_message = prefix + commit_message
|
commit_message = prefix + commit_message
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue