feat(transformers): merge musicgen functionalities to a single backend

So we optimize space

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2025-01-17 14:59:41 +01:00
parent b5eeb5c5ab
commit a96c1f9bcd
21 changed files with 173 additions and 401 deletions

View file

@ -22,6 +22,8 @@ import torch.cuda
XPU=os.environ.get("XPU", "0") == "1"
from transformers import AutoTokenizer, AutoModel, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria
from transformers import AutoProcessor, MusicgenForConditionalGeneration
from scipy.io import wavfile
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
@ -191,6 +193,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
export=True,
device=device_map)
self.OV = True
elif request.Type == "MusicgenForConditionalGeneration":
self.processor = AutoProcessor.from_pretrained(model_name)
self.model = MusicgenForConditionalGeneration.from_pretrained(model_name)
else:
print("Automodel", file=sys.stderr)
self.model = AutoModel.from_pretrained(model_name,
@ -380,6 +385,93 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
finally:
await iterations.aclose()
def SoundGeneration(self, request, context):
model_name = request.model
try:
if self.processor is None:
if model_name == "":
return backend_pb2.Result(success=False, message="request.model is required")
self.processor = AutoProcessor.from_pretrained(model_name)
if self.model is None:
if model_name == "":
return backend_pb2.Result(success=False, message="request.model is required")
self.model = MusicgenForConditionalGeneration.from_pretrained(model_name)
inputs = None
if request.text == "":
inputs = self.model.get_unconditional_inputs(num_samples=1)
elif request.HasField('src'):
# TODO SECURITY CODE GOES HERE LOL
# WHO KNOWS IF THIS WORKS???
sample_rate, wsamples = wavfile.read('path_to_your_file.wav')
if request.HasField('src_divisor'):
wsamples = wsamples[: len(wsamples) // request.src_divisor]
inputs = self.processor(
audio=wsamples,
sampling_rate=sample_rate,
text=[request.text],
padding=True,
return_tensors="pt",
)
else:
inputs = self.processor(
text=[request.text],
padding=True,
return_tensors="pt",
)
tokens = 256
if request.HasField('duration'):
tokens = int(request.duration * 51.2) # 256 tokens = 5 seconds, therefore 51.2 tokens is one second
guidance = 3.0
if request.HasField('temperature'):
guidance = request.temperature
dosample = True
if request.HasField('sample'):
dosample = request.sample
audio_values = self.model.generate(**inputs, do_sample=dosample, guidance_scale=guidance, max_new_tokens=tokens)
print("[transformers-musicgen] SoundGeneration generated!", file=sys.stderr)
sampling_rate = self.model.config.audio_encoder.sampling_rate
wavfile.write(request.dst, rate=sampling_rate, data=audio_values[0, 0].numpy())
print("[transformers-musicgen] SoundGeneration saved to", request.dst, file=sys.stderr)
print("[transformers-musicgen] SoundGeneration for", file=sys.stderr)
print("[transformers-musicgen] SoundGeneration requested tokens", tokens, file=sys.stderr)
print(request, file=sys.stderr)
except Exception as err:
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
return backend_pb2.Result(success=True)
# The TTS endpoint is older, and provides fewer features, but exists for compatibility reasons
def TTS(self, request, context):
model_name = request.model
try:
if self.processor is None:
if model_name == "":
return backend_pb2.Result(success=False, message="request.model is required")
self.processor = AutoProcessor.from_pretrained(model_name)
if self.model is None:
if model_name == "":
return backend_pb2.Result(success=False, message="request.model is required")
self.model = MusicgenForConditionalGeneration.from_pretrained(model_name)
inputs = self.processor(
text=[request.text],
padding=True,
return_tensors="pt",
)
tokens = 512 # No good place to set the "length" in TTS, so use 10s as a sane default
audio_values = self.model.generate(**inputs, max_new_tokens=tokens)
print("[transformers-musicgen] TTS generated!", file=sys.stderr)
sampling_rate = self.model.config.audio_encoder.sampling_rate
wavfile.write(request.dst, rate=sampling_rate, data=audio_values[0, 0].numpy())
print("[transformers-musicgen] TTS saved to", request.dst, file=sys.stderr)
print("[transformers-musicgen] TTS for", file=sys.stderr)
print(request, file=sys.stderr)
except Exception as err:
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
return backend_pb2.Result(success=True)
async def serve(address):
# Start asyncio gRPC server
server = grpc.aio.server(migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))