diff --git a/coder.py b/coder.py index 48f2205d0..2e0b183a7 100755 --- a/coder.py +++ b/coder.py @@ -219,7 +219,7 @@ class Coder: for line in content: print(role, line) - def send(self, messages, model=None, show_progress=0): + def send(self, messages, model=None, progress_bar_expected=0, silent=False): # self.show_messages(messages, "all") if not model: @@ -234,20 +234,20 @@ class Coder: interrupted = False try: - if show_progress is not None: - self.show_send_progress(completion, show_progress) - elif self.pretty and show_progress: + if progress_bar_expected: + self.show_send_progress(completion, progress_bar_expected) + elif self.pretty and not silent: self.show_send_output_color(completion) else: - self.show_send_output_plain(completion, False) + self.show_send_output_plain(completion, silent) except KeyboardInterrupt: interrupted = True return self.resp, interrupted - def show_send_progress(self, completion, show_progress): + def show_send_progress(self, completion, progress_bar_expected): self.resp = "" - pbar = tqdm(total=show_progress) + pbar = tqdm(total=progress_bar_expected) for chunk in completion: try: text = chunk.choices[0].delta.content @@ -257,10 +257,10 @@ class Coder: pbar.update(len(text)) - pbar.update(show_progress) + pbar.update(progress_bar_expected) pbar.close() - def show_send_output_plain(self, completion, show_output=True): + def show_send_output_plain(self, completion, silent): self.resp = "" for chunk in completion: @@ -272,7 +272,7 @@ class Coder: except AttributeError: continue - if show_output: + if not silent: sys.stdout.write(text) sys.stdout.flush() @@ -378,7 +378,7 @@ class Coder: dict(role="user", content=prompt), ] res, interrupted = self.send( - messages, show_progress=len(content) + len(edit) / 2, model=model + messages, progress_bar_expected=len(content) + len(edit) / 2, model=model ) if interrupted: return @@ -450,7 +450,7 @@ class Coder: commit_message, interrupted = self.send( messages, model="gpt-3.5-turbo", - show_progress=None, + silent=True, ) commit_message = commit_message.strip()