fix: Implement cache warming with a background thread

This commit is contained in:
Paul Gauthier 2024-08-26 16:22:53 -07:00 committed by Paul Gauthier (aider)
parent 97a70830e9
commit 07767e2961

View file

@ -14,7 +14,7 @@ import threading
import time import time
import traceback import traceback
from collections import defaultdict from collections import defaultdict
from datetime import datetime from datetime import datetime, timedelta
from json.decoder import JSONDecodeError from json.decoder import JSONDecodeError
from pathlib import Path from pathlib import Path
@ -992,16 +992,30 @@ class Coder:
if not self.num_cache_warming_pings: if not self.num_cache_warming_pings:
return return
if self.cache_warming_thread and self.cache_warming_thread.is_alive(): delay = 20
self.cache_warming_thread.cancel() self.next_cache_warm = time.time() + delay
self.warming_pings_left = self.num_cache_warming_pings
self.cache_warming_chunks = chunks
if self.cache_warming_thread:
return
def warm_cache_worker(): def warm_cache_worker():
for i in range(self.num_cache_warming_pings): while True:
time.sleep(20) # 290 == 4 minutes and 50 seconds time.sleep(1)
if self.warming_pings_left <= 0:
continue
now = time.time()
if now < self.next_cache_warm:
continue
self.warming_pings_left -= 1
self.next_cache_warm = time.time() + delay
try: try:
completion = litellm.completion( completion = litellm.completion(
model=self.main_model.name, model=self.main_model.name,
messages=chunks.cacheable_messages(), messages=self.cache_warming_chunks.cacheable_messages(),
stream=False, stream=False,
max_tokens=1, max_tokens=1,
extra_headers=self.main_model.extra_headers, extra_headers=self.main_model.extra_headers,
@ -1014,12 +1028,8 @@ class Coder:
completion.usage, "prompt_cache_hit_tokens", 0 completion.usage, "prompt_cache_hit_tokens", 0
) or getattr(completion.usage, "cache_read_input_tokens", 0) ) or getattr(completion.usage, "cache_read_input_tokens", 0)
self.io.tool_output( # if self.verbose:
f"Warmed {format_tokens(cache_hit_tokens)} cached tokens." self.io.tool_output(f"Warmed {format_tokens(cache_hit_tokens)} cached tokens.")
f" ({i + 1}/{self.num_cache_warming_pings})"
)
self.io.tool_output("Stopped warming.")
self.cache_warming_thread = threading.Timer(0, warm_cache_worker) self.cache_warming_thread = threading.Timer(0, warm_cache_worker)
self.cache_warming_thread.start() self.cache_warming_thread.start()