feat(diffusers): various enhancements (#895)

This commit is contained in:
Ettore Di Giacinto 2023-08-14 23:12:00 +02:00 committed by GitHub
parent 77e1ae3d70
commit a96c3bc885
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 165 additions and 101 deletions

View file

@ -13,9 +13,15 @@ import os
import torch
from torch import autocast
from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline, DiffusionPipeline, EulerAncestralDiscreteScheduler
from diffusers.pipelines.stable_diffusion import safety_checker
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
# https://github.com/CompVis/stable-diffusion/issues/239#issuecomment-1627615287
def sc(self, clip_input, images) : return images, [False for i in images]
# edit the StableDiffusionSafetyChecker class so that, when called, it just returns the images and an array of True values
safety_checker.StableDiffusionSafetyChecker.forward = sc
# Implement the BackendServicer class with the service methods
class BackendServicer(backend_pb2_grpc.BackendServicer):
def Health(self, request, context):
@ -28,24 +34,48 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
if request.F16Memory:
torchType = torch.float16
local = False
modelFile = request.Model
# Check if ModelFile exists
if request.ModelFile != "":
if os.path.exists(request.ModelFile):
local = True
modelFile = request.ModelFile
fromSingleFile = request.Model.startswith("http") or request.Model.startswith("/") or local
# If request.Model is a URL, use from_single_file
if request.PipelineType == "":
request.PipelineType == "StableDiffusionPipeline"
if request.PipelineType == "StableDiffusionPipeline":
self.pipe = StableDiffusionPipeline.from_pretrained(request.Model,
if fromSingleFile:
self.pipe = StableDiffusionPipeline.from_single_file(modelFile,
torch_dtype=torchType)
else:
self.pipe = StableDiffusionPipeline.from_pretrained(request.Model,
torch_dtype=torchType)
if request.PipelineType == "DiffusionPipeline":
self.pipe = DiffusionPipeline.from_pretrained(request.Model,
if fromSingleFile:
self.pipe = DiffusionPipeline.from_single_file(modelFile,
torch_dtype=torchType)
else:
self.pipe = DiffusionPipeline.from_pretrained(request.Model,
torch_dtype=torchType)
if request.PipelineType == "StableDiffusionXLPipeline":
self.pipe = StableDiffusionXLPipeline.from_pretrained(
request.Model,
torch_dtype=torchType,
use_safetensors=True,
# variant="fp16"
)
if fromSingleFile:
self.pipe = StableDiffusionXLPipeline.from_single_file(modelFile,
torch_dtype=torchType, use_safetensors=True)
else:
self.pipe = StableDiffusionXLPipeline.from_pretrained(
request.Model,
torch_dtype=torchType,
use_safetensors=True,
# variant="fp16"
)
# torch_dtype needs to be customized. float16 for GPU, float32 for CPU
# TODO: this needs to be customized
@ -64,19 +94,34 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
def GenerateImage(self, request, context):
prompt = request.positive_prompt
negative_prompt = request.negative_prompt
# create a dictionary of values for the parameters
options = {
"negative_prompt": request.negative_prompt,
"width": request.width,
"height": request.height,
"num_inference_steps": request.step
}
# Get the keys that we will build the args for our pipe for
keys = options.keys()
if request.EnableParameters != "":
keys = request.EnableParameters.split(",")
if request.EnableParameters == "none":
keys = []
# create a dictionary of parameters by using the keys from EnableParameters and the values from defaults
kwargs = {key: options[key] for key in keys}
# pass the kwargs dictionary to the self.pipe method
image = self.pipe(
prompt,
negative_prompt=negative_prompt,
width=request.width,
height=request.height,
# guidance_scale=12,
target_size=(request.width,request.height),
original_size=(4096,4096),
num_inference_steps=request.step
**kwargs
).images[0]
# save the result
image.save(request.dst)
return backend_pb2.Result(message="Model loaded successfully", success=True)