fix send() output

This commit is contained in:
Paul Gauthier 2023-05-08 11:44:51 -07:00
parent 0f040949c4
commit 1dc8b3cd31

View file

@ -219,7 +219,7 @@ class Coder:
for line in content:
print(role, line)
def send(self, messages, model=None, show_progress=0):
def send(self, messages, model=None, progress_bar_expected=0, silent=False):
# self.show_messages(messages, "all")
if not model:
@ -234,20 +234,20 @@ class Coder:
interrupted = False
try:
if show_progress is not None:
self.show_send_progress(completion, show_progress)
elif self.pretty and show_progress:
if progress_bar_expected:
self.show_send_progress(completion, progress_bar_expected)
elif self.pretty and not silent:
self.show_send_output_color(completion)
else:
self.show_send_output_plain(completion, False)
self.show_send_output_plain(completion, silent)
except KeyboardInterrupt:
interrupted = True
return self.resp, interrupted
def show_send_progress(self, completion, show_progress):
def show_send_progress(self, completion, progress_bar_expected):
self.resp = ""
pbar = tqdm(total=show_progress)
pbar = tqdm(total=progress_bar_expected)
for chunk in completion:
try:
text = chunk.choices[0].delta.content
@ -257,10 +257,10 @@ class Coder:
pbar.update(len(text))
pbar.update(show_progress)
pbar.update(progress_bar_expected)
pbar.close()
def show_send_output_plain(self, completion, show_output=True):
def show_send_output_plain(self, completion, silent):
self.resp = ""
for chunk in completion:
@ -272,7 +272,7 @@ class Coder:
except AttributeError:
continue
if show_output:
if not silent:
sys.stdout.write(text)
sys.stdout.flush()
@ -378,7 +378,7 @@ class Coder:
dict(role="user", content=prompt),
]
res, interrupted = self.send(
messages, show_progress=len(content) + len(edit) / 2, model=model
messages, progress_bar_expected=len(content) + len(edit) / 2, model=model
)
if interrupted:
return
@ -450,7 +450,7 @@ class Coder:
commit_message, interrupted = self.send(
messages,
model="gpt-3.5-turbo",
show_progress=None,
silent=True,
)
commit_message = commit_message.strip()