feat(diffusers): be consistent with pipelines, support also depthimg2img (#926)

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2023-08-18 22:06:24 +02:00 committed by GitHub
parent 8cb1061c11
commit 1079b18ff7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 480 additions and 103 deletions

View file

@ -1,12 +1,11 @@
package openai package openai
import ( import (
"bufio"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"os" "os"
"path"
"path/filepath" "path/filepath"
"strconv" "strconv"
"strings" "strings"
@ -52,34 +51,28 @@ func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx
} }
src := "" src := ""
// retrieve the file data from the request if input.File != "" {
file, err := c.FormFile("src") //base 64 decode the file and write it somewhere
if err == nil { // that we will cleanup
decoded, err := base64.StdEncoding.DecodeString(input.File)
f, err := file.Open()
if err != nil { if err != nil {
return err return err
} }
defer f.Close() // Create a temporary file
outputFile, err := os.CreateTemp(o.ImageDir, "b64")
dir, err := os.MkdirTemp("", "img2img")
if err != nil { if err != nil {
return err return err
} }
defer os.RemoveAll(dir) // write the base64 result
writer := bufio.NewWriter(outputFile)
dst := filepath.Join(dir, path.Base(file.Filename)) _, err = writer.Write(decoded)
dstFile, err := os.Create(dst)
if err != nil { if err != nil {
outputFile.Close()
return err return err
} }
outputFile.Close()
if _, err := io.Copy(dstFile, f); err != nil { src = outputFile.Name()
log.Debug().Msgf("Image file copying error %+v - %+v - err %+v", file.Filename, dst, err) defer os.RemoveAll(src)
return err
}
src = dst
} }
log.Debug().Msgf("Parameter Config: %+v", config) log.Debug().Msgf("Parameter Config: %+v", config)

File diff suppressed because one or more lines are too long

View file

