mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-20 10:35:01 +00:00
feat(diffusers): be consistent with pipelines, support also depthimg2img (#926)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
8cb1061c11
commit
1079b18ff7
14 changed files with 480 additions and 103 deletions
|
@ -12,7 +12,7 @@ import os
|
|||
# import diffusers
|
||||
import torch
|
||||
from torch import autocast
|
||||
from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline, DiffusionPipeline, EulerAncestralDiscreteScheduler
|
||||
from diffusers import StableDiffusionXLPipeline, StableDiffusionDepth2ImgPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline, DiffusionPipeline, EulerAncestralDiscreteScheduler
|
||||
from diffusers.pipelines.stable_diffusion import safety_checker
|
||||
from compel import Compel
|
||||
from PIL import Image
|
||||
|
@ -150,36 +150,39 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||
modelFile = request.ModelFile
|
||||
|
||||
fromSingleFile = request.Model.startswith("http") or request.Model.startswith("/") or local
|
||||
# If request.Model is a URL, use from_single_file
|
||||
|
||||
if request.IMG2IMG and request.PipelineType == "":
|
||||
request.PipelineType == "StableDiffusionImg2ImgPipeline"
|
||||
|
||||
if request.PipelineType == "":
|
||||
request.PipelineType == "StableDiffusionPipeline"
|
||||
|
||||
## img2img
|
||||
if request.PipelineType == "StableDiffusionImg2ImgPipeline":
|
||||
if fromSingleFile:
|
||||
self.pipe = StableDiffusionImg2ImgPipeline.from_single_file(modelFile,
|
||||
torch_dtype=torchType,
|
||||
guidance_scale=cfg_scale)
|
||||
else:
|
||||
self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(request.Model,
|
||||
torch_dtype=torchType,
|
||||
guidance_scale=cfg_scale)
|
||||
|
||||
if request.PipelineType == "StableDiffusionDepth2ImgPipeline":
|
||||
self.pipe = StableDiffusionDepth2ImgPipeline.from_pretrained(request.Model,
|
||||
torch_dtype=torchType,
|
||||
guidance_scale=cfg_scale)
|
||||
## text2img
|
||||
if request.PipelineType == "StableDiffusionPipeline":
|
||||
if fromSingleFile:
|
||||
if request.IMG2IMG:
|
||||
self.pipe = StableDiffusionImg2ImgPipeline.from_single_file(modelFile,
|
||||
torch_dtype=torchType,
|
||||
guidance_scale=cfg_scale)
|
||||
else:
|
||||
self.pipe = StableDiffusionPipeline.from_single_file(modelFile,
|
||||
torch_dtype=torchType,
|
||||
guidance_scale=cfg_scale)
|
||||
self.pipe = StableDiffusionPipeline.from_single_file(modelFile,
|
||||
torch_dtype=torchType,
|
||||
guidance_scale=cfg_scale)
|
||||
else:
|
||||
if request.IMG2IMG:
|
||||
self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(request.Model,
|
||||
torch_dtype=torchType,
|
||||
guidance_scale=cfg_scale)
|
||||
else:
|
||||
self.pipe = StableDiffusionPipeline.from_pretrained(request.Model,
|
||||
torch_dtype=torchType,
|
||||
guidance_scale=cfg_scale)
|
||||
# https://github.com/huggingface/diffusers/issues/4446
|
||||
# do not use text_encoder in the constructor since then
|
||||
# https://github.com/huggingface/diffusers/issues/3212#issuecomment-1521841481
|
||||
if CLIPSKIP and request.CLIPSkip != 0:
|
||||
text_encoder = CLIPTextModel.from_pretrained(clipmodel, num_hidden_layers=request.CLIPSkip, subfolder=clipsubfolder, torch_dtype=torchType)
|
||||
self.pipe.text_encoder=text_encoder
|
||||
self.pipe = StableDiffusionPipeline.from_pretrained(request.Model,
|
||||
torch_dtype=torchType,
|
||||
guidance_scale=cfg_scale)
|
||||
|
||||
if request.PipelineType == "DiffusionPipeline":
|
||||
self.pipe = DiffusionPipeline.from_pretrained(request.Model,
|
||||
torch_dtype=torchType,
|
||||
|
@ -197,11 +200,19 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||
use_safetensors=True,
|
||||
# variant="fp16"
|
||||
guidance_scale=cfg_scale)
|
||||
|
||||
# https://github.com/huggingface/diffusers/issues/4446
|
||||
# do not use text_encoder in the constructor since then
|
||||
# https://github.com/huggingface/diffusers/issues/3212#issuecomment-1521841481
|
||||
if CLIPSKIP and request.CLIPSkip != 0:
|
||||
text_encoder = CLIPTextModel.from_pretrained(clipmodel, num_hidden_layers=request.CLIPSkip, subfolder=clipsubfolder, torch_dtype=torchType)
|
||||
self.pipe.text_encoder=text_encoder
|
||||
# torch_dtype needs to be customized. float16 for GPU, float32 for CPU
|
||||
# TODO: this needs to be customized
|
||||
self.pipe.scheduler = get_scheduler(request.SchedulerType, self.pipe.scheduler.config)
|
||||
if request.SchedulerType != "":
|
||||
self.pipe.scheduler = get_scheduler(request.SchedulerType, self.pipe.scheduler.config)
|
||||
self.compel = Compel(tokenizer=self.pipe.tokenizer, text_encoder=self.pipe.text_encoder)
|
||||
if request.CUDA:
|
||||
self.pipe.to('cuda')
|
||||
except Exception as err:
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
# Implement your logic here for the LoadModel service
|
||||
|
@ -220,11 +231,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||
}
|
||||
|
||||
if request.src != "":
|
||||
# open the image with Image.open
|
||||
# convert the image to RGB
|
||||
# resize the image to the request width and height
|
||||
# XXX: untested
|
||||
image = Image.open(request.src).convert("RGB").resize((request.width, request.height))
|
||||
image = Image.open(request.src)
|
||||
options["image"] = image
|
||||
|
||||
# Get the keys that we will build the args for our pipe for
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue