mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-20 10:35:01 +00:00
feat(diffusers): add img2img and clip_skip, support more kernels schedulers (#906)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
ddf9bc2335
commit
2bacd0180d
13 changed files with 435 additions and 213 deletions
|
@ -15,15 +15,108 @@ from torch import autocast
|
|||
from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline, DiffusionPipeline, EulerAncestralDiscreteScheduler
|
||||
from diffusers.pipelines.stable_diffusion import safety_checker
|
||||
from compel import Compel
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
from diffusers import StableDiffusionImg2ImgPipeline
|
||||
from transformers import CLIPTextModel
|
||||
from enum import Enum
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
COMPEL=os.environ.get("COMPEL", "1") == "1"
|
||||
CLIPSKIP=os.environ.get("CLIPSKIP", "1") == "1"
|
||||
|
||||
# https://github.com/CompVis/stable-diffusion/issues/239#issuecomment-1627615287
|
||||
def sc(self, clip_input, images) : return images, [False for i in images]
|
||||
# edit the StableDiffusionSafetyChecker class so that, when called, it just returns the images and an array of True values
|
||||
safety_checker.StableDiffusionSafetyChecker.forward = sc
|
||||
|
||||
from diffusers.schedulers import (
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
UniPCMultistepScheduler,
|
||||
)
|
||||
# The scheduler list mapping was taken from here: https://github.com/neggles/animatediff-cli/blob/6f336f5f4b5e38e85d7f06f1744ef42d0a45f2a7/src/animatediff/schedulers.py#L39
|
||||
# Credits to https://github.com/neggles
|
||||
# See https://github.com/huggingface/diffusers/issues/4167 for more details on sched mapping from A1111
|
||||
class DiffusionScheduler(str, Enum):
|
||||
ddim = "ddim" # DDIM
|
||||
pndm = "pndm" # PNDM
|
||||
heun = "heun" # Heun
|
||||
unipc = "unipc" # UniPC
|
||||
euler = "euler" # Euler
|
||||
euler_a = "euler_a" # Euler a
|
||||
|
||||
lms = "lms" # LMS
|
||||
k_lms = "k_lms" # LMS Karras
|
||||
|
||||
dpm_2 = "dpm_2" # DPM2
|
||||
k_dpm_2 = "k_dpm_2" # DPM2 Karras
|
||||
|
||||
dpm_2_a = "dpm_2_a" # DPM2 a
|
||||
k_dpm_2_a = "k_dpm_2_a" # DPM2 a Karras
|
||||
|
||||
dpmpp_2m = "dpmpp_2m" # DPM++ 2M
|
||||
k_dpmpp_2m = "k_dpmpp_2m" # DPM++ 2M Karras
|
||||
|
||||
dpmpp_sde = "dpmpp_sde" # DPM++ SDE
|
||||
k_dpmpp_sde = "k_dpmpp_sde" # DPM++ SDE Karras
|
||||
|
||||
dpmpp_2m_sde = "dpmpp_2m_sde" # DPM++ 2M SDE
|
||||
k_dpmpp_2m_sde = "k_dpmpp_2m_sde" # DPM++ 2M SDE Karras
|
||||
|
||||
|
||||
def get_scheduler(name: str, config: dict = {}):
|
||||
is_karras = name.startswith("k_")
|
||||
if is_karras:
|
||||
# strip the k_ prefix and add the karras sigma flag to config
|
||||
name = name.lstrip("k_")
|
||||
config["use_karras_sigmas"] = True
|
||||
|
||||
if name == DiffusionScheduler.ddim:
|
||||
sched_class = DDIMScheduler
|
||||
elif name == DiffusionScheduler.pndm:
|
||||
sched_class = PNDMScheduler
|
||||
elif name == DiffusionScheduler.heun:
|
||||
sched_class = HeunDiscreteScheduler
|
||||
elif name == DiffusionScheduler.unipc:
|
||||
sched_class = UniPCMultistepScheduler
|
||||
elif name == DiffusionScheduler.euler:
|
||||
sched_class = EulerDiscreteScheduler
|
||||
elif name == DiffusionScheduler.euler_a:
|
||||
sched_class = EulerAncestralDiscreteScheduler
|
||||
elif name == DiffusionScheduler.lms:
|
||||
sched_class = LMSDiscreteScheduler
|
||||
elif name == DiffusionScheduler.dpm_2:
|
||||
# Equivalent to DPM2 in K-Diffusion
|
||||
sched_class = KDPM2DiscreteScheduler
|
||||
elif name == DiffusionScheduler.dpm_2_a:
|
||||
# Equivalent to `DPM2 a`` in K-Diffusion
|
||||
sched_class = KDPM2AncestralDiscreteScheduler
|
||||
elif name == DiffusionScheduler.dpmpp_2m:
|
||||
# Equivalent to `DPM++ 2M` in K-Diffusion
|
||||
sched_class = DPMSolverMultistepScheduler
|
||||
config["algorithm_type"] = "dpmsolver++"
|
||||
config["solver_order"] = 2
|
||||
elif name == DiffusionScheduler.dpmpp_sde:
|
||||
# Equivalent to `DPM++ SDE` in K-Diffusion
|
||||
sched_class = DPMSolverSinglestepScheduler
|
||||
elif name == DiffusionScheduler.dpmpp_2m_sde:
|
||||
# Equivalent to `DPM++ 2M SDE` in K-Diffusion
|
||||
sched_class = DPMSolverMultistepScheduler
|
||||
config["algorithm_type"] = "sde-dpmsolver++"
|
||||
else:
|
||||
raise ValueError(f"Invalid scheduler '{'k_' if is_karras else ''}{name}'")
|
||||
|
||||
return sched_class.from_config(config)
|
||||
|
||||
# Implement the BackendServicer class with the service methods
|
||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
def Health(self, request, context):
|
||||
|
@ -42,39 +135,55 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||
cfg_scale = 7
|
||||
if request.CFGScale != 0:
|
||||
cfg_scale = request.CFGScale
|
||||
|
||||
|
||||
clipmodel = "runwayml/stable-diffusion-v1-5"
|
||||
if request.CLIPModel != "":
|
||||
clipmodel = request.CLIPModel
|
||||
clipsubfolder = "text_encoder"
|
||||
if request.CLIPSubfolder != "":
|
||||
clipsubfolder = request.CLIPSubfolder
|
||||
|
||||
# Check if ModelFile exists
|
||||
if request.ModelFile != "":
|
||||
if os.path.exists(request.ModelFile):
|
||||
local = True
|
||||
modelFile = request.ModelFile
|
||||
|
||||
|
||||
fromSingleFile = request.Model.startswith("http") or request.Model.startswith("/") or local
|
||||
# If request.Model is a URL, use from_single_file
|
||||
|
||||
|
||||
if request.PipelineType == "":
|
||||
request.PipelineType == "StableDiffusionPipeline"
|
||||
|
||||
if request.PipelineType == "StableDiffusionPipeline":
|
||||
if fromSingleFile:
|
||||
self.pipe = StableDiffusionPipeline.from_single_file(modelFile,
|
||||
torch_dtype=torchType,
|
||||
guidance_scale=cfg_scale)
|
||||
if request.IMG2IMG:
|
||||
self.pipe = StableDiffusionImg2ImgPipeline.from_single_file(modelFile,
|
||||
torch_dtype=torchType,
|
||||
guidance_scale=cfg_scale)
|
||||
else:
|
||||
self.pipe = StableDiffusionPipeline.from_single_file(modelFile,
|
||||
torch_dtype=torchType,
|
||||
guidance_scale=cfg_scale)
|
||||
else:
|
||||
self.pipe = StableDiffusionPipeline.from_pretrained(request.Model,
|
||||
torch_dtype=torchType,
|
||||
guidance_scale=cfg_scale)
|
||||
|
||||
if request.IMG2IMG:
|
||||
self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(request.Model,
|
||||
torch_dtype=torchType,
|
||||
guidance_scale=cfg_scale)
|
||||
else:
|
||||
self.pipe = StableDiffusionPipeline.from_pretrained(request.Model,
|
||||
torch_dtype=torchType,
|
||||
guidance_scale=cfg_scale)
|
||||
# https://github.com/huggingface/diffusers/issues/4446
|
||||
# do not use text_encoder in the constructor since then
|
||||
# https://github.com/huggingface/diffusers/issues/3212#issuecomment-1521841481
|
||||
if CLIPSKIP and request.CLIPSkip != 0:
|
||||
text_encoder = CLIPTextModel.from_pretrained(clipmodel, num_hidden_layers=request.CLIPSkip, subfolder=clipsubfolder, torch_dtype=torchType)
|
||||
self.pipe.text_encoder=text_encoder
|
||||
if request.PipelineType == "DiffusionPipeline":
|
||||
if fromSingleFile:
|
||||
self.pipe = DiffusionPipeline.from_single_file(modelFile,
|
||||
torch_dtype=torchType,
|
||||
guidance_scale=cfg_scale)
|
||||
else:
|
||||
self.pipe = DiffusionPipeline.from_pretrained(request.Model,
|
||||
torch_dtype=torchType,
|
||||
guidance_scale=cfg_scale)
|
||||
self.pipe = DiffusionPipeline.from_pretrained(request.Model,
|
||||
torch_dtype=torchType,
|
||||
guidance_scale=cfg_scale)
|
||||
|
||||
if request.PipelineType == "StableDiffusionXLPipeline":
|
||||
if fromSingleFile:
|
||||
|
@ -91,17 +200,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||
|
||||
# torch_dtype needs to be customized. float16 for GPU, float32 for CPU
|
||||
# TODO: this needs to be customized
|
||||
if request.SchedulerType == "EulerAncestralDiscreteScheduler":
|
||||
self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipe.scheduler.config)
|
||||
if request.SchedulerType == "DPMSolverMultistepScheduler":
|
||||
self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(self.pipe.scheduler.config)
|
||||
if request.SchedulerType == "DPMSolverMultistepScheduler++":
|
||||
self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(self.pipe.scheduler.config,algorithm_type="dpmsolver++")
|
||||
if request.SchedulerType == "DPMSolverMultistepSchedulerSDE++":
|
||||
self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(self.pipe.scheduler.config, algorithm_type="sde-dpmsolver++")
|
||||
if request.CUDA:
|
||||
self.pipe.to('cuda')
|
||||
|
||||
self.pipe.scheduler = get_scheduler(request.SchedulerType, self.pipe.scheduler.config)
|
||||
self.compel = Compel(tokenizer=self.pipe.tokenizer, text_encoder=self.pipe.text_encoder)
|
||||
except Exception as err:
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
|
@ -117,9 +216,17 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||
"negative_prompt": request.negative_prompt,
|
||||
"width": request.width,
|
||||
"height": request.height,
|
||||
"num_inference_steps": request.step
|
||||
"num_inference_steps": request.step,
|
||||
}
|
||||
|
||||
if request.src != "":
|
||||
# open the image with Image.open
|
||||
# convert the image to RGB
|
||||
# resize the image to the request width and height
|
||||
# XXX: untested
|
||||
image = Image.open(request.src).convert("RGB").resize((request.width, request.height))
|
||||
options["image"] = image
|
||||
|
||||
# Get the keys that we will build the args for our pipe for
|
||||
keys = options.keys()
|
||||
|
||||
|
@ -131,6 +238,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||
|
||||
# create a dictionary of parameters by using the keys from EnableParameters and the values from defaults
|
||||
kwargs = {key: options[key] for key in keys}
|
||||
|
||||
image = {}
|
||||
if COMPEL:
|
||||
conditioning = self.compel.build_conditioning_tensor(prompt)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue