fix: Lora loading (#2893)

- Fixed Lora loading

Co-authored-by: Alex <alex@akhbar.home>
This commit is contained in:
vaaale 2024-07-16 18:58:45 +02:00 committed by GitHub
parent f521e50fa8
commit 4e84764787
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 72 additions and 61 deletions

View file

@ -1,6 +1,6 @@
#!/usr/bin/env python3
from concurrent import futures
import traceback
import argparse
from collections import defaultdict
from enum import Enum
@ -17,7 +17,8 @@ import backend_pb2_grpc
import grpc
from diffusers import StableDiffusion3Pipeline, StableDiffusionXLPipeline, StableDiffusionDepth2ImgPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline, DiffusionPipeline, EulerAncestralDiscreteScheduler
from diffusers import StableDiffusion3Pipeline, StableDiffusionXLPipeline, StableDiffusionDepth2ImgPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline, DiffusionPipeline, \
EulerAncestralDiscreteScheduler
from diffusers import StableDiffusionImg2ImgPipeline, AutoPipelineForText2Image, ControlNetModel, StableVideoDiffusionPipeline
from diffusers.pipelines.stable_diffusion import safety_checker
from diffusers.utils import load_image, export_to_video
@ -26,7 +27,6 @@ from compel import Compel, ReturnedEmbeddingsType
from transformers import CLIPTextModel
from safetensors.torch import load_file
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
COMPEL = os.environ.get("COMPEL", "0") == "1"
XPU = os.environ.get("XPU", "0") == "1"
@ -39,13 +39,17 @@ FRAMES=os.environ.get("FRAMES", "64")
if XPU:
import intel_extension_for_pytorch as ipex
print(ipex.xpu.get_device_name(0))
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
# https://github.com/CompVis/stable-diffusion/issues/239#issuecomment-1627615287
def sc(self, clip_input, images): return images, [False for i in images]
# edit the StableDiffusionSafetyChecker class so that, when called, it just returns the images and an array of True values
safety_checker.StableDiffusionSafetyChecker.forward = sc
@ -62,6 +66,8 @@ from diffusers.schedulers import (
PNDMScheduler,
UniPCMultistepScheduler,
)
# The scheduler list mapping was taken from here: https://github.com/neggles/animatediff-cli/blob/6f336f5f4b5e38e85d7f06f1744ef42d0a45f2a7/src/animatediff/schedulers.py#L39
# Credits to https://github.com/neggles
# See https://github.com/huggingface/diffusers/issues/4167 for more details on sched mapping from A1111
@ -136,10 +142,12 @@ def get_scheduler(name: str, config: dict = {}):
return sched_class.from_config(config)
# Implement the BackendServicer class with the service methods
class BackendServicer(backend_pb2_grpc.BackendServicer):
def Health(self, request, context):
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
def LoadModel(self, request, context):
try:
print(f"Loading model {request.Model}...", file=sys.stderr)
@ -255,7 +263,6 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
requires_pooled=[False, True]
)
if request.ControlNet:
self.controlnet = ControlNetModel.from_pretrained(
request.ControlNet, torch_dtype=torchType, variant=variant
@ -263,13 +270,6 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
self.pipe.controlnet = self.controlnet
else:
self.controlnet = None
if request.CUDA:
self.pipe.to('cuda')
if self.controlnet:
self.controlnet.to('cuda')
if XPU:
self.pipe = self.pipe.to("xpu")
# Assume directory from request.ModelFile.
# Only if request.LoraAdapter it's not an absolute path
if request.LoraAdapter and request.ModelFile != "" and not os.path.isabs(request.LoraAdapter) and request.LoraAdapter:
@ -282,10 +282,17 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
if request.LoraAdapter:
# Check if its a local file and not a directory ( we load lora differently for a safetensor file )
if os.path.exists(request.LoraAdapter) and not os.path.isdir(request.LoraAdapter):
self.load_lora_weights(request.LoraAdapter, 1, device, torchType)
# self.load_lora_weights(request.LoraAdapter, 1, device, torchType)
self.pipe.load_lora_weights(request.LoraAdapter)
else:
self.pipe.unet.load_attn_procs(request.LoraAdapter)
if request.CUDA:
self.pipe.to('cuda')
if self.controlnet:
self.controlnet.to('cuda')
if XPU:
self.pipe = self.pipe.to("xpu")
except Exception as err:
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
# Implement your logic here for the LoadModel service
@ -430,6 +437,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
return backend_pb2.Result(message="Media generated", success=True)
def serve(address):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
@ -453,6 +461,7 @@ def serve(address):
except KeyboardInterrupt:
server.stop(0)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run the gRPC server.")
parser.add_argument(

View file

@ -1,5 +1,7 @@
setuptools
accelerate
compel
peft
diffusers
grpcio==1.65.0
opencv-python