refactor send output code

This commit is contained in:
Paul Gauthier 2023-05-10 21:42:22 -07:00
parent 84776593d3
commit 65436ae255

View file

@ -52,9 +52,7 @@ class Coder:
self.set_repo(fnames) self.set_repo(fnames)
if not self.repo: if not self.repo:
self.console.print( self.console.print("[red]No suitable git repo, will not automatically commit edits.")
"[red]No suitable git repo, will not automatically commit edits."
)
self.find_common_root() self.find_common_root()
self.pretty = pretty self.pretty = pretty
@ -116,13 +114,9 @@ class Coder:
if Confirm.ask("[bright_black]Add them?", console=self.console, default="y"): if Confirm.ask("[bright_black]Add them?", console=self.console, default="y"):
for relative_fname in new_files: for relative_fname in new_files:
repo.git.add(relative_fname) repo.git.add(relative_fname)
self.console.print( self.console.print(f"[bright_black]Added {relative_fname} to the git repo")
f"[bright_black]Added {relative_fname} to the git repo"
)
show_files = ", ".join(new_files) show_files = ", ".join(new_files)
commit_message = ( commit_message = f"Initial commit: Added new files to the git repo: {show_files}"
f"Initial commit: Added new files to the git repo: {show_files}"
)
repo.git.commit("-m", commit_message, "--no-verify") repo.git.commit("-m", commit_message, "--no-verify")
self.console.print( self.console.print(
f"[bright_black]Committed new files with message: {commit_message}" f"[bright_black]Committed new files with message: {commit_message}"
@ -299,48 +293,36 @@ class Coder:
# print(f"Rate limit exceeded. Retrying in {retry_after} seconds.") # print(f"Rate limit exceeded. Retrying in {retry_after} seconds.")
time.sleep(retry_after) time.sleep(retry_after)
if self.pretty and not silent: self.show_send_output(completion, silent)
self.show_send_output_color(completion)
else:
self.show_send_output_plain(completion, silent)
except KeyboardInterrupt: except KeyboardInterrupt:
interrupted = True interrupted = True
return self.resp, interrupted return self.resp, interrupted
def show_send_output_plain(self, completion, silent): def show_send_output(self, completion, silent):
for chunk in completion: if self.pretty:
if chunk.choices[0].finish_reason not in (None, "stop"): live = Live(vertical_overflow="scroll")
dump(chunk.choices[0].finish_reason) live.start()
try:
text = chunk.choices[0].delta.content
self.resp += text
except AttributeError:
continue
if not silent:
sys.stdout.write(text)
sys.stdout.flush()
def show_send_output_color(self, completion):
with Live(vertical_overflow="scroll") as live:
for chunk in completion: for chunk in completion:
if chunk.choices[0].finish_reason not in (None, "stop"): if chunk.choices[0].finish_reason not in (None, "stop"):
assert False, "Exceeded context window!" assert False, "Exceeded context window!"
try: try:
text = chunk.choices[0].delta.content text = chunk.choices[0].delta.content
self.resp += text self.resp += text
except AttributeError: except AttributeError:
continue continue
if self.pretty:
md = Markdown(self.resp, style="blue", code_theme="default") md = Markdown(self.resp, style="blue", code_theme="default")
live.update(md) live.update(md)
else:
sys.stdout.write(text)
sys.stdout.flush()
# live.update(Text("")) if self.pretty:
# live.stop() live.stop()
# md = Markdown(self.resp, style="blue", code_theme="default")
# self.console.print(md)
pattern = re.compile( pattern = re.compile(
# Optional: Matches the start of a code block (e.g., ```python) and any following whitespace # Optional: Matches the start of a code block (e.g., ```python) and any following whitespace
@ -370,9 +352,7 @@ class Coder:
if full_path not in self.abs_fnames: if full_path not in self.abs_fnames:
if not Path(full_path).exists(): if not Path(full_path).exists():
question = ( question = f"[bright_black]Allow creation of new file {path}?" # noqa: E501
f"[bright_black]Allow creation of new file {path}?" # noqa: E501
)
else: else:
question = f"[bright_black]Allow edits to {path} which was not previously provided?" # noqa: E501 question = f"[bright_black]Allow edits to {path} which was not previously provided?" # noqa: E501
if not Confirm.ask(question, console=self.console, default="y"): if not Confirm.ask(question, console=self.console, default="y"):
@ -456,8 +436,7 @@ class Coder:
if interrupted: if interrupted:
self.console.print( self.console.print(
"[red]Unable to get commit message from gpt-3.5-turbo. Use /commit to try" "[red]Unable to get commit message from gpt-3.5-turbo. Use /commit to try again.\n"
" again.\n"
) )
return return
@ -468,9 +447,7 @@ class Coder:
self.last_modified = self.get_last_modified() self.last_modified = self.get_last_modified()
self.console.print("[bright_black]Files have uncommitted changes.\n") self.console.print("[bright_black]Files have uncommitted changes.\n")
self.console.print( self.console.print(f"[bright_black]Suggested commit message:\n{commit_message}\n")
f"[bright_black]Suggested commit message:\n{commit_message}\n"
)
res = Prompt.ask( res = Prompt.ask(
"[bright_black]Commit before the chat proceeds? \[y/n/commit message]", # noqa: W605 E501 "[bright_black]Commit before the chat proceeds? \[y/n/commit message]", # noqa: W605 E501