This commit is contained in:
Paul Gauthier 2023-07-21 11:49:19 -03:00
parent 289887d94f
commit 23beb7cb5d
6 changed files with 87 additions and 105 deletions

View file

@ -7,9 +7,8 @@ import sys
import time
import traceback
from json.decoder import JSONDecodeError
from pathlib import Path, PurePosixPath
from pathlib import Path
import git
import openai
from jsonschema import Draft7Validator
from rich.console import Console, Text
@ -18,6 +17,7 @@ from rich.markdown import Markdown
from aider import models, prompts, utils
from aider.commands import Commands
from aider.repo import AiderRepo
from aider.repomap import RepoMap
from aider.sendchat import send_with_retries
@ -149,12 +149,16 @@ class Coder:
self.commands = Commands(self.io, self)
if use_git:
self.set_repo(fnames)
try:
self.repo = AiderRepo(fnames)
self.root = self.repo.root
except FileNotFoundError:
self.repo = None
else:
self.abs_fnames = set([str(Path(fname).resolve()) for fname in fnames])
if self.repo:
rel_repo_dir = self.get_rel_repo_dir()
rel_repo_dir = self.repo.get_rel_repo_dir()
self.io.tool_output(f"Git repo: {rel_repo_dir}")
else:
self.io.tool_output("Git repo: none")
@ -376,7 +380,7 @@ class Coder:
if self.should_dirty_commit(inp):
self.io.tool_output("Git repo has uncommitted changes, preparing commit...")
self.commit(ask=True, which="repo_files")
self.commit(ask=True, which="repo_files", pretty=self.pretty)
# files changed, move cur messages back behind the files messages
self.move_back_cur_messages(self.gpt_prompts.files_content_local_edits)
@ -495,8 +499,16 @@ class Coder:
)
]
def get_context_from_history(self, history):
context = ""
if history:
for msg in history:
context += "\n" + msg["role"].upper() + ": " + msg["content"] + "\n"
return context
def auto_commit(self):
res = self.commit(history=self.cur_messages, prefix="aider: ")
context = self.get_context_from_history(self.cur_messages)
res = self.commit(context=context, prefix="aider: ", pretty=self.pretty)
if res:
commit_hash, commit_message = res
self.last_aider_commit_hash = commit_hash
@ -553,7 +565,7 @@ class Coder:
return prompts.added_files.format(fnames=", ".join(mentioned_rel_fnames))
def send(self, messages, model=None, silent=False, functions=None):
def send(self, messages, model=None, functions=None):
if not model:
model = self.main_model.name
@ -566,25 +578,24 @@ class Coder:
self.chat_completion_call_hashes.append(hash_object.hexdigest())
if self.stream:
self.show_send_output_stream(completion, silent)
self.show_send_output_stream(completion)
else:
self.show_send_output(completion, silent)
self.show_send_output(completion)
except KeyboardInterrupt:
self.keyboard_interrupt()
interrupted = True
if not silent:
if self.partial_response_content:
self.io.ai_output(self.partial_response_content)
elif self.partial_response_function_call:
# TODO: push this into subclasses
args = self.parse_partial_args()
if args:
self.io.ai_output(json.dumps(args, indent=4))
if self.partial_response_content:
self.io.ai_output(self.partial_response_content)
elif self.partial_response_function_call:
# TODO: push this into subclasses
args = self.parse_partial_args()
if args:
self.io.ai_output(json.dumps(args, indent=4))
return interrupted
def show_send_output(self, completion, silent):
def show_send_output(self, completion):
if self.verbose:
print(completion)
@ -633,9 +644,9 @@ class Coder:
self.io.console.print(show_resp)
self.io.console.print(tokens)
def show_send_output_stream(self, completion, silent):
def show_send_output_stream(self, completion):
live = None
if self.pretty and not silent:
if self.pretty:
live = Live(vertical_overflow="scroll")
try:
@ -664,9 +675,6 @@ class Coder:
except AttributeError:
pass
if silent:
continue
if self.pretty:
self.live_incremental_response(live, False)
else:
@ -697,7 +705,7 @@ class Coder:
def get_all_relative_files(self):
if self.repo:
files = self.get_tracked_files()
files = self.repo.get_tracked_files()
else:
files = self.get_inchat_relative_files()
@ -752,25 +760,6 @@ class Coder:
return full_path
def get_tracked_files(self):
if not self.repo:
return []
try:
commit = self.repo.head.commit
except ValueError:
return set()
files = []
for blob in commit.tree.traverse():
if blob.type == "blob": # blob is a file
files.append(blob.path)
# convert to appropriate os.sep, since git always normalizes to /
res = set(str(Path(PurePosixPath(path))) for path in files)
return res
apply_update_errors = 0
def apply_updates(self):