mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-20 10:35:01 +00:00
feat(diffusers): update, add autopipeline, controlnet (#1432)
* feat(diffusers): update, add autopipeline, controlenet * tests with AutoPipeline * simplify logic
This commit is contained in:
parent
72325fd0a3
commit
7641f92cde
19 changed files with 812 additions and 770 deletions
|
@ -18,9 +18,9 @@ import backend_pb2_grpc
|
|||
import grpc
|
||||
|
||||
from diffusers import StableDiffusionXLPipeline, StableDiffusionDepth2ImgPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline, DiffusionPipeline, EulerAncestralDiscreteScheduler
|
||||
from diffusers import StableDiffusionImg2ImgPipeline
|
||||
from diffusers import StableDiffusionImg2ImgPipeline, AutoPipelineForText2Image, ControlNetModel
|
||||
from diffusers.pipelines.stable_diffusion import safety_checker
|
||||
|
||||
from diffusers.utils import load_image
|
||||
from compel import Compel
|
||||
|
||||
from transformers import CLIPTextModel
|
||||
|
@ -30,6 +30,7 @@ from safetensors.torch import load_file
|
|||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
COMPEL=os.environ.get("COMPEL", "1") == "1"
|
||||
CLIPSKIP=os.environ.get("CLIPSKIP", "1") == "1"
|
||||
SAFETENSORS=os.environ.get("SAFETENSORS", "1") == "1"
|
||||
|
||||
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
||||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
||||
|
@ -135,8 +136,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||
print(f"Loading model {request.Model}...", file=sys.stderr)
|
||||
print(f"Request {request}", file=sys.stderr)
|
||||
torchType = torch.float32
|
||||
variant = None
|
||||
|
||||
if request.F16Memory:
|
||||
torchType = torch.float16
|
||||
variant="fp16"
|
||||
|
||||
local = False
|
||||
modelFile = request.Model
|
||||
|
@ -160,14 +164,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||
|
||||
fromSingleFile = request.Model.startswith("http") or request.Model.startswith("/") or local
|
||||
|
||||
if request.IMG2IMG and request.PipelineType == "":
|
||||
request.PipelineType == "StableDiffusionImg2ImgPipeline"
|
||||
|
||||
if request.PipelineType == "":
|
||||
request.PipelineType == "StableDiffusionPipeline"
|
||||
|
||||
## img2img
|
||||
if request.PipelineType == "StableDiffusionImg2ImgPipeline":
|
||||
if (request.PipelineType == "StableDiffusionImg2ImgPipeline") or (request.IMG2IMG and request.PipelineType == ""):
|
||||
if fromSingleFile:
|
||||
self.pipe = StableDiffusionImg2ImgPipeline.from_single_file(modelFile,
|
||||
torch_dtype=torchType,
|
||||
|
@ -177,12 +175,18 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||
torch_dtype=torchType,
|
||||
guidance_scale=cfg_scale)
|
||||
|
||||
if request.PipelineType == "StableDiffusionDepth2ImgPipeline":
|
||||
elif request.PipelineType == "StableDiffusionDepth2ImgPipeline":
|
||||
self.pipe = StableDiffusionDepth2ImgPipeline.from_pretrained(request.Model,
|
||||
torch_dtype=torchType,
|
||||
guidance_scale=cfg_scale)
|
||||
## text2img
|
||||
if request.PipelineType == "StableDiffusionPipeline":
|
||||
elif request.PipelineType == "AutoPipelineForText2Image" or request.PipelineType == "":
|
||||
self.pipe = AutoPipelineForText2Image.from_pretrained(request.Model,
|
||||
torch_dtype=torchType,
|
||||
use_safetensors=SAFETENSORS,
|
||||
variant=variant,
|
||||
guidance_scale=cfg_scale)
|
||||
elif request.PipelineType == "StableDiffusionPipeline":
|
||||
if fromSingleFile:
|
||||
self.pipe = StableDiffusionPipeline.from_single_file(modelFile,
|
||||
torch_dtype=torchType,
|
||||
|
@ -191,13 +195,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||
self.pipe = StableDiffusionPipeline.from_pretrained(request.Model,
|
||||
torch_dtype=torchType,
|
||||
guidance_scale=cfg_scale)
|
||||
|
||||
if request.PipelineType == "DiffusionPipeline":
|
||||
elif request.PipelineType == "DiffusionPipeline":
|
||||
self.pipe = DiffusionPipeline.from_pretrained(request.Model,
|
||||
torch_dtype=torchType,
|
||||
guidance_scale=cfg_scale)
|
||||
|
||||
if request.PipelineType == "StableDiffusionXLPipeline":
|
||||
elif request.PipelineType == "StableDiffusionXLPipeline":
|
||||
if fromSingleFile:
|
||||
self.pipe = StableDiffusionXLPipeline.from_single_file(modelFile,
|
||||
torch_dtype=torchType, use_safetensors=True,
|
||||
|
@ -207,21 +209,34 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||
request.Model,
|
||||
torch_dtype=torchType,
|
||||
use_safetensors=True,
|
||||
# variant="fp16"
|
||||
variant=variant,
|
||||
guidance_scale=cfg_scale)
|
||||
# https://github.com/huggingface/diffusers/issues/4446
|
||||
# do not use text_encoder in the constructor since then
|
||||
# https://github.com/huggingface/diffusers/issues/3212#issuecomment-1521841481
|
||||
|
||||
if CLIPSKIP and request.CLIPSkip != 0:
|
||||
text_encoder = CLIPTextModel.from_pretrained(clipmodel, num_hidden_layers=request.CLIPSkip, subfolder=clipsubfolder, torch_dtype=torchType)
|
||||
self.pipe.text_encoder=text_encoder
|
||||
self.clip_skip = request.CLIPSkip
|
||||
else:
|
||||
self.clip_skip = 0
|
||||
|
||||
# torch_dtype needs to be customized. float16 for GPU, float32 for CPU
|
||||
# TODO: this needs to be customized
|
||||
if request.SchedulerType != "":
|
||||
self.pipe.scheduler = get_scheduler(request.SchedulerType, self.pipe.scheduler.config)
|
||||
|
||||
self.compel = Compel(tokenizer=self.pipe.tokenizer, text_encoder=self.pipe.text_encoder)
|
||||
|
||||
|
||||
if request.ControlNet:
|
||||
self.controlnet = ControlNetModel.from_pretrained(
|
||||
request.ControlNet, torch_dtype=torchType, variant=variant
|
||||
)
|
||||
self.pipe.controlnet = self.controlnet
|
||||
else:
|
||||
self.controlnet = None
|
||||
|
||||
if request.CUDA:
|
||||
self.pipe.to('cuda')
|
||||
if self.controlnet:
|
||||
self.controlnet.to('cuda')
|
||||
# Assume directory from request.ModelFile.
|
||||
# Only if request.LoraAdapter it's not an absolute path
|
||||
if request.LoraAdapter and request.ModelFile != "" and not os.path.isabs(request.LoraAdapter) and request.LoraAdapter:
|
||||
|
@ -316,9 +331,15 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||
"num_inference_steps": steps,
|
||||
}
|
||||
|
||||
if request.src != "":
|
||||
if request.src != "" and not self.controlnet:
|
||||
image = Image.open(request.src)
|
||||
options["image"] = image
|
||||
elif self.controlnet and request.src:
|
||||
pose_image = load_image(request.src)
|
||||
options["image"] = pose_image
|
||||
|
||||
if CLIPSKIP and self.clip_skip != 0:
|
||||
options["clip_skip"]=self.clip_skip
|
||||
|
||||
# Get the keys that we will build the args for our pipe for
|
||||
keys = options.keys()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue