fix send() output

This commit is contained in:
Paul Gauthier 2023-05-08 11:44:51 -07:00
parent 0f040949c4
commit 1dc8b3cd31

View file

@ -219,7 +219,7 @@ class Coder:
for line in content: for line in content:
print(role, line) print(role, line)
def send(self, messages, model=None, show_progress=0): def send(self, messages, model=None, progress_bar_expected=0, silent=False):
# self.show_messages(messages, "all") # self.show_messages(messages, "all")
if not model: if not model:
@ -234,20 +234,20 @@ class Coder:
interrupted = False interrupted = False
try: try:
if show_progress is not None: if progress_bar_expected:
self.show_send_progress(completion, show_progress) self.show_send_progress(completion, progress_bar_expected)
elif self.pretty and show_progress: elif self.pretty and not silent:
self.show_send_output_color(completion) self.show_send_output_color(completion)
else: else:
self.show_send_output_plain(completion, False) self.show_send_output_plain(completion, silent)
except KeyboardInterrupt: except KeyboardInterrupt:
interrupted = True interrupted = True
return self.resp, interrupted return self.resp, interrupted
def show_send_progress(self, completion, show_progress): def show_send_progress(self, completion, progress_bar_expected):
self.resp = "" self.resp = ""
pbar = tqdm(total=show_progress) pbar = tqdm(total=progress_bar_expected)
for chunk in completion: for chunk in completion:
try: try:
text = chunk.choices[0].delta.content text = chunk.choices[0].delta.content
@ -257,10 +257,10 @@ class Coder:
pbar.update(len(text)) pbar.update(len(text))
pbar.update(show_progress) pbar.update(progress_bar_expected)
pbar.close() pbar.close()
def show_send_output_plain(self, completion, show_output=True): def show_send_output_plain(self, completion, silent):
self.resp = "" self.resp = ""
for chunk in completion: for chunk in completion:
@ -272,7 +272,7 @@ class Coder:
except AttributeError: except AttributeError:
continue continue
if show_output: if not silent:
sys.stdout.write(text) sys.stdout.write(text)
sys.stdout.flush() sys.stdout.flush()
@ -378,7 +378,7 @@ class Coder:
dict(role="user", content=prompt), dict(role="user", content=prompt),
] ]
res, interrupted = self.send( res, interrupted = self.send(
messages, show_progress=len(content) + len(edit) / 2, model=model messages, progress_bar_expected=len(content) + len(edit) / 2, model=model
) )
if interrupted: if interrupted:
return return
@ -450,7 +450,7 @@ class Coder:
commit_message, interrupted = self.send( commit_message, interrupted = self.send(
messages, messages,
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
show_progress=None, silent=True,
) )
commit_message = commit_message.strip() commit_message = commit_message.strip()