mirror of
https://github.com/Aider-AI/aider.git
synced 2025-05-28 00:05:01 +00:00
wip
This commit is contained in:
parent
289887d94f
commit
23beb7cb5d
6 changed files with 87 additions and 105 deletions
|
@ -7,9 +7,8 @@ import sys
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from json.decoder import JSONDecodeError
|
from json.decoder import JSONDecodeError
|
||||||
from pathlib import Path, PurePosixPath
|
from pathlib import Path
|
||||||
|
|
||||||
import git
|
|
||||||
import openai
|
import openai
|
||||||
from jsonschema import Draft7Validator
|
from jsonschema import Draft7Validator
|
||||||
from rich.console import Console, Text
|
from rich.console import Console, Text
|
||||||
|
@ -18,6 +17,7 @@ from rich.markdown import Markdown
|
||||||
|
|
||||||
from aider import models, prompts, utils
|
from aider import models, prompts, utils
|
||||||
from aider.commands import Commands
|
from aider.commands import Commands
|
||||||
|
from aider.repo import AiderRepo
|
||||||
from aider.repomap import RepoMap
|
from aider.repomap import RepoMap
|
||||||
from aider.sendchat import send_with_retries
|
from aider.sendchat import send_with_retries
|
||||||
|
|
||||||
|
@ -149,12 +149,16 @@ class Coder:
|
||||||
self.commands = Commands(self.io, self)
|
self.commands = Commands(self.io, self)
|
||||||
|
|
||||||
if use_git:
|
if use_git:
|
||||||
self.set_repo(fnames)
|
try:
|
||||||
|
self.repo = AiderRepo(fnames)
|
||||||
|
self.root = self.repo.root
|
||||||
|
except FileNotFoundError:
|
||||||
|
self.repo = None
|
||||||
else:
|
else:
|
||||||
self.abs_fnames = set([str(Path(fname).resolve()) for fname in fnames])
|
self.abs_fnames = set([str(Path(fname).resolve()) for fname in fnames])
|
||||||
|
|
||||||
if self.repo:
|
if self.repo:
|
||||||
rel_repo_dir = self.get_rel_repo_dir()
|
rel_repo_dir = self.repo.get_rel_repo_dir()
|
||||||
self.io.tool_output(f"Git repo: {rel_repo_dir}")
|
self.io.tool_output(f"Git repo: {rel_repo_dir}")
|
||||||
else:
|
else:
|
||||||
self.io.tool_output("Git repo: none")
|
self.io.tool_output("Git repo: none")
|
||||||
|
@ -376,7 +380,7 @@ class Coder:
|
||||||
|
|
||||||
if self.should_dirty_commit(inp):
|
if self.should_dirty_commit(inp):
|
||||||
self.io.tool_output("Git repo has uncommitted changes, preparing commit...")
|
self.io.tool_output("Git repo has uncommitted changes, preparing commit...")
|
||||||
self.commit(ask=True, which="repo_files")
|
self.commit(ask=True, which="repo_files", pretty=self.pretty)
|
||||||
|
|
||||||
# files changed, move cur messages back behind the files messages
|
# files changed, move cur messages back behind the files messages
|
||||||
self.move_back_cur_messages(self.gpt_prompts.files_content_local_edits)
|
self.move_back_cur_messages(self.gpt_prompts.files_content_local_edits)
|
||||||
|
@ -495,8 +499,16 @@ class Coder:
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def get_context_from_history(self, history):
|
||||||
|
context = ""
|
||||||
|
if history:
|
||||||
|
for msg in history:
|
||||||
|
context += "\n" + msg["role"].upper() + ": " + msg["content"] + "\n"
|
||||||
|
return context
|
||||||
|
|
||||||
def auto_commit(self):
|
def auto_commit(self):
|
||||||
res = self.commit(history=self.cur_messages, prefix="aider: ")
|
context = self.get_context_from_history(self.cur_messages)
|
||||||
|
res = self.commit(context=context, prefix="aider: ", pretty=self.pretty)
|
||||||
if res:
|
if res:
|
||||||
commit_hash, commit_message = res
|
commit_hash, commit_message = res
|
||||||
self.last_aider_commit_hash = commit_hash
|
self.last_aider_commit_hash = commit_hash
|
||||||
|
@ -553,7 +565,7 @@ class Coder:
|
||||||
|
|
||||||
return prompts.added_files.format(fnames=", ".join(mentioned_rel_fnames))
|
return prompts.added_files.format(fnames=", ".join(mentioned_rel_fnames))
|
||||||
|
|
||||||
def send(self, messages, model=None, silent=False, functions=None):
|
def send(self, messages, model=None, functions=None):
|
||||||
if not model:
|
if not model:
|
||||||
model = self.main_model.name
|
model = self.main_model.name
|
||||||
|
|
||||||
|
@ -566,25 +578,24 @@ class Coder:
|
||||||
self.chat_completion_call_hashes.append(hash_object.hexdigest())
|
self.chat_completion_call_hashes.append(hash_object.hexdigest())
|
||||||
|
|
||||||
if self.stream:
|
if self.stream:
|
||||||
self.show_send_output_stream(completion, silent)
|
self.show_send_output_stream(completion)
|
||||||
else:
|
else:
|
||||||
self.show_send_output(completion, silent)
|
self.show_send_output(completion)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
self.keyboard_interrupt()
|
self.keyboard_interrupt()
|
||||||
interrupted = True
|
interrupted = True
|
||||||
|
|
||||||
if not silent:
|
if self.partial_response_content:
|
||||||
if self.partial_response_content:
|
self.io.ai_output(self.partial_response_content)
|
||||||
self.io.ai_output(self.partial_response_content)
|
elif self.partial_response_function_call:
|
||||||
elif self.partial_response_function_call:
|
# TODO: push this into subclasses
|
||||||
# TODO: push this into subclasses
|
args = self.parse_partial_args()
|
||||||
args = self.parse_partial_args()
|
if args:
|
||||||
if args:
|
self.io.ai_output(json.dumps(args, indent=4))
|
||||||
self.io.ai_output(json.dumps(args, indent=4))
|
|
||||||
|
|
||||||
return interrupted
|
return interrupted
|
||||||
|
|
||||||
def show_send_output(self, completion, silent):
|
def show_send_output(self, completion):
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(completion)
|
print(completion)
|
||||||
|
|
||||||
|
@ -633,9 +644,9 @@ class Coder:
|
||||||
self.io.console.print(show_resp)
|
self.io.console.print(show_resp)
|
||||||
self.io.console.print(tokens)
|
self.io.console.print(tokens)
|
||||||
|
|
||||||
def show_send_output_stream(self, completion, silent):
|
def show_send_output_stream(self, completion):
|
||||||
live = None
|
live = None
|
||||||
if self.pretty and not silent:
|
if self.pretty:
|
||||||
live = Live(vertical_overflow="scroll")
|
live = Live(vertical_overflow="scroll")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -664,9 +675,6 @@ class Coder:
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if silent:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if self.pretty:
|
if self.pretty:
|
||||||
self.live_incremental_response(live, False)
|
self.live_incremental_response(live, False)
|
||||||
else:
|
else:
|
||||||
|
@ -697,7 +705,7 @@ class Coder:
|
||||||
|
|
||||||
def get_all_relative_files(self):
|
def get_all_relative_files(self):
|
||||||
if self.repo:
|
if self.repo:
|
||||||
files = self.get_tracked_files()
|
files = self.repo.get_tracked_files()
|
||||||
else:
|
else:
|
||||||
files = self.get_inchat_relative_files()
|
files = self.get_inchat_relative_files()
|
||||||
|
|
||||||
|
@ -752,25 +760,6 @@ class Coder:
|
||||||
|
|
||||||
return full_path
|
return full_path
|
||||||
|
|
||||||
def get_tracked_files(self):
|
|
||||||
if not self.repo:
|
|
||||||
return []
|
|
||||||
|
|
||||||
try:
|
|
||||||
commit = self.repo.head.commit
|
|
||||||
except ValueError:
|
|
||||||
return set()
|
|
||||||
|
|
||||||
files = []
|
|
||||||
for blob in commit.tree.traverse():
|
|
||||||
if blob.type == "blob": # blob is a file
|
|
||||||
files.append(blob.path)
|
|
||||||
|
|
||||||
# convert to appropriate os.sep, since git always normalizes to /
|
|
||||||
res = set(str(Path(PurePosixPath(path))) for path in files)
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
apply_update_errors = 0
|
apply_update_errors = 0
|
||||||
|
|
||||||
def apply_updates(self):
|
def apply_updates(self):
|
||||||
|
|
|
@ -42,15 +42,6 @@ class SingleWholeFileFunctionCoder(Coder):
|
||||||
else:
|
else:
|
||||||
self.cur_messages += [dict(role="assistant", content=self.partial_response_content)]
|
self.cur_messages += [dict(role="assistant", content=self.partial_response_content)]
|
||||||
|
|
||||||
def get_context_from_history(self, history):
|
|
||||||
context = ""
|
|
||||||
if history:
|
|
||||||
context += "# Context:\n"
|
|
||||||
for msg in history:
|
|
||||||
if msg["role"] == "user":
|
|
||||||
context += msg["role"].upper() + ": " + msg["content"] + "\n"
|
|
||||||
return context
|
|
||||||
|
|
||||||
def render_incremental_response(self, final=False):
|
def render_incremental_response(self, final=False):
|
||||||
if self.partial_response_content:
|
if self.partial_response_content:
|
||||||
return self.partial_response_content
|
return self.partial_response_content
|
||||||
|
|
|
@ -20,15 +20,6 @@ class WholeFileCoder(Coder):
|
||||||
else:
|
else:
|
||||||
self.cur_messages += [dict(role="assistant", content=self.partial_response_content)]
|
self.cur_messages += [dict(role="assistant", content=self.partial_response_content)]
|
||||||
|
|
||||||
def get_context_from_history(self, history):
|
|
||||||
context = ""
|
|
||||||
if history:
|
|
||||||
context += "# Context:\n"
|
|
||||||
for msg in history:
|
|
||||||
if msg["role"] == "user":
|
|
||||||
context += msg["role"].upper() + ": " + msg["content"] + "\n"
|
|
||||||
return context
|
|
||||||
|
|
||||||
def render_incremental_response(self, final):
|
def render_incremental_response(self, final):
|
||||||
try:
|
try:
|
||||||
return self.update_files(mode="diff")
|
return self.update_files(mode="diff")
|
||||||
|
|
|
@ -55,15 +55,6 @@ class WholeFileFunctionCoder(Coder):
|
||||||
else:
|
else:
|
||||||
self.cur_messages += [dict(role="assistant", content=self.partial_response_content)]
|
self.cur_messages += [dict(role="assistant", content=self.partial_response_content)]
|
||||||
|
|
||||||
def get_context_from_history(self, history):
|
|
||||||
context = ""
|
|
||||||
if history:
|
|
||||||
context += "# Context:\n"
|
|
||||||
for msg in history:
|
|
||||||
if msg["role"] == "user":
|
|
||||||
context += msg["role"].upper() + ": " + msg["content"] + "\n"
|
|
||||||
return context
|
|
||||||
|
|
||||||
def render_incremental_response(self, final=False):
|
def render_incremental_response(self, final=False):
|
||||||
if self.partial_response_content:
|
if self.partial_response_content:
|
||||||
return self.partial_response_content
|
return self.partial_response_content
|
||||||
|
|
|
@ -215,7 +215,9 @@ class Commands:
|
||||||
return
|
return
|
||||||
|
|
||||||
commits = f"{self.coder.last_aider_commit_hash}~1"
|
commits = f"{self.coder.last_aider_commit_hash}~1"
|
||||||
diff = self.coder.get_diffs(commits, self.coder.last_aider_commit_hash)
|
diff = self.coder.get_diffs(
|
||||||
|
commits, self.coder.last_aider_commit_hash, pretty=self.coder.pretty
|
||||||
|
)
|
||||||
|
|
||||||
# don't use io.tool_output() because we don't want to log or further colorize
|
# don't use io.tool_output() because we don't want to log or further colorize
|
||||||
print(diff)
|
print(diff)
|
||||||
|
@ -247,7 +249,7 @@ class Commands:
|
||||||
|
|
||||||
added_fnames = []
|
added_fnames = []
|
||||||
git_added = []
|
git_added = []
|
||||||
git_files = self.coder.get_tracked_files()
|
git_files = self.coder.get_tracked_files() if self.coder.repo else []
|
||||||
|
|
||||||
all_matched_files = set()
|
all_matched_files = set()
|
||||||
for word in args.split():
|
for word in args.split():
|
||||||
|
|
|
@ -1,4 +1,11 @@
|
||||||
|
import os
|
||||||
|
from pathlib import Path, PurePosixPath
|
||||||
|
|
||||||
import git
|
import git
|
||||||
|
import openai
|
||||||
|
|
||||||
|
from aider import models, prompts, utils
|
||||||
|
from aider.sendchat import send_with_retries
|
||||||
|
|
||||||
|
|
||||||
class AiderRepo:
|
class AiderRepo:
|
||||||
|
@ -30,23 +37,26 @@ class AiderRepo:
|
||||||
if fname.is_dir():
|
if fname.is_dir():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
self.abs_fnames.add(str(fname))
|
|
||||||
|
|
||||||
num_repos = len(set(repo_paths))
|
num_repos = len(set(repo_paths))
|
||||||
|
|
||||||
if num_repos == 0:
|
if num_repos == 0:
|
||||||
return
|
raise FileNotFoundError
|
||||||
if num_repos > 1:
|
if num_repos > 1:
|
||||||
self.io.tool_error("Files are in different git repos.")
|
self.io.tool_error("Files are in different git repos.")
|
||||||
return
|
raise FileNotFoundError
|
||||||
|
|
||||||
# https://github.com/gitpython-developers/GitPython/issues/427
|
# https://github.com/gitpython-developers/GitPython/issues/427
|
||||||
self.repo = git.Repo(repo_paths.pop(), odbt=git.GitDB)
|
self.repo = git.Repo(repo_paths.pop(), odbt=git.GitDB)
|
||||||
|
|
||||||
self.root = utils.safe_abs_path(self.repo.working_tree_dir)
|
self.root = utils.safe_abs_path(self.repo.working_tree_dir)
|
||||||
|
|
||||||
|
def ___(self, fnames):
|
||||||
|
|
||||||
|
# TODO!
|
||||||
|
|
||||||
|
self.abs_fnames.add(str(fname))
|
||||||
|
|
||||||
new_files = []
|
new_files = []
|
||||||
for fname in self.abs_fnames:
|
for fname in fnames:
|
||||||
relative_fname = self.get_rel_fname(fname)
|
relative_fname = self.get_rel_fname(fname)
|
||||||
|
|
||||||
tracked_files = set(self.get_tracked_files())
|
tracked_files = set(self.get_tracked_files())
|
||||||
|
@ -71,7 +81,12 @@ class AiderRepo:
|
||||||
else:
|
else:
|
||||||
self.io.tool_error("Skipped adding new files to the git repo.")
|
self.io.tool_error("Skipped adding new files to the git repo.")
|
||||||
|
|
||||||
def commit(self, history=None, prefix=None, ask=False, message=None, which="chat_files"):
|
def commit(
|
||||||
|
self, context=None, prefix=None, ask=False, message=None, which="chat_files", pretty=False
|
||||||
|
):
|
||||||
|
|
||||||
|
## TODO!
|
||||||
|
|
||||||
repo = self.repo
|
repo = self.repo
|
||||||
if not repo:
|
if not repo:
|
||||||
return
|
return
|
||||||
|
@ -96,7 +111,7 @@ class AiderRepo:
|
||||||
if not current_branch_commit_count:
|
if not current_branch_commit_count:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
these_diffs = self.get_diffs("HEAD", "--", relative_fname)
|
these_diffs = self.get_diffs(pretty, "HEAD", "--", relative_fname)
|
||||||
|
|
||||||
if these_diffs:
|
if these_diffs:
|
||||||
diffs += these_diffs + "\n"
|
diffs += these_diffs + "\n"
|
||||||
|
@ -115,7 +130,6 @@ class AiderRepo:
|
||||||
# don't use io.tool_output() because we don't want to log or further colorize
|
# don't use io.tool_output() because we don't want to log or further colorize
|
||||||
print(diffs)
|
print(diffs)
|
||||||
|
|
||||||
context = self.get_context_from_history(history)
|
|
||||||
if message:
|
if message:
|
||||||
commit_message = message
|
commit_message = message
|
||||||
else:
|
else:
|
||||||
|
@ -162,13 +176,6 @@ class AiderRepo:
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return self.repo.git_dir
|
return self.repo.git_dir
|
||||||
|
|
||||||
def get_context_from_history(self, history):
|
|
||||||
context = ""
|
|
||||||
if history:
|
|
||||||
for msg in history:
|
|
||||||
context += "\n" + msg["role"].upper() + ": " + msg["content"] + "\n"
|
|
||||||
return context
|
|
||||||
|
|
||||||
def get_commit_message(self, diffs, context):
|
def get_commit_message(self, diffs, context):
|
||||||
if len(diffs) >= 4 * 1024 * 4:
|
if len(diffs) >= 4 * 1024 * 4:
|
||||||
self.io.tool_error(
|
self.io.tool_error(
|
||||||
|
@ -184,34 +191,45 @@ class AiderRepo:
|
||||||
]
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
interrupted = self.send(
|
_hash, response = send_with_retries(
|
||||||
messages,
|
|
||||||
model=models.GPT35.name,
|
model=models.GPT35.name,
|
||||||
silent=True,
|
messages=messages,
|
||||||
)
|
functions=None,
|
||||||
except openai.error.InvalidRequestError:
|
stream=False,
|
||||||
self.io.tool_error(
|
|
||||||
f"Failed to generate commit message using {models.GPT35.name} due to an invalid"
|
|
||||||
" request."
|
|
||||||
)
|
)
|
||||||
|
commit_message = completion.choices[0].message.content
|
||||||
|
except (AttributeError, openai.error.InvalidRequestError):
|
||||||
|
self.io.tool_error(f"Failed to generate commit message using {models.GPT35.name}")
|
||||||
return
|
return
|
||||||
|
|
||||||
commit_message = self.partial_response_content
|
|
||||||
commit_message = commit_message.strip()
|
commit_message = commit_message.strip()
|
||||||
if commit_message and commit_message[0] == '"' and commit_message[-1] == '"':
|
if commit_message and commit_message[0] == '"' and commit_message[-1] == '"':
|
||||||
commit_message = commit_message[1:-1].strip()
|
commit_message = commit_message[1:-1].strip()
|
||||||
|
|
||||||
if interrupted:
|
|
||||||
self.io.tool_error(
|
|
||||||
f"Unable to get commit message from {models.GPT35.name}. Use /commit to try again."
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
return commit_message
|
return commit_message
|
||||||
|
|
||||||
def get_diffs(self, *args):
|
def get_diffs(self, pretty, *args):
|
||||||
if self.pretty:
|
if pretty:
|
||||||
args = ["--color"] + list(args)
|
args = ["--color"] + list(args)
|
||||||
|
|
||||||
diffs = self.repo.git.diff(*args)
|
diffs = self.repo.git.diff(*args)
|
||||||
return diffs
|
return diffs
|
||||||
|
|
||||||
|
def get_tracked_files(self):
|
||||||
|
if not self.repo:
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
commit = self.repo.head.commit
|
||||||
|
except ValueError:
|
||||||
|
return set()
|
||||||
|
|
||||||
|
files = []
|
||||||
|
for blob in commit.tree.traverse():
|
||||||
|
if blob.type == "blob": # blob is a file
|
||||||
|
files.append(blob.path)
|
||||||
|
|
||||||
|
# convert to appropriate os.sep, since git always normalizes to /
|
||||||
|
res = set(str(Path(PurePosixPath(path))) for path in files)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue