mirror of
https://github.com/mudler/LocalAI.git
synced 2025-06-06 10:54:59 +00:00
feat(chatterbox): add new backend (#5524)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
dd7fa6b9f7
commit
d5c9c717b5
15 changed files with 330 additions and 3 deletions
29
backend/python/chatterbox/Makefile
Normal file
29
backend/python/chatterbox/Makefile
Normal file
|
@ -0,0 +1,29 @@
|
|||
.PHONY: coqui
|
||||
coqui: protogen
|
||||
bash install.sh
|
||||
|
||||
.PHONY: run
|
||||
run: protogen
|
||||
@echo "Running coqui..."
|
||||
bash run.sh
|
||||
@echo "coqui run."
|
||||
|
||||
.PHONY: test
|
||||
test: protogen
|
||||
@echo "Testing coqui..."
|
||||
bash test.sh
|
||||
@echo "coqui tested."
|
||||
|
||||
.PHONY: protogen
|
||||
protogen: backend_pb2_grpc.py backend_pb2.py
|
||||
|
||||
.PHONY: protogen-clean
|
||||
protogen-clean:
|
||||
$(RM) backend_pb2_grpc.py backend_pb2.py
|
||||
|
||||
backend_pb2_grpc.py backend_pb2.py:
|
||||
python3 -m grpc_tools.protoc -I../.. --python_out=. --grpc_python_out=. backend.proto
|
||||
|
||||
.PHONY: clean
|
||||
clean: protogen-clean
|
||||
rm -rf venv __pycache__
|
117
backend/python/chatterbox/backend.py
Normal file
117
backend/python/chatterbox/backend.py
Normal file
|
@ -0,0 +1,117 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
This is an extra gRPC server of LocalAI for Bark TTS
|
||||
"""
|
||||
from concurrent import futures
|
||||
import time
|
||||
import argparse
|
||||
import signal
|
||||
import sys
|
||||
import os
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
import torch
|
||||
import torchaudio as ta
|
||||
from chatterbox.tts import ChatterboxTTS
|
||||
|
||||
import grpc
|
||||
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
|
||||
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
||||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
||||
COQUI_LANGUAGE = os.environ.get('COQUI_LANGUAGE', None)
|
||||
|
||||
# Implement the BackendServicer class with the service methods
|
||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
"""
|
||||
BackendServicer is the class that implements the gRPC service
|
||||
"""
|
||||
def Health(self, request, context):
|
||||
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
||||
def LoadModel(self, request, context):
|
||||
|
||||
# Get device
|
||||
# device = "cuda" if request.CUDA else "cpu"
|
||||
if torch.cuda.is_available():
|
||||
print("CUDA is available", file=sys.stderr)
|
||||
device = "cuda"
|
||||
else:
|
||||
print("CUDA is not available", file=sys.stderr)
|
||||
device = "cpu"
|
||||
|
||||
if not torch.cuda.is_available() and request.CUDA:
|
||||
return backend_pb2.Result(success=False, message="CUDA is not available")
|
||||
|
||||
self.AudioPath = None
|
||||
|
||||
if os.path.isabs(request.AudioPath):
|
||||
self.AudioPath = request.AudioPath
|
||||
elif request.AudioPath and request.ModelFile != "" and not os.path.isabs(request.AudioPath):
|
||||
# get base path of modelFile
|
||||
modelFileBase = os.path.dirname(request.ModelFile)
|
||||
# modify LoraAdapter to be relative to modelFileBase
|
||||
self.AudioPath = os.path.join(modelFileBase, request.AudioPath)
|
||||
|
||||
try:
|
||||
print("Preparing models, please wait", file=sys.stderr)
|
||||
self.model = ChatterboxTTS.from_pretrained(device=device)
|
||||
except Exception as err:
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
# Implement your logic here for the LoadModel service
|
||||
# Replace this with your desired response
|
||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
||||
|
||||
def TTS(self, request, context):
|
||||
try:
|
||||
# Generate audio using ChatterboxTTS
|
||||
if self.AudioPath is not None:
|
||||
wav = self.model.generate(request.text, audio_prompt_path=self.AudioPath)
|
||||
else:
|
||||
wav = self.model.generate(request.text)
|
||||
|
||||
# Save the generated audio
|
||||
ta.save(request.dst, wav, self.model.sr)
|
||||
|
||||
except Exception as err:
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
return backend_pb2.Result(success=True)
|
||||
|
||||
def serve(address):
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
|
||||
options=[
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
])
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
print("Server started. Listening on: " + address, file=sys.stderr)
|
||||
|
||||
# Define the signal handler function
|
||||
def signal_handler(sig, frame):
|
||||
print("Received termination signal. Shutting down...")
|
||||
server.stop(0)
|
||||
sys.exit(0)
|
||||
|
||||
# Set the signal handlers for SIGINT and SIGTERM
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
try:
|
||||
while True:
|
||||
time.sleep(_ONE_DAY_IN_SECONDS)
|
||||
except KeyboardInterrupt:
|
||||
server.stop(0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run the gRPC server.")
|
||||
parser.add_argument(
|
||||
"--addr", default="localhost:50051", help="The address to bind the server to."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
serve(args.addr)
|
14
backend/python/chatterbox/install.sh
Executable file
14
backend/python/chatterbox/install.sh
Executable file
|
@ -0,0 +1,14 @@
|
|||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
source $(dirname $0)/../common/libbackend.sh
|
||||
|
||||
# This is here because the Intel pip index is broken and returns 200 status codes for every package name, it just doesn't return any package links.
|
||||
# This makes uv think that the package exists in the Intel pip index, and by default it stops looking at other pip indexes once it finds a match.
|
||||
# We need uv to continue falling through to the pypi default index to find optimum[openvino] in the pypi index
|
||||
# the --upgrade actually allows us to *downgrade* torch to the version provided in the Intel pip index
|
||||
if [ "x${BUILD_PROFILE}" == "xintel" ]; then
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
|
||||
fi
|
||||
|
||||
installRequirements
|
5
backend/python/chatterbox/requirements-cpu.txt
Normal file
5
backend/python/chatterbox/requirements-cpu.txt
Normal file
|
@ -0,0 +1,5 @@
|
|||
accelerate
|
||||
torch==2.6.0
|
||||
torchaudio==2.6.0
|
||||
transformers==4.46.3
|
||||
chatterbox-tts
|
6
backend/python/chatterbox/requirements-cublas11.txt
Normal file
6
backend/python/chatterbox/requirements-cublas11.txt
Normal file
|
@ -0,0 +1,6 @@
|
|||
--extra-index-url https://download.pytorch.org/whl/cu118
|
||||
torch==2.6.0+cu118
|
||||
torchaudio==2.6.0+cu118
|
||||
transformers==4.46.3
|
||||
chatterbox-tts
|
||||
accelerate
|
5
backend/python/chatterbox/requirements-cublas12.txt
Normal file
5
backend/python/chatterbox/requirements-cublas12.txt
Normal file
|
@ -0,0 +1,5 @@
|
|||
torch==2.6.0
|
||||
torchaudio==2.6.0
|
||||
transformers==4.46.3
|
||||
chatterbox-tts
|
||||
accelerate
|
6
backend/python/chatterbox/requirements-hipblas.txt
Normal file
6
backend/python/chatterbox/requirements-hipblas.txt
Normal file
|
@ -0,0 +1,6 @@
|
|||
--extra-index-url https://download.pytorch.org/whl/rocm6.0
|
||||
torch==2.6.0+rocm6.0
|
||||
torchaudio==2.6.0+rocm6.0
|
||||
transformers==4.46.3
|
||||
chatterbox-tts
|
||||
accelerate
|
12
backend/python/chatterbox/requirements-intel.txt
Normal file
12
backend/python/chatterbox/requirements-intel.txt
Normal file
|
@ -0,0 +1,12 @@
|
|||
--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||
intel-extension-for-pytorch==2.3.110+xpu
|
||||
torch==2.3.1+cxx11.abi
|
||||
torchaudio==2.3.1+cxx11.abi
|
||||
transformers==4.46.3
|
||||
chatterbox-tts
|
||||
accelerate
|
||||
oneccl_bind_pt==2.3.100+xpu
|
||||
optimum[openvino]
|
||||
setuptools
|
||||
transformers==4.48.3
|
||||
accelerate
|
5
backend/python/chatterbox/requirements.txt
Normal file
5
backend/python/chatterbox/requirements.txt
Normal file
|
@ -0,0 +1,5 @@
|
|||
grpcio==1.72.0
|
||||
protobuf
|
||||
certifi
|
||||
packaging
|
||||
setuptools
|
4
backend/python/chatterbox/run.sh
Executable file
4
backend/python/chatterbox/run.sh
Executable file
|
@ -0,0 +1,4 @@
|
|||
#!/bin/bash
|
||||
source $(dirname $0)/../common/libbackend.sh
|
||||
|
||||
startBackend $@
|
82
backend/python/chatterbox/test.py
Normal file
82
backend/python/chatterbox/test.py
Normal file
|
@ -0,0 +1,82 @@
|
|||
"""
|
||||
A test script to test the gRPC service
|
||||
"""
|
||||
import unittest
|
||||
import subprocess
|
||||
import time
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
import grpc
|
||||
|
||||
|
||||
class TestBackendServicer(unittest.TestCase):
|
||||
"""
|
||||
TestBackendServicer is the class that tests the gRPC service
|
||||
"""
|
||||
def setUp(self):
|
||||
"""
|
||||
This method sets up the gRPC service by starting the server
|
||||
"""
|
||||
self.service = subprocess.Popen(["python3", "backend.py", "--addr", "localhost:50051"])
|
||||
time.sleep(30)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
"""
|
||||
This method tears down the gRPC service by terminating the server
|
||||
"""
|
||||
self.service.terminate()
|
||||
self.service.wait()
|
||||
|
||||
def test_server_startup(self):
|
||||
"""
|
||||
This method tests if the server starts up successfully
|
||||
"""
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.Health(backend_pb2.HealthMessage())
|
||||
self.assertEqual(response.message, b'OK')
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("Server failed to start")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
def test_load_model(self):
|
||||
"""
|
||||
This method tests if the model is loaded successfully
|
||||
"""
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions())
|
||||
print(response)
|
||||
self.assertTrue(response.success)
|
||||
self.assertEqual(response.message, "Model loaded successfully")
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("LoadModel service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
def test_tts(self):
|
||||
"""
|
||||
This method tests if the embeddings are generated successfully
|
||||
"""
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions())
|
||||
self.assertTrue(response.success)
|
||||
tts_request = backend_pb2.TTSRequest(text="80s TV news production music hit for tonight's biggest story")
|
||||
tts_response = stub.TTS(tts_request)
|
||||
self.assertIsNotNone(tts_response)
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("TTS service failed")
|
||||
finally:
|
||||
self.tearDown()
|
6
backend/python/chatterbox/test.sh
Executable file
6
backend/python/chatterbox/test.sh
Executable file
|
@ -0,0 +1,6 @@
|
|||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
source $(dirname $0)/../common/libbackend.sh
|
||||
|
||||
runUnittests
|
Loading…
Add table
Add a link
Reference in a new issue