mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-31 07:54:59 +00:00
feat: add vall-e-x (#1007)
**Description** This PR fixes #985 **Notes for Reviewers** **[Signed commits](../CONTRIBUTING.md#signing-off-on-commits-developer-certificate-of-origin)** - [ ] Yes, I signed my commits. <!-- Thank you for contributing to LocalAI! Contributing Conventions: 1. Include descriptive PR titles with [<component-name>] prepended. 2. Build and test your changes before submitting a PR. 3. Sign your commits By following the community's contribution conventions upfront, the review process will be accelerated and your PR merged more quickly. --> Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
e7981152b2
commit
dc307a1cc0
14 changed files with 809 additions and 266 deletions
61
extra/grpc/vall-e-x/backend_pb2.py
Normal file
61
extra/grpc/vall-e-x/backend_pb2.py
Normal file
File diff suppressed because one or more lines are too long
363
extra/grpc/vall-e-x/backend_pb2_grpc.py
Normal file
363
extra/grpc/vall-e-x/backend_pb2_grpc.py
Normal file
|
@ -0,0 +1,363 @@
|
|||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||
"""Client and server classes corresponding to protobuf-defined services."""
|
||||
import grpc
|
||||
|
||||
import backend_pb2 as backend__pb2
|
||||
|
||||
|
||||
class BackendStub(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.Health = channel.unary_unary(
|
||||
'/backend.Backend/Health',
|
||||
request_serializer=backend__pb2.HealthMessage.SerializeToString,
|
||||
response_deserializer=backend__pb2.Reply.FromString,
|
||||
)
|
||||
self.Predict = channel.unary_unary(
|
||||
'/backend.Backend/Predict',
|
||||
request_serializer=backend__pb2.PredictOptions.SerializeToString,
|
||||
response_deserializer=backend__pb2.Reply.FromString,
|
||||
)
|
||||
self.LoadModel = channel.unary_unary(
|
||||
'/backend.Backend/LoadModel',
|
||||
request_serializer=backend__pb2.ModelOptions.SerializeToString,
|
||||
response_deserializer=backend__pb2.Result.FromString,
|
||||
)
|
||||
self.PredictStream = channel.unary_stream(
|
||||
'/backend.Backend/PredictStream',
|
||||
request_serializer=backend__pb2.PredictOptions.SerializeToString,
|
||||
response_deserializer=backend__pb2.Reply.FromString,
|
||||
)
|
||||
self.Embedding = channel.unary_unary(
|
||||
'/backend.Backend/Embedding',
|
||||
request_serializer=backend__pb2.PredictOptions.SerializeToString,
|
||||
response_deserializer=backend__pb2.EmbeddingResult.FromString,
|
||||
)
|
||||
self.GenerateImage = channel.unary_unary(
|
||||
'/backend.Backend/GenerateImage',
|
||||
request_serializer=backend__pb2.GenerateImageRequest.SerializeToString,
|
||||
response_deserializer=backend__pb2.Result.FromString,
|
||||
)
|
||||
self.AudioTranscription = channel.unary_unary(
|
||||
'/backend.Backend/AudioTranscription',
|
||||
request_serializer=backend__pb2.TranscriptRequest.SerializeToString,
|
||||
response_deserializer=backend__pb2.TranscriptResult.FromString,
|
||||
)
|
||||
self.TTS = channel.unary_unary(
|
||||
'/backend.Backend/TTS',
|
||||
request_serializer=backend__pb2.TTSRequest.SerializeToString,
|
||||
response_deserializer=backend__pb2.Result.FromString,
|
||||
)
|
||||
self.TokenizeString = channel.unary_unary(
|
||||
'/backend.Backend/TokenizeString',
|
||||
request_serializer=backend__pb2.PredictOptions.SerializeToString,
|
||||
response_deserializer=backend__pb2.TokenizationResponse.FromString,
|
||||
)
|
||||
self.Status = channel.unary_unary(
|
||||
'/backend.Backend/Status',
|
||||
request_serializer=backend__pb2.HealthMessage.SerializeToString,
|
||||
response_deserializer=backend__pb2.StatusResponse.FromString,
|
||||
)
|
||||
|
||||
|
||||
class BackendServicer(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def Health(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def Predict(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def LoadModel(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def PredictStream(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def Embedding(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def GenerateImage(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def AudioTranscription(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def TTS(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def TokenizeString(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def Status(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_BackendServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'Health': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Health,
|
||||
request_deserializer=backend__pb2.HealthMessage.FromString,
|
||||
response_serializer=backend__pb2.Reply.SerializeToString,
|
||||
),
|
||||
'Predict': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Predict,
|
||||
request_deserializer=backend__pb2.PredictOptions.FromString,
|
||||
response_serializer=backend__pb2.Reply.SerializeToString,
|
||||
),
|
||||
'LoadModel': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.LoadModel,
|
||||
request_deserializer=backend__pb2.ModelOptions.FromString,
|
||||
response_serializer=backend__pb2.Result.SerializeToString,
|
||||
),
|
||||
'PredictStream': grpc.unary_stream_rpc_method_handler(
|
||||
servicer.PredictStream,
|
||||
request_deserializer=backend__pb2.PredictOptions.FromString,
|
||||
response_serializer=backend__pb2.Reply.SerializeToString,
|
||||
),
|
||||
'Embedding': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Embedding,
|
||||
request_deserializer=backend__pb2.PredictOptions.FromString,
|
||||
response_serializer=backend__pb2.EmbeddingResult.SerializeToString,
|
||||
),
|
||||
'GenerateImage': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.GenerateImage,
|
||||
request_deserializer=backend__pb2.GenerateImageRequest.FromString,
|
||||
response_serializer=backend__pb2.Result.SerializeToString,
|
||||
),
|
||||
'AudioTranscription': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.AudioTranscription,
|
||||
request_deserializer=backend__pb2.TranscriptRequest.FromString,
|
||||
response_serializer=backend__pb2.TranscriptResult.SerializeToString,
|
||||
),
|
||||
'TTS': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.TTS,
|
||||
request_deserializer=backend__pb2.TTSRequest.FromString,
|
||||
response_serializer=backend__pb2.Result.SerializeToString,
|
||||
),
|
||||
'TokenizeString': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.TokenizeString,
|
||||
request_deserializer=backend__pb2.PredictOptions.FromString,
|
||||
response_serializer=backend__pb2.TokenizationResponse.SerializeToString,
|
||||
),
|
||||
'Status': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Status,
|
||||
request_deserializer=backend__pb2.HealthMessage.FromString,
|
||||
response_serializer=backend__pb2.StatusResponse.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'backend.Backend', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
class Backend(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
@staticmethod
|
||||
def Health(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Health',
|
||||
backend__pb2.HealthMessage.SerializeToString,
|
||||
backend__pb2.Reply.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def Predict(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Predict',
|
||||
backend__pb2.PredictOptions.SerializeToString,
|
||||
backend__pb2.Reply.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def LoadModel(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/LoadModel',
|
||||
backend__pb2.ModelOptions.SerializeToString,
|
||||
backend__pb2.Result.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def PredictStream(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_stream(request, target, '/backend.Backend/PredictStream',
|
||||
backend__pb2.PredictOptions.SerializeToString,
|
||||
backend__pb2.Reply.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def Embedding(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Embedding',
|
||||
backend__pb2.PredictOptions.SerializeToString,
|
||||
backend__pb2.EmbeddingResult.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def GenerateImage(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/GenerateImage',
|
||||
backend__pb2.GenerateImageRequest.SerializeToString,
|
||||
backend__pb2.Result.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def AudioTranscription(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/AudioTranscription',
|
||||
backend__pb2.TranscriptRequest.SerializeToString,
|
||||
backend__pb2.TranscriptResult.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def TTS(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/TTS',
|
||||
backend__pb2.TTSRequest.SerializeToString,
|
||||
backend__pb2.Result.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def TokenizeString(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/TokenizeString',
|
||||
backend__pb2.PredictOptions.SerializeToString,
|
||||
backend__pb2.TokenizationResponse.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def Status(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Status',
|
||||
backend__pb2.HealthMessage.SerializeToString,
|
||||
backend__pb2.StatusResponse.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
91
extra/grpc/vall-e-x/ttsvalle.py
Normal file
91
extra/grpc/vall-e-x/ttsvalle.py
Normal file
|
@ -0,0 +1,91 @@
|
|||
#!/usr/bin/env python3
|
||||
import grpc
|
||||
from concurrent import futures
|
||||
import time
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
import argparse
|
||||
import signal
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from utils.generation import SAMPLE_RATE, generate_audio, preload_models
|
||||
from scipy.io.wavfile import write as write_wav
|
||||
from utils.prompt_making import make_prompt
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
|
||||
# Implement the BackendServicer class with the service methods
|
||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
def Health(self, request, context):
|
||||
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
||||
def LoadModel(self, request, context):
|
||||
model_name = request.Model
|
||||
try:
|
||||
print("Preparing models, please wait", file=sys.stderr)
|
||||
# download and load all models
|
||||
preload_models()
|
||||
|
||||
if request.AudioPath != "":
|
||||
print("Generating model", file=sys.stderr)
|
||||
make_prompt(name=model_name, audio_prompt_path=request.AudioPath)
|
||||
### Use given transcript
|
||||
##make_prompt(name=model_name, audio_prompt_path="paimon_prompt.wav",
|
||||
## transcript="Just, what was that? Paimon thought we were gonna get eaten.")
|
||||
except Exception as err:
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
# Implement your logic here for the LoadModel service
|
||||
# Replace this with your desired response
|
||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
||||
|
||||
def TTS(self, request, context):
|
||||
model = request.model
|
||||
print(request, file=sys.stderr)
|
||||
try:
|
||||
audio_array = None
|
||||
if model != "":
|
||||
audio_array = generate_audio(request.text, prompt=model)
|
||||
else:
|
||||
audio_array = generate_audio(request.text)
|
||||
print("saving to", request.dst, file=sys.stderr)
|
||||
# save audio to disk
|
||||
write_wav(request.dst, SAMPLE_RATE, audio_array)
|
||||
print("saved to", request.dst, file=sys.stderr)
|
||||
print("tts for", file=sys.stderr)
|
||||
print(request, file=sys.stderr)
|
||||
except Exception as err:
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
return backend_pb2.Result(success=True)
|
||||
|
||||
def serve(address):
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
print("Server started. Listening on: " + address, file=sys.stderr)
|
||||
|
||||
# Define the signal handler function
|
||||
def signal_handler(sig, frame):
|
||||
print("Received termination signal. Shutting down...")
|
||||
server.stop(0)
|
||||
sys.exit(0)
|
||||
|
||||
# Set the signal handlers for SIGINT and SIGTERM
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
try:
|
||||
while True:
|
||||
time.sleep(_ONE_DAY_IN_SECONDS)
|
||||
except KeyboardInterrupt:
|
||||
server.stop(0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run the gRPC server.")
|
||||
parser.add_argument(
|
||||
"--addr", default="localhost:50051", help="The address to bind the server to."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
serve(args.addr)
|
Loading…
Add table
Add a link
Reference in a new issue