@ -54,6 +54,16 @@ class BackendStub(object):
request_serializer=backend__pb2.TTSRequest.SerializeToString, request_serializer=backend__pb2.TTSRequest.SerializeToString,
response_deserializer=backend__pb2.Result.FromString, response_deserializer=backend__pb2.Result.FromString,
) )
self.TokenizeString = channel.unary_unary(
'/backend.Backend/TokenizeString',
request_serializer=backend__pb2.PredictOptions.SerializeToString,
response_deserializer=backend__pb2.TokenizationResponse.FromString,
)
self.Status = channel.unary_unary(
'/backend.Backend/Status',
request_serializer=backend__pb2.HealthMessage.SerializeToString,
response_deserializer=backend__pb2.StatusResponse.FromString,
)
class BackendServicer(object): class BackendServicer(object):
@ -107,6 +117,18 @@ class BackendServicer(object):
context.set_details('Method not implemented!') context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!') raise NotImplementedError('Method not implemented!')
def TokenizeString(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Status(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_BackendServicer_to_server(servicer, server): def add_BackendServicer_to_server(servicer, server):
rpc_method_handlers = { rpc_method_handlers = {
@ -150,6 +172,16 @@ def add_BackendServicer_to_server(servicer, server):
request_deserializer=backend__pb2.TTSRequest.FromString, request_deserializer=backend__pb2.TTSRequest.FromString,
response_serializer=backend__pb2.Result.SerializeToString, response_serializer=backend__pb2.Result.SerializeToString,
), ),
'TokenizeString': grpc.unary_unary_rpc_method_handler(
servicer.TokenizeString,
request_deserializer=backend__pb2.PredictOptions.FromString,
response_serializer=backend__pb2.TokenizationResponse.SerializeToString,
),
'Status': grpc.unary_unary_rpc_method_handler(
servicer.Status,
request_deserializer=backend__pb2.HealthMessage.FromString,
response_serializer=backend__pb2.StatusResponse.SerializeToString,
),
} }
generic_handler = grpc.method_handlers_generic_handler( generic_handler = grpc.method_handlers_generic_handler(
'backend.Backend', rpc_method_handlers) 'backend.Backend', rpc_method_handlers)
@ -295,3 +327,37 @@ class Backend(object):
backend__pb2.Result.FromString, backend__pb2.Result.FromString,
options, channel_credentials, options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata) insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def TokenizeString(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/backend.Backend/TokenizeString',
backend__pb2.PredictOptions.SerializeToString,
backend__pb2.TokenizationResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def Status(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Status',
backend__pb2.HealthMessage.SerializeToString,
backend__pb2.StatusResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

File diff suppressed because one or more lines are too long

View file

@ -54,6 +54,16 @@ class BackendStub(object):
request_serializer=backend__pb2.TTSRequest.SerializeToString, request_serializer=backend__pb2.TTSRequest.SerializeToString,
response_deserializer=backend__pb2.Result.FromString, response_deserializer=backend__pb2.Result.FromString,
) )
self.TokenizeString = channel.unary_unary(
'/backend.Backend/TokenizeString',
request_serializer=backend__pb2.PredictOptions.SerializeToString,
response_deserializer=backend__pb2.TokenizationResponse.FromString,
)
self.Status = channel.unary_unary(
'/backend.Backend/Status',
request_serializer=backend__pb2.HealthMessage.SerializeToString,
response_deserializer=backend__pb2.StatusResponse.FromString,
)
class BackendServicer(object): class BackendServicer(object):
@ -107,6 +117,18 @@ class BackendServicer(object):
context.set_details('Method not implemented!') context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!') raise NotImplementedError('Method not implemented!')
def TokenizeString(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Status(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_BackendServicer_to_server(servicer, server): def add_BackendServicer_to_server(servicer, server):
rpc_method_handlers = { rpc_method_handlers = {
@ -150,6 +172,16 @@ def add_BackendServicer_to_server(servicer, server):
request_deserializer=backend__pb2.TTSRequest.FromString, request_deserializer=backend__pb2.TTSRequest.FromString,
response_serializer=backend__pb2.Result.SerializeToString, response_serializer=backend__pb2.Result.SerializeToString,
), ),
'TokenizeString': grpc.unary_unary_rpc_method_handler(
servicer.TokenizeString,
request_deserializer=backend__pb2.PredictOptions.FromString,
response_serializer=backend__pb2.TokenizationResponse.SerializeToString,
),
'Status': grpc.unary_unary_rpc_method_handler(
servicer.Status,
request_deserializer=backend__pb2.HealthMessage.FromString,
response_serializer=backend__pb2.StatusResponse.SerializeToString,
),
} }
generic_handler = grpc.method_handlers_generic_handler( generic_handler = grpc.method_handlers_generic_handler(
'backend.Backend', rpc_method_handlers) 'backend.Backend', rpc_method_handlers)
@ -295,3 +327,37 @@ class Backend(object):
backend__pb2.Result.FromString, backend__pb2.Result.FromString,
options, channel_credentials, options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata) insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def TokenizeString(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/backend.Backend/TokenizeString',
backend__pb2.PredictOptions.SerializeToString,
backend__pb2.TokenizationResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def Status(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Status',
backend__pb2.HealthMessage.SerializeToString,
backend__pb2.StatusResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

View file

@ -12,7 +12,7 @@ import os
# import diffusers # import diffusers
import torch import torch
from torch import autocast from torch import autocast
from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline, DiffusionPipeline, EulerAncestralDiscreteScheduler from diffusers import StableDiffusionXLPipeline, StableDiffusionDepth2ImgPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline, DiffusionPipeline, EulerAncestralDiscreteScheduler
from diffusers.pipelines.stable_diffusion import safety_checker from diffusers.pipelines.stable_diffusion import safety_checker
from compel import Compel from compel import Compel
from PIL import Image from PIL import Image
@ -150,36 +150,39 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
modelFile = request.ModelFile modelFile = request.ModelFile
fromSingleFile = request.Model.startswith("http") or request.Model.startswith("/") or local fromSingleFile = request.Model.startswith("http") or request.Model.startswith("/") or local
# If request.Model is a URL, use from_single_file
if request.IMG2IMG and request.PipelineType == "":
request.PipelineType == "StableDiffusionImg2ImgPipeline"
if request.PipelineType == "": if request.PipelineType == "":
request.PipelineType == "StableDiffusionPipeline" request.PipelineType == "StableDiffusionPipeline"
## img2img
if request.PipelineType == "StableDiffusionImg2ImgPipeline":
if fromSingleFile:
self.pipe = StableDiffusionImg2ImgPipeline.from_single_file(modelFile,
torch_dtype=torchType,
guidance_scale=cfg_scale)
else:
self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(request.Model,
torch_dtype=torchType,
guidance_scale=cfg_scale)
if request.PipelineType == "StableDiffusionDepth2ImgPipeline":
self.pipe = StableDiffusionDepth2ImgPipeline.from_pretrained(request.Model,
torch_dtype=torchType,
guidance_scale=cfg_scale)
## text2img
if request.PipelineType == "StableDiffusionPipeline": if request.PipelineType == "StableDiffusionPipeline":
if fromSingleFile: if fromSingleFile:
if request.IMG2IMG: self.pipe = StableDiffusionPipeline.from_single_file(modelFile,
self.pipe = StableDiffusionImg2ImgPipeline.from_single_file(modelFile, torch_dtype=torchType,
torch_dtype=torchType, guidance_scale=cfg_scale)
guidance_scale=cfg_scale)
else:
self.pipe = StableDiffusionPipeline.from_single_file(modelFile,
torch_dtype=torchType,
guidance_scale=cfg_scale)
else: else:
if request.IMG2IMG: self.pipe = StableDiffusionPipeline.from_pretrained(request.Model,
self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(request.Model, torch_dtype=torchType,
torch_dtype=torchType, guidance_scale=cfg_scale)
guidance_scale=cfg_scale)
else:
self.pipe = StableDiffusionPipeline.from_pretrained(request.Model,
torch_dtype=torchType,
guidance_scale=cfg_scale)
# https://github.com/huggingface/diffusers/issues/4446
# do not use text_encoder in the constructor since then
# https://github.com/huggingface/diffusers/issues/3212#issuecomment-1521841481
if CLIPSKIP and request.CLIPSkip != 0:
text_encoder = CLIPTextModel.from_pretrained(clipmodel, num_hidden_layers=request.CLIPSkip, subfolder=clipsubfolder, torch_dtype=torchType)
self.pipe.text_encoder=text_encoder
if request.PipelineType == "DiffusionPipeline": if request.PipelineType == "DiffusionPipeline":
self.pipe = DiffusionPipeline.from_pretrained(request.Model, self.pipe = DiffusionPipeline.from_pretrained(request.Model,
torch_dtype=torchType, torch_dtype=torchType,
@ -197,11 +200,19 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
use_safetensors=True, use_safetensors=True,
# variant="fp16" # variant="fp16"
guidance_scale=cfg_scale) guidance_scale=cfg_scale)
# https://github.com/huggingface/diffusers/issues/4446
# do not use text_encoder in the constructor since then
# https://github.com/huggingface/diffusers/issues/3212#issuecomment-1521841481
if CLIPSKIP and request.CLIPSkip != 0:
text_encoder = CLIPTextModel.from_pretrained(clipmodel, num_hidden_layers=request.CLIPSkip, subfolder=clipsubfolder, torch_dtype=torchType)
self.pipe.text_encoder=text_encoder
# torch_dtype needs to be customized. float16 for GPU, float32 for CPU # torch_dtype needs to be customized. float16 for GPU, float32 for CPU
# TODO: this needs to be customized # TODO: this needs to be customized
self.pipe.scheduler = get_scheduler(request.SchedulerType, self.pipe.scheduler.config) if request.SchedulerType != "":
self.pipe.scheduler = get_scheduler(request.SchedulerType, self.pipe.scheduler.config)
self.compel = Compel(tokenizer=self.pipe.tokenizer, text_encoder=self.pipe.text_encoder) self.compel = Compel(tokenizer=self.pipe.tokenizer, text_encoder=self.pipe.text_encoder)
if request.CUDA:
self.pipe.to('cuda')
except Exception as err: except Exception as err:
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
# Implement your logic here for the LoadModel service # Implement your logic here for the LoadModel service
@ -220,11 +231,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
} }
if request.src != "": if request.src != "":
# open the image with Image.open image = Image.open(request.src)
# convert the image to RGB
# resize the image to the request width and height
# XXX: untested
image = Image.open(request.src).convert("RGB").resize((request.width, request.height))
options["image"] = image options["image"] = image
# Get the keys that we will build the args for our pipe for # Get the keys that we will build the args for our pipe for

File diff suppressed because one or more lines are too long

View file

@ -54,6 +54,16 @@ class BackendStub(object):
request_serializer=backend__pb2.TTSRequest.SerializeToString, request_serializer=backend__pb2.TTSRequest.SerializeToString,
response_deserializer=backend__pb2.Result.FromString, response_deserializer=backend__pb2.Result.FromString,
) )
self.TokenizeString = channel.unary_unary(
'/backend.Backend/TokenizeString',
request_serializer=backend__pb2.PredictOptions.SerializeToString,
response_deserializer=backend__pb2.TokenizationResponse.FromString,
)
self.Status = channel.unary_unary(
'/backend.Backend/Status',
request_serializer=backend__pb2.HealthMessage.SerializeToString,
response_deserializer=backend__pb2.StatusResponse.FromString,
)
class BackendServicer(object): class BackendServicer(object):
@ -107,6 +117,18 @@ class BackendServicer(object):
context.set_details('Method not implemented!') context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!') raise NotImplementedError('Method not implemented!')
def TokenizeString(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Status(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_BackendServicer_to_server(servicer, server): def add_BackendServicer_to_server(servicer, server):
rpc_method_handlers = { rpc_method_handlers = {
@ -150,6 +172,16 @@ def add_BackendServicer_to_server(servicer, server):
request_deserializer=backend__pb2.TTSRequest.FromString, request_deserializer=backend__pb2.TTSRequest.FromString,
response_serializer=backend__pb2.Result.SerializeToString, response_serializer=backend__pb2.Result.SerializeToString,
), ),
'TokenizeString': grpc.unary_unary_rpc_method_handler(
servicer.TokenizeString,
request_deserializer=backend__pb2.PredictOptions.FromString,
response_serializer=backend__pb2.TokenizationResponse.SerializeToString,
),
'Status': grpc.unary_unary_rpc_method_handler(
servicer.Status,
request_deserializer=backend__pb2.HealthMessage.FromString,
response_serializer=backend__pb2.StatusResponse.SerializeToString,
),
} }
generic_handler = grpc.method_handlers_generic_handler( generic_handler = grpc.method_handlers_generic_handler(
'backend.Backend', rpc_method_handlers) 'backend.Backend', rpc_method_handlers)
@ -295,3 +327,37 @@ class Backend(object):
backend__pb2.Result.FromString, backend__pb2.Result.FromString,
options, channel_credentials, options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata) insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def TokenizeString(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/backend.Backend/TokenizeString',
backend__pb2.PredictOptions.SerializeToString,
backend__pb2.TokenizationResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def Status(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Status',
backend__pb2.HealthMessage.SerializeToString,
backend__pb2.StatusResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

File diff suppressed because one or more lines are too long

View file

@ -54,6 +54,16 @@ class BackendStub(object):
request_serializer=backend__pb2.TTSRequest.SerializeToString, request_serializer=backend__pb2.TTSRequest.SerializeToString,
response_deserializer=backend__pb2.Result.FromString, response_deserializer=backend__pb2.Result.FromString,
) )
self.TokenizeString = channel.unary_unary(
'/backend.Backend/TokenizeString',
request_serializer=backend__pb2.PredictOptions.SerializeToString,
response_deserializer=backend__pb2.TokenizationResponse.FromString,
)
self.Status = channel.unary_unary(
'/backend.Backend/Status',
request_serializer=backend__pb2.HealthMessage.SerializeToString,
response_deserializer=backend__pb2.StatusResponse.FromString,
)
class BackendServicer(object): class BackendServicer(object):
@ -107,6 +117,18 @@ class BackendServicer(object):
context.set_details('Method not implemented!') context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!') raise NotImplementedError('Method not implemented!')
def TokenizeString(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Status(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_BackendServicer_to_server(servicer, server): def add_BackendServicer_to_server(servicer, server):
rpc_method_handlers = { rpc_method_handlers = {
@ -150,6 +172,16 @@ def add_BackendServicer_to_server(servicer, server):
request_deserializer=backend__pb2.TTSRequest.FromString, request_deserializer=backend__pb2.TTSRequest.FromString,
response_serializer=backend__pb2.Result.SerializeToString, response_serializer=backend__pb2.Result.SerializeToString,
), ),
'TokenizeString': grpc.unary_unary_rpc_method_handler(
servicer.TokenizeString,
request_deserializer=backend__pb2.PredictOptions.FromString,
response_serializer=backend__pb2.TokenizationResponse.SerializeToString,
),
'Status': grpc.unary_unary_rpc_method_handler(
servicer.Status,
request_deserializer=backend__pb2.HealthMessage.FromString,
response_serializer=backend__pb2.StatusResponse.SerializeToString,
),
} }
generic_handler = grpc.method_handlers_generic_handler( generic_handler = grpc.method_handlers_generic_handler(
'backend.Backend', rpc_method_handlers) 'backend.Backend', rpc_method_handlers)
@ -295,3 +327,37 @@ class Backend(object):
backend__pb2.Result.FromString, backend__pb2.Result.FromString,
options, channel_credentials, options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata) insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def TokenizeString(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/backend.Backend/TokenizeString',
backend__pb2.PredictOptions.SerializeToString,
backend__pb2.TokenizationResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def Status(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Status',
backend__pb2.HealthMessage.SerializeToString,
backend__pb2.StatusResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

File diff suppressed because one or more lines are too long

View file

@ -54,6 +54,16 @@ class BackendStub(object):
request_serializer=backend__pb2.TTSRequest.SerializeToString, request_serializer=backend__pb2.TTSRequest.SerializeToString,
response_deserializer=backend__pb2.Result.FromString, response_deserializer=backend__pb2.Result.FromString,
) )
self.TokenizeString = channel.unary_unary(
'/backend.Backend/TokenizeString',
request_serializer=backend__pb2.PredictOptions.SerializeToString,
response_deserializer=backend__pb2.TokenizationResponse.FromString,
)
self.Status = channel.unary_unary(
'/backend.Backend/Status',
request_serializer=backend__pb2.HealthMessage.SerializeToString,
response_deserializer=backend__pb2.StatusResponse.FromString,
)
class BackendServicer(object): class BackendServicer(object):
@ -107,6 +117,18 @@ class BackendServicer(object):
context.set_details('Method not implemented!') context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!') raise NotImplementedError('Method not implemented!')
def TokenizeString(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Status(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_BackendServicer_to_server(servicer, server): def add_BackendServicer_to_server(servicer, server):
rpc_method_handlers = { rpc_method_handlers = {
@ -150,6 +172,16 @@ def add_BackendServicer_to_server(servicer, server):
request_deserializer=backend__pb2.TTSRequest.FromString, request_deserializer=backend__pb2.TTSRequest.FromString,
response_serializer=backend__pb2.Result.SerializeToString, response_serializer=backend__pb2.Result.SerializeToString,
), ),
'TokenizeString': grpc.unary_unary_rpc_method_handler(
servicer.TokenizeString,
request_deserializer=backend__pb2.PredictOptions.FromString,
response_serializer=backend__pb2.TokenizationResponse.SerializeToString,
),
'Status': grpc.unary_unary_rpc_method_handler(
servicer.Status,
request_deserializer=backend__pb2.HealthMessage.FromString,
response_serializer=backend__pb2.StatusResponse.SerializeToString,
),
} }
generic_handler = grpc.method_handlers_generic_handler( generic_handler = grpc.method_handlers_generic_handler(
'backend.Backend', rpc_method_handlers) 'backend.Backend', rpc_method_handlers)
@ -295,3 +327,37 @@ class Backend(object):
backend__pb2.Result.FromString, backend__pb2.Result.FromString,
options, channel_credentials, options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata) insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def TokenizeString(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/backend.Backend/TokenizeString',
backend__pb2.PredictOptions.SerializeToString,
backend__pb2.TokenizationResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def Status(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Status',
backend__pb2.HealthMessage.SerializeToString,
backend__pb2.StatusResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

View file

@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT. // Code generated by protoc-gen-go. DO NOT EDIT.
// versions: // versions:
// protoc-gen-go v1.27.1 // protoc-gen-go v1.26.0
// protoc v3.12.4 // protoc v3.15.8
// source: pkg/grpc/proto/backend.proto // source: pkg/grpc/proto/backend.proto
package proto package proto

View file

@ -1,7 +1,7 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT. // Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions: // versions:
// - protoc-gen-go-grpc v1.3.0 // - protoc-gen-go-grpc v1.2.0
// - protoc v3.12.4 // - protoc v3.15.8
// source: pkg/grpc/proto/backend.proto // source: pkg/grpc/proto/backend.proto
package proto package proto
@ -18,19 +18,6 @@ import (
// Requires gRPC-Go v1.32.0 or later. // Requires gRPC-Go v1.32.0 or later.
const _ = grpc.SupportPackageIsVersion7 const _ = grpc.SupportPackageIsVersion7
const (
Backend_Health_FullMethodName = "/backend.Backend/Health"
Backend_Predict_FullMethodName = "/backend.Backend/Predict"
Backend_LoadModel_FullMethodName = "/backend.Backend/LoadModel"
Backend_PredictStream_FullMethodName = "/backend.Backend/PredictStream"
Backend_Embedding_FullMethodName = "/backend.Backend/Embedding"
Backend_GenerateImage_FullMethodName = "/backend.Backend/GenerateImage"
Backend_AudioTranscription_FullMethodName = "/backend.Backend/AudioTranscription"
Backend_TTS_FullMethodName = "/backend.Backend/TTS"
Backend_TokenizeString_FullMethodName = "/backend.Backend/TokenizeString"
Backend_Status_FullMethodName = "/backend.Backend/Status"
)
// BackendClient is the client API for Backend service. // BackendClient is the client API for Backend service.
// //
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
@ -57,7 +44,7 @@ func NewBackendClient(cc grpc.ClientConnInterface) BackendClient {
func (c *backendClient) Health(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*Reply, error) { func (c *backendClient) Health(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*Reply, error) {
out := new(Reply) out := new(Reply)
err := c.cc.Invoke(ctx, Backend_Health_FullMethodName, in, out, opts...) err := c.cc.Invoke(ctx, "/backend.Backend/Health", in, out, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -66,7 +53,7 @@ func (c *backendClient) Health(ctx context.Context, in *HealthMessage, opts ...g
func (c *backendClient) Predict(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*Reply, error) { func (c *backendClient) Predict(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*Reply, error) {
out := new(Reply) out := new(Reply)
err := c.cc.Invoke(ctx, Backend_Predict_FullMethodName, in, out, opts...) err := c.cc.Invoke(ctx, "/backend.Backend/Predict", in, out, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -75,7 +62,7 @@ func (c *backendClient) Predict(ctx context.Context, in *PredictOptions, opts ..
func (c *backendClient) LoadModel(ctx context.Context, in *ModelOptions, opts ...grpc.CallOption) (*Result, error) { func (c *backendClient) LoadModel(ctx context.Context, in *ModelOptions, opts ...grpc.CallOption) (*Result, error) {
out := new(Result) out := new(Result)
err := c.cc.Invoke(ctx, Backend_LoadModel_FullMethodName, in, out, opts...) err := c.cc.Invoke(ctx, "/backend.Backend/LoadModel", in, out, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -83,7 +70,7 @@ func (c *backendClient) LoadModel(ctx context.Context, in *ModelOptions, opts ..
} }
func (c *backendClient) PredictStream(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (Backend_PredictStreamClient, error) { func (c *backendClient) PredictStream(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (Backend_PredictStreamClient, error) {
stream, err := c.cc.NewStream(ctx, &Backend_ServiceDesc.Streams[0], Backend_PredictStream_FullMethodName, opts...) stream, err := c.cc.NewStream(ctx, &Backend_ServiceDesc.Streams[0], "/backend.Backend/PredictStream", opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -116,7 +103,7 @@ func (x *backendPredictStreamClient) Recv() (*Reply, error) {
func (c *backendClient) Embedding(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*EmbeddingResult, error) { func (c *backendClient) Embedding(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*EmbeddingResult, error) {
out := new(EmbeddingResult) out := new(EmbeddingResult)
err := c.cc.Invoke(ctx, Backend_Embedding_FullMethodName, in, out, opts...) err := c.cc.Invoke(ctx, "/backend.Backend/Embedding", in, out, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -125,7 +112,7 @@ func (c *backendClient) Embedding(ctx context.Context, in *PredictOptions, opts
func (c *backendClient) GenerateImage(ctx context.Context, in *GenerateImageRequest, opts ...grpc.CallOption) (*Result, error) { func (c *backendClient) GenerateImage(ctx context.Context, in *GenerateImageRequest, opts ...grpc.CallOption) (*Result, error) {
out := new(Result) out := new(Result)
err := c.cc.Invoke(ctx, Backend_GenerateImage_FullMethodName, in, out, opts...) err := c.cc.Invoke(ctx, "/backend.Backend/GenerateImage", in, out, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -134,7 +121,7 @@ func (c *backendClient) GenerateImage(ctx context.Context, in *GenerateImageRequ
func (c *backendClient) AudioTranscription(ctx context.Context, in *TranscriptRequest, opts ...grpc.CallOption) (*TranscriptResult, error) { func (c *backendClient) AudioTranscription(ctx context.Context, in *TranscriptRequest, opts ...grpc.CallOption) (*TranscriptResult, error) {
out := new(TranscriptResult) out := new(TranscriptResult)
err := c.cc.Invoke(ctx, Backend_AudioTranscription_FullMethodName, in, out, opts...) err := c.cc.Invoke(ctx, "/backend.Backend/AudioTranscription", in, out, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -143,7 +130,7 @@ func (c *backendClient) AudioTranscription(ctx context.Context, in *TranscriptRe
func (c *backendClient) TTS(ctx context.Context, in *TTSRequest, opts ...grpc.CallOption) (*Result, error) { func (c *backendClient) TTS(ctx context.Context, in *TTSRequest, opts ...grpc.CallOption) (*Result, error) {
out := new(Result) out := new(Result)
err := c.cc.Invoke(ctx, Backend_TTS_FullMethodName, in, out, opts...) err := c.cc.Invoke(ctx, "/backend.Backend/TTS", in, out, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -152,7 +139,7 @@ func (c *backendClient) TTS(ctx context.Context, in *TTSRequest, opts ...grpc.Ca
func (c *backendClient) TokenizeString(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*TokenizationResponse, error) { func (c *backendClient) TokenizeString(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*TokenizationResponse, error) {
out := new(TokenizationResponse) out := new(TokenizationResponse)
err := c.cc.Invoke(ctx, Backend_TokenizeString_FullMethodName, in, out, opts...) err := c.cc.Invoke(ctx, "/backend.Backend/TokenizeString", in, out, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -161,7 +148,7 @@ func (c *backendClient) TokenizeString(ctx context.Context, in *PredictOptions,
func (c *backendClient) Status(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*StatusResponse, error) { func (c *backendClient) Status(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*StatusResponse, error) {
out := new(StatusResponse) out := new(StatusResponse)
err := c.cc.Invoke(ctx, Backend_Status_FullMethodName, in, out, opts...) err := c.cc.Invoke(ctx, "/backend.Backend/Status", in, out, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -242,7 +229,7 @@ func _Backend_Health_Handler(srv interface{}, ctx context.Context, dec func(inte
} }
info := &grpc.UnaryServerInfo{ info := &grpc.UnaryServerInfo{
Server: srv, Server: srv,
FullMethod: Backend_Health_FullMethodName, FullMethod: "/backend.Backend/Health",
} }
handler := func(ctx context.Context, req interface{}) (interface{}, error) { handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).Health(ctx, req.(*HealthMessage)) return srv.(BackendServer).Health(ctx, req.(*HealthMessage))
@ -260,7 +247,7 @@ func _Backend_Predict_Handler(srv interface{}, ctx context.Context, dec func(int
} }
info := &grpc.UnaryServerInfo{ info := &grpc.UnaryServerInfo{
Server: srv, Server: srv,
FullMethod: Backend_Predict_FullMethodName, FullMethod: "/backend.Backend/Predict",
} }
handler := func(ctx context.Context, req interface{}) (interface{}, error) { handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).Predict(ctx, req.(*PredictOptions)) return srv.(BackendServer).Predict(ctx, req.(*PredictOptions))
@ -278,7 +265,7 @@ func _Backend_LoadModel_Handler(srv interface{}, ctx context.Context, dec func(i
} }
info := &grpc.UnaryServerInfo{ info := &grpc.UnaryServerInfo{
Server: srv, Server: srv,
FullMethod: Backend_LoadModel_FullMethodName, FullMethod: "/backend.Backend/LoadModel",
} }
handler := func(ctx context.Context, req interface{}) (interface{}, error) { handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).LoadModel(ctx, req.(*ModelOptions)) return srv.(BackendServer).LoadModel(ctx, req.(*ModelOptions))
@ -317,7 +304,7 @@ func _Backend_Embedding_Handler(srv interface{}, ctx context.Context, dec func(i
} }
info := &grpc.UnaryServerInfo{ info := &grpc.UnaryServerInfo{
Server: srv, Server: srv,
FullMethod: Backend_Embedding_FullMethodName, FullMethod: "/backend.Backend/Embedding",
} }
handler := func(ctx context.Context, req interface{}) (interface{}, error) { handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).Embedding(ctx, req.(*PredictOptions)) return srv.(BackendServer).Embedding(ctx, req.(*PredictOptions))
@ -335,7 +322,7 @@ func _Backend_GenerateImage_Handler(srv interface{}, ctx context.Context, dec fu
} }
info := &grpc.UnaryServerInfo{ info := &grpc.UnaryServerInfo{
Server: srv, Server: srv,
FullMethod: Backend_GenerateImage_FullMethodName, FullMethod: "/backend.Backend/GenerateImage",
} }
handler := func(ctx context.Context, req interface{}) (interface{}, error) { handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).GenerateImage(ctx, req.(*GenerateImageRequest)) return srv.(BackendServer).GenerateImage(ctx, req.(*GenerateImageRequest))
@ -353,7 +340,7 @@ func _Backend_AudioTranscription_Handler(srv interface{}, ctx context.Context, d
} }
info := &grpc.UnaryServerInfo{ info := &grpc.UnaryServerInfo{
Server: srv, Server: srv,
FullMethod: Backend_AudioTranscription_FullMethodName, FullMethod: "/backend.Backend/AudioTranscription",
} }
handler := func(ctx context.Context, req interface{}) (interface{}, error) { handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).AudioTranscription(ctx, req.(*TranscriptRequest)) return srv.(BackendServer).AudioTranscription(ctx, req.(*TranscriptRequest))
@ -371,7 +358,7 @@ func _Backend_TTS_Handler(srv interface{}, ctx context.Context, dec func(interfa
} }
info := &grpc.UnaryServerInfo{ info := &grpc.UnaryServerInfo{
Server: srv, Server: srv,
FullMethod: Backend_TTS_FullMethodName, FullMethod: "/backend.Backend/TTS",
} }
handler := func(ctx context.Context, req interface{}) (interface{}, error) { handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).TTS(ctx, req.(*TTSRequest)) return srv.(BackendServer).TTS(ctx, req.(*TTSRequest))
@ -389,7 +376,7 @@ func _Backend_TokenizeString_Handler(srv interface{}, ctx context.Context, dec f
} }
info := &grpc.UnaryServerInfo{ info := &grpc.UnaryServerInfo{
Server: srv, Server: srv,
FullMethod: Backend_TokenizeString_FullMethodName, FullMethod: "/backend.Backend/TokenizeString",
} }
handler := func(ctx context.Context, req interface{}) (interface{}, error) { handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).TokenizeString(ctx, req.(*PredictOptions)) return srv.(BackendServer).TokenizeString(ctx, req.(*PredictOptions))
@ -407,7 +394,7 @@ func _Backend_Status_Handler(srv interface{}, ctx context.Context, dec func(inte
} }
info := &grpc.UnaryServerInfo{ info := &grpc.UnaryServerInfo{
Server: srv, Server: srv,
FullMethod: Backend_Status_FullMethodName, FullMethod: "/backend.Backend/Status",
} }
handler := func(ctx context.Context, req interface{}) (interface{}, error) { handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).Status(ctx, req.(*HealthMessage)) return srv.(BackendServer).Status(ctx, req.(*HealthMessage))