feat(diffusers): be consistent with pipelines, support also depthimg2img (#926)

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2023-08-18 22:06:24 +02:00 committed by GitHub
parent 8cb1061c11
commit 1079b18ff7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 480 additions and 103 deletions

File diff suppressed because one or more lines are too long

View file

@ -54,6 +54,16 @@ class BackendStub(object):
request_serializer=backend__pb2.TTSRequest.SerializeToString,
response_deserializer=backend__pb2.Result.FromString,
)
self.TokenizeString = channel.unary_unary(
'/backend.Backend/TokenizeString',
request_serializer=backend__pb2.PredictOptions.SerializeToString,
response_deserializer=backend__pb2.TokenizationResponse.FromString,
)
self.Status = channel.unary_unary(
'/backend.Backend/Status',
request_serializer=backend__pb2.HealthMessage.SerializeToString,
response_deserializer=backend__pb2.StatusResponse.FromString,
)
class BackendServicer(object):
@ -107,6 +117,18 @@ class BackendServicer(object):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def TokenizeString(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Status(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_BackendServicer_to_server(servicer, server):
rpc_method_handlers = {
@ -150,6 +172,16 @@ def add_BackendServicer_to_server(servicer, server):
request_deserializer=backend__pb2.TTSRequest.FromString,
response_serializer=backend__pb2.Result.SerializeToString,
),
'TokenizeString': grpc.unary_unary_rpc_method_handler(
servicer.TokenizeString,
request_deserializer=backend__pb2.PredictOptions.FromString,
response_serializer=backend__pb2.TokenizationResponse.SerializeToString,
),
'Status': grpc.unary_unary_rpc_method_handler(
servicer.Status,
request_deserializer=backend__pb2.HealthMessage.FromString,
response_serializer=backend__pb2.StatusResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'backend.Backend', rpc_method_handlers)
@ -295,3 +327,37 @@ class Backend(object):
backend__pb2.Result.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def TokenizeString(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/backend.Backend/TokenizeString',
backend__pb2.PredictOptions.SerializeToString,
backend__pb2.TokenizationResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def Status(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Status',
backend__pb2.HealthMessage.SerializeToString,
backend__pb2.StatusResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

File diff suppressed because one or more lines are too long

View file

@ -54,6 +54,16 @@ class BackendStub(object):
request_serializer=backend__pb2.TTSRequest.SerializeToString,
response_deserializer=backend__pb2.Result.FromString,
)
self.TokenizeString = channel.unary_unary(
'/backend.Backend/TokenizeString',
request_serializer=backend__pb2.PredictOptions.SerializeToString,
response_deserializer=backend__pb2.TokenizationResponse.FromString,
)
self.Status = channel.unary_unary(
'/backend.Backend/Status',
request_serializer=backend__pb2.HealthMessage.SerializeToString,
response_deserializer=backend__pb2.StatusResponse.FromString,
)
class BackendServicer(object):
@ -107,6 +117,18 @@ class BackendServicer(object):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def TokenizeString(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Status(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_BackendServicer_to_server(servicer, server):
rpc_method_handlers = {
@ -150,6 +172,16 @@ def add_BackendServicer_to_server(servicer, server):
request_deserializer=backend__pb2.TTSRequest.FromString,
response_serializer=backend__pb2.Result.SerializeToString,
),
'TokenizeString': grpc.unary_unary_rpc_method_handler(
servicer.TokenizeString,
request_deserializer=backend__pb2.PredictOptions.FromString,
response_serializer=backend__pb2.TokenizationResponse.SerializeToString,
),
'Status': grpc.unary_unary_rpc_method_handler(
servicer.Status,
request_deserializer=backend__pb2.HealthMessage.FromString,
response_serializer=backend__pb2.StatusResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'backend.Backend', rpc_method_handlers)
@ -295,3 +327,37 @@ class Backend(object):
backend__pb2.Result.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def TokenizeString(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/backend.Backend/TokenizeString',
backend__pb2.PredictOptions.SerializeToString,
backend__pb2.TokenizationResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def Status(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Status',
backend__pb2.HealthMessage.SerializeToString,
backend__pb2.StatusResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

View file

@ -12,7 +12,7 @@ import os
# import diffusers
import torch
from torch import autocast
from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline, DiffusionPipeline, EulerAncestralDiscreteScheduler
from diffusers import StableDiffusionXLPipeline, StableDiffusionDepth2ImgPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline, DiffusionPipeline, EulerAncestralDiscreteScheduler
from diffusers.pipelines.stable_diffusion import safety_checker
from compel import Compel
from PIL import Image
@ -150,36 +150,39 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
modelFile = request.ModelFile
fromSingleFile = request.Model.startswith("http") or request.Model.startswith("/") or local
# If request.Model is a URL, use from_single_file
if request.IMG2IMG and request.PipelineType == "":
request.PipelineType == "StableDiffusionImg2ImgPipeline"
if request.PipelineType == "":
request.PipelineType == "StableDiffusionPipeline"
## img2img
if request.PipelineType == "StableDiffusionImg2ImgPipeline":
if fromSingleFile:
self.pipe = StableDiffusionImg2ImgPipeline.from_single_file(modelFile,
torch_dtype=torchType,
guidance_scale=cfg_scale)
else:
self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(request.Model,
torch_dtype=torchType,
guidance_scale=cfg_scale)
if request.PipelineType == "StableDiffusionDepth2ImgPipeline":
self.pipe = StableDiffusionDepth2ImgPipeline.from_pretrained(request.Model,
torch_dtype=torchType,
guidance_scale=cfg_scale)
## text2img
if request.PipelineType == "StableDiffusionPipeline":
if fromSingleFile:
if request.IMG2IMG:
self.pipe = StableDiffusionImg2ImgPipeline.from_single_file(modelFile,
torch_dtype=torchType,
guidance_scale=cfg_scale)
else:
self.pipe = StableDiffusionPipeline.from_single_file(modelFile,
torch_dtype=torchType,
guidance_scale=cfg_scale)
self.pipe = StableDiffusionPipeline.from_single_file(modelFile,
torch_dtype=torchType,
guidance_scale=cfg_scale)
else:
if request.IMG2IMG:
self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(request.Model,
torch_dtype=torchType,
guidance_scale=cfg_scale)
else:
self.pipe = StableDiffusionPipeline.from_pretrained(request.Model,
torch_dtype=torchType,
guidance_scale=cfg_scale)
# https://github.com/huggingface/diffusers/issues/4446
# do not use text_encoder in the constructor since then
# https://github.com/huggingface/diffusers/issues/3212#issuecomment-1521841481
if CLIPSKIP and request.CLIPSkip != 0:
text_encoder = CLIPTextModel.from_pretrained(clipmodel, num_hidden_layers=request.CLIPSkip, subfolder=clipsubfolder, torch_dtype=torchType)
self.pipe.text_encoder=text_encoder
self.pipe = StableDiffusionPipeline.from_pretrained(request.Model,
torch_dtype=torchType,
guidance_scale=cfg_scale)
if request.PipelineType == "DiffusionPipeline":
self.pipe = DiffusionPipeline.from_pretrained(request.Model,
torch_dtype=torchType,
@ -197,11 +200,19 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
use_safetensors=True,
# variant="fp16"
guidance_scale=cfg_scale)
# https://github.com/huggingface/diffusers/issues/4446
# do not use text_encoder in the constructor since then
# https://github.com/huggingface/diffusers/issues/3212#issuecomment-1521841481
if CLIPSKIP and request.CLIPSkip != 0:
text_encoder = CLIPTextModel.from_pretrained(clipmodel, num_hidden_layers=request.CLIPSkip, subfolder=clipsubfolder, torch_dtype=torchType)
self.pipe.text_encoder=text_encoder
# torch_dtype needs to be customized. float16 for GPU, float32 for CPU
# TODO: this needs to be customized
self.pipe.scheduler = get_scheduler(request.SchedulerType, self.pipe.scheduler.config)
if request.SchedulerType != "":
self.pipe.scheduler = get_scheduler(request.SchedulerType, self.pipe.scheduler.config)
self.compel = Compel(tokenizer=self.pipe.tokenizer, text_encoder=self.pipe.text_encoder)
if request.CUDA:
self.pipe.to('cuda')
except Exception as err:
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
# Implement your logic here for the LoadModel service
@ -220,11 +231,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
}
if request.src != "":
# open the image with Image.open
# convert the image to RGB
# resize the image to the request width and height
# XXX: untested
image = Image.open(request.src).convert("RGB").resize((request.width, request.height))
image = Image.open(request.src)
options["image"] = image
# Get the keys that we will build the args for our pipe for

File diff suppressed because one or more lines are too long

View file

@ -54,6 +54,16 @@ class BackendStub(object):
request_serializer=backend__pb2.TTSRequest.SerializeToString,
response_deserializer=backend__pb2.Result.FromString,
)
self.TokenizeString = channel.unary_unary(
'/backend.Backend/TokenizeString',
request_serializer=backend__pb2.PredictOptions.SerializeToString,
response_deserializer=backend__pb2.TokenizationResponse.FromString,
)
self.Status = channel.unary_unary(
'/backend.Backend/Status',
request_serializer=backend__pb2.HealthMessage.SerializeToString,
response_deserializer=backend__pb2.StatusResponse.FromString,
)
class BackendServicer(object):
@ -107,6 +117,18 @@ class BackendServicer(object):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def TokenizeString(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Status(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_BackendServicer_to_server(servicer, server):
rpc_method_handlers = {
@ -150,6 +172,16 @@ def add_BackendServicer_to_server(servicer, server):
request_deserializer=backend__pb2.TTSRequest.FromString,
response_serializer=backend__pb2.Result.SerializeToString,
),
'TokenizeString': grpc.unary_unary_rpc_method_handler(
servicer.TokenizeString,
request_deserializer=backend__pb2.PredictOptions.FromString,
response_serializer=backend__pb2.TokenizationResponse.SerializeToString,
),
'Status': grpc.unary_unary_rpc_method_handler(
servicer.Status,
request_deserializer=backend__pb2.HealthMessage.FromString,
response_serializer=backend__pb2.StatusResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'backend.Backend', rpc_method_handlers)
@ -295,3 +327,37 @@ class Backend(object):
backend__pb2.Result.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def TokenizeString(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/backend.Backend/TokenizeString',
backend__pb2.PredictOptions.SerializeToString,
backend__pb2.TokenizationResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def Status(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Status',
backend__pb2.HealthMessage.SerializeToString,
backend__pb2.StatusResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

File diff suppressed because one or more lines are too long

View file

@ -54,6 +54,16 @@ class BackendStub(object):
request_serializer=backend__pb2.TTSRequest.SerializeToString,
response_deserializer=backend__pb2.Result.FromString,
)
self.TokenizeString = channel.unary_unary(
'/backend.Backend/TokenizeString',
request_serializer=backend__pb2.PredictOptions.SerializeToString,
response_deserializer=backend__pb2.TokenizationResponse.FromString,
)
self.Status = channel.unary_unary(
'/backend.Backend/Status',
request_serializer=backend__pb2.HealthMessage.SerializeToString,
response_deserializer=backend__pb2.StatusResponse.FromString,
)
class BackendServicer(object):
@ -107,6 +117,18 @@ class BackendServicer(object):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def TokenizeString(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Status(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_BackendServicer_to_server(servicer, server):
rpc_method_handlers = {
@ -150,6 +172,16 @@ def add_BackendServicer_to_server(servicer, server):
request_deserializer=backend__pb2.TTSRequest.FromString,
response_serializer=backend__pb2.Result.SerializeToString,
),
'TokenizeString': grpc.unary_unary_rpc_method_handler(
servicer.TokenizeString,
request_deserializer=backend__pb2.PredictOptions.FromString,
response_serializer=backend__pb2.TokenizationResponse.SerializeToString,
),
'Status': grpc.unary_unary_rpc_method_handler(
servicer.Status,
request_deserializer=backend__pb2.HealthMessage.FromString,
response_serializer=backend__pb2.StatusResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'backend.Backend', rpc_method_handlers)
@ -295,3 +327,37 @@ class Backend(object):
backend__pb2.Result.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def TokenizeString(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/backend.Backend/TokenizeString',
backend__pb2.PredictOptions.SerializeToString,
backend__pb2.TokenizationResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def Status(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Status',
backend__pb2.HealthMessage.SerializeToString,
backend__pb2.StatusResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

File diff suppressed because one or more lines are too long

View file

@ -54,6 +54,16 @@ class BackendStub(object):
request_serializer=backend__pb2.TTSRequest.SerializeToString,
response_deserializer=backend__pb2.Result.FromString,
)
self.TokenizeString = channel.unary_unary(
'/backend.Backend/TokenizeString',
request_serializer=backend__pb2.PredictOptions.SerializeToString,
response_deserializer=backend__pb2.TokenizationResponse.FromString,
)
self.Status = channel.unary_unary(
'/backend.Backend/Status',
request_serializer=backend__pb2.HealthMessage.SerializeToString,
response_deserializer=backend__pb2.StatusResponse.FromString,
)
class BackendServicer(object):
@ -107,6 +117,18 @@ class BackendServicer(object):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def TokenizeString(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Status(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_BackendServicer_to_server(servicer, server):
rpc_method_handlers = {
@ -150,6 +172,16 @@ def add_BackendServicer_to_server(servicer, server):
request_deserializer=backend__pb2.TTSRequest.FromString,
response_serializer=backend__pb2.Result.SerializeToString,
),
'TokenizeString': grpc.unary_unary_rpc_method_handler(
servicer.TokenizeString,
request_deserializer=backend__pb2.PredictOptions.FromString,
response_serializer=backend__pb2.TokenizationResponse.SerializeToString,
),
'Status': grpc.unary_unary_rpc_method_handler(
servicer.Status,
request_deserializer=backend__pb2.HealthMessage.FromString,
response_serializer=backend__pb2.StatusResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'backend.Backend', rpc_method_handlers)
@ -295,3 +327,37 @@ class Backend(object):
backend__pb2.Result.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def TokenizeString(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/backend.Backend/TokenizeString',
backend__pb2.PredictOptions.SerializeToString,
backend__pb2.TokenizationResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def Status(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Status',
backend__pb2.HealthMessage.SerializeToString,
backend__pb2.StatusResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)