mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-20 10:35:01 +00:00
feat(diffusers): various enhancements (#895)
This commit is contained in:
parent
77e1ae3d70
commit
a96c3bc885
11 changed files with 165 additions and 101 deletions
|
@ -13,9 +13,15 @@ import os
|
|||
import torch
|
||||
from torch import autocast
|
||||
from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline, DiffusionPipeline, EulerAncestralDiscreteScheduler
|
||||
from diffusers.pipelines.stable_diffusion import safety_checker
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
|
||||
# https://github.com/CompVis/stable-diffusion/issues/239#issuecomment-1627615287
|
||||
def sc(self, clip_input, images) : return images, [False for i in images]
|
||||
# edit the StableDiffusionSafetyChecker class so that, when called, it just returns the images and an array of True values
|
||||
safety_checker.StableDiffusionSafetyChecker.forward = sc
|
||||
|
||||
# Implement the BackendServicer class with the service methods
|
||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
def Health(self, request, context):
|
||||
|
@ -28,24 +34,48 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||
if request.F16Memory:
|
||||
torchType = torch.float16
|
||||
|
||||
local = False
|
||||
modelFile = request.Model
|
||||
# Check if ModelFile exists
|
||||
if request.ModelFile != "":
|
||||
if os.path.exists(request.ModelFile):
|
||||
local = True
|
||||
modelFile = request.ModelFile
|
||||
|
||||
fromSingleFile = request.Model.startswith("http") or request.Model.startswith("/") or local
|
||||
# If request.Model is a URL, use from_single_file
|
||||
|
||||
|
||||
if request.PipelineType == "":
|
||||
request.PipelineType == "StableDiffusionPipeline"
|
||||
|
||||
if request.PipelineType == "StableDiffusionPipeline":
|
||||
self.pipe = StableDiffusionPipeline.from_pretrained(request.Model,
|
||||
if fromSingleFile:
|
||||
self.pipe = StableDiffusionPipeline.from_single_file(modelFile,
|
||||
torch_dtype=torchType)
|
||||
else:
|
||||
self.pipe = StableDiffusionPipeline.from_pretrained(request.Model,
|
||||
torch_dtype=torchType)
|
||||
|
||||
if request.PipelineType == "DiffusionPipeline":
|
||||
self.pipe = DiffusionPipeline.from_pretrained(request.Model,
|
||||
if fromSingleFile:
|
||||
self.pipe = DiffusionPipeline.from_single_file(modelFile,
|
||||
torch_dtype=torchType)
|
||||
else:
|
||||
self.pipe = DiffusionPipeline.from_pretrained(request.Model,
|
||||
torch_dtype=torchType)
|
||||
|
||||
if request.PipelineType == "StableDiffusionXLPipeline":
|
||||
self.pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
request.Model,
|
||||
torch_dtype=torchType,
|
||||
use_safetensors=True,
|
||||
# variant="fp16"
|
||||
)
|
||||
if fromSingleFile:
|
||||
self.pipe = StableDiffusionXLPipeline.from_single_file(modelFile,
|
||||
torch_dtype=torchType, use_safetensors=True)
|
||||
else:
|
||||
self.pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
request.Model,
|
||||
torch_dtype=torchType,
|
||||
use_safetensors=True,
|
||||
# variant="fp16"
|
||||
)
|
||||
|
||||
# torch_dtype needs to be customized. float16 for GPU, float32 for CPU
|
||||
# TODO: this needs to be customized
|
||||
|
@ -64,19 +94,34 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||
def GenerateImage(self, request, context):
|
||||
|
||||
prompt = request.positive_prompt
|
||||
negative_prompt = request.negative_prompt
|
||||
|
||||
# create a dictionary of values for the parameters
|
||||
options = {
|
||||
"negative_prompt": request.negative_prompt,
|
||||
"width": request.width,
|
||||
"height": request.height,
|
||||
"num_inference_steps": request.step
|
||||
}
|
||||
|
||||
# Get the keys that we will build the args for our pipe for
|
||||
keys = options.keys()
|
||||
|
||||
if request.EnableParameters != "":
|
||||
keys = request.EnableParameters.split(",")
|
||||
|
||||
if request.EnableParameters == "none":
|
||||
keys = []
|
||||
|
||||
# create a dictionary of parameters by using the keys from EnableParameters and the values from defaults
|
||||
kwargs = {key: options[key] for key in keys}
|
||||
|
||||
# pass the kwargs dictionary to the self.pipe method
|
||||
image = self.pipe(
|
||||
prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
width=request.width,
|
||||
height=request.height,
|
||||
# guidance_scale=12,
|
||||
target_size=(request.width,request.height),
|
||||
original_size=(4096,4096),
|
||||
num_inference_steps=request.step
|
||||
**kwargs
|
||||
).images[0]
|
||||
|
||||
# save the result
|
||||
image.save(request.dst)
|
||||
|
||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue