track and report errors/asks during benchmarking

This commit is contained in:
Paul Gauthier 2023-06-26 10:33:16 -07:00
parent cbfda391bb
commit 1370da14fb
2 changed files with 16 additions and 1 deletions

View file

@ -83,6 +83,9 @@ class AutoCompleter(Completer):
class InputOutput: class InputOutput:
num_error_outputs = 0
num_user_asks = 0
def __init__( def __init__(
self, self,
pretty=True, pretty=True,
@ -208,6 +211,8 @@ class InputOutput:
self.append_chat_history(hist) self.append_chat_history(hist)
def confirm_ask(self, question, default="y"): def confirm_ask(self, question, default="y"):
self.num_user_asks += 1
if self.yes is True: if self.yes is True:
res = "yes" res = "yes"
elif self.yes is False: elif self.yes is False:
@ -217,12 +222,16 @@ class InputOutput:
hist = f"{question.strip()} {res.strip()}" hist = f"{question.strip()} {res.strip()}"
self.append_chat_history(hist, linebreak=True, blockquote=True) self.append_chat_history(hist, linebreak=True, blockquote=True)
if self.yes in (True, False):
self.tool_output(hist)
if not res or not res.strip(): if not res or not res.strip():
return return
return res.strip().lower().startswith("y") return res.strip().lower().startswith("y")
def prompt_ask(self, question, default=None): def prompt_ask(self, question, default=None):
self.num_user_asks += 1
if self.yes is True: if self.yes is True:
res = "yes" res = "yes"
elif self.yes is False: elif self.yes is False:
@ -232,10 +241,14 @@ class InputOutput:
hist = f"{question.strip()} {res.strip()}" hist = f"{question.strip()} {res.strip()}"
self.append_chat_history(hist, linebreak=True, blockquote=True) self.append_chat_history(hist, linebreak=True, blockquote=True)
if self.yes in (True, False):
self.tool_output(hist)
return res return res
def tool_error(self, message): def tool_error(self, message):
self.num_error_outputs += 1
if message.strip(): if message.strip():
hist = f"{message.strip()}" hist = f"{message.strip()}"
self.append_chat_history(hist, linebreak=True, blockquote=True) self.append_chat_history(hist, linebreak=True, blockquote=True)

View file

@ -180,7 +180,7 @@ def summarize_results(dirname, all_results, total_tests=None):
total_cost += results["cost"] total_cost += results["cost"]
duration += results["duration"] duration += results["duration"]
for key in "model edit_format commit_hash".split(): for key in "model edit_format commit_hash num_error_outputs num_user_asks".split():
val = results.get(key) val = results.get(key)
variants[key].add(val) variants[key].add(val)
@ -339,6 +339,8 @@ Only use standard python libraries, don't suggest installing any packages.
cost=coder.total_cost, cost=coder.total_cost,
duration=dur, duration=dur,
commit_hash=commit_hash, commit_hash=commit_hash,
num_error_outputs=io.num_error_outputs,
num_user_asks=io.num_user_asks,
) )
dump(results) dump(results)