fix: Lora loading (#2893)

- Fixed Lora loading

Co-authored-by: Alex <alex@akhbar.home>
This commit is contained in:
vaaale 2024-07-16 18:58:45 +02:00 committed by GitHub
parent f521e50fa8
commit 4e84764787
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 72 additions and 61 deletions

View file

@ -1,6 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from concurrent import futures from concurrent import futures
import traceback
import argparse import argparse
from collections import defaultdict from collections import defaultdict
from enum import Enum from enum import Enum
@ -17,7 +17,8 @@ import backend_pb2_grpc
import grpc import grpc
from diffusers import StableDiffusion3Pipeline, StableDiffusionXLPipeline, StableDiffusionDepth2ImgPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline, DiffusionPipeline, EulerAncestralDiscreteScheduler from diffusers import StableDiffusion3Pipeline, StableDiffusionXLPipeline, StableDiffusionDepth2ImgPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline, DiffusionPipeline, \
EulerAncestralDiscreteScheduler
from diffusers import StableDiffusionImg2ImgPipeline, AutoPipelineForText2Image, ControlNetModel, StableVideoDiffusionPipeline from diffusers import StableDiffusionImg2ImgPipeline, AutoPipelineForText2Image, ControlNetModel, StableVideoDiffusionPipeline
from diffusers.pipelines.stable_diffusion import safety_checker from diffusers.pipelines.stable_diffusion import safety_checker
from diffusers.utils import load_image, export_to_video from diffusers.utils import load_image, export_to_video
@ -26,7 +27,6 @@ from compel import Compel, ReturnedEmbeddingsType
from transformers import CLIPTextModel from transformers import CLIPTextModel
from safetensors.torch import load_file from safetensors.torch import load_file
_ONE_DAY_IN_SECONDS = 60 * 60 * 24 _ONE_DAY_IN_SECONDS = 60 * 60 * 24
COMPEL = os.environ.get("COMPEL", "0") == "1" COMPEL = os.environ.get("COMPEL", "0") == "1"
XPU = os.environ.get("XPU", "0") == "1" XPU = os.environ.get("XPU", "0") == "1"
@ -39,13 +39,17 @@ FRAMES=os.environ.get("FRAMES", "64")
if XPU: if XPU:
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
print(ipex.xpu.get_device_name(0)) print(ipex.xpu.get_device_name(0))
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1 # If MAX_WORKERS are specified in the environment use it, otherwise default to 1
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1')) MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
# https://github.com/CompVis/stable-diffusion/issues/239#issuecomment-1627615287 # https://github.com/CompVis/stable-diffusion/issues/239#issuecomment-1627615287
def sc(self, clip_input, images): return images, [False for i in images] def sc(self, clip_input, images): return images, [False for i in images]
# edit the StableDiffusionSafetyChecker class so that, when called, it just returns the images and an array of True values # edit the StableDiffusionSafetyChecker class so that, when called, it just returns the images and an array of True values
safety_checker.StableDiffusionSafetyChecker.forward = sc safety_checker.StableDiffusionSafetyChecker.forward = sc
@ -62,6 +66,8 @@ from diffusers.schedulers import (
PNDMScheduler, PNDMScheduler,
UniPCMultistepScheduler, UniPCMultistepScheduler,
) )
# The scheduler list mapping was taken from here: https://github.com/neggles/animatediff-cli/blob/6f336f5f4b5e38e85d7f06f1744ef42d0a45f2a7/src/animatediff/schedulers.py#L39 # The scheduler list mapping was taken from here: https://github.com/neggles/animatediff-cli/blob/6f336f5f4b5e38e85d7f06f1744ef42d0a45f2a7/src/animatediff/schedulers.py#L39
# Credits to https://github.com/neggles # Credits to https://github.com/neggles
# See https://github.com/huggingface/diffusers/issues/4167 for more details on sched mapping from A1111 # See https://github.com/huggingface/diffusers/issues/4167 for more details on sched mapping from A1111
@ -136,10 +142,12 @@ def get_scheduler(name: str, config: dict = {}):
return sched_class.from_config(config) return sched_class.from_config(config)
# Implement the BackendServicer class with the service methods # Implement the BackendServicer class with the service methods
class BackendServicer(backend_pb2_grpc.BackendServicer): class BackendServicer(backend_pb2_grpc.BackendServicer):
def Health(self, request, context): def Health(self, request, context):
return backend_pb2.Reply(message=bytes("OK", 'utf-8')) return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
def LoadModel(self, request, context): def LoadModel(self, request, context):
try: try:
print(f"Loading model {request.Model}...", file=sys.stderr) print(f"Loading model {request.Model}...", file=sys.stderr)
@ -255,7 +263,6 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
requires_pooled=[False, True] requires_pooled=[False, True]
) )
if request.ControlNet: if request.ControlNet:
self.controlnet = ControlNetModel.from_pretrained( self.controlnet = ControlNetModel.from_pretrained(
request.ControlNet, torch_dtype=torchType, variant=variant request.ControlNet, torch_dtype=torchType, variant=variant
@ -263,13 +270,6 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
self.pipe.controlnet = self.controlnet self.pipe.controlnet = self.controlnet
else: else:
self.controlnet = None self.controlnet = None
if request.CUDA:
self.pipe.to('cuda')
if self.controlnet:
self.controlnet.to('cuda')
if XPU:
self.pipe = self.pipe.to("xpu")
# Assume directory from request.ModelFile. # Assume directory from request.ModelFile.
# Only if request.LoraAdapter it's not an absolute path # Only if request.LoraAdapter it's not an absolute path
if request.LoraAdapter and request.ModelFile != "" and not os.path.isabs(request.LoraAdapter) and request.LoraAdapter: if request.LoraAdapter and request.ModelFile != "" and not os.path.isabs(request.LoraAdapter) and request.LoraAdapter:
@ -282,10 +282,17 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
if request.LoraAdapter: if request.LoraAdapter:
# Check if its a local file and not a directory ( we load lora differently for a safetensor file ) # Check if its a local file and not a directory ( we load lora differently for a safetensor file )
if os.path.exists(request.LoraAdapter) and not os.path.isdir(request.LoraAdapter): if os.path.exists(request.LoraAdapter) and not os.path.isdir(request.LoraAdapter):
self.load_lora_weights(request.LoraAdapter, 1, device, torchType) # self.load_lora_weights(request.LoraAdapter, 1, device, torchType)
self.pipe.load_lora_weights(request.LoraAdapter)
else: else:
self.pipe.unet.load_attn_procs(request.LoraAdapter) self.pipe.unet.load_attn_procs(request.LoraAdapter)
if request.CUDA:
self.pipe.to('cuda')
if self.controlnet:
self.controlnet.to('cuda')
if XPU:
self.pipe = self.pipe.to("xpu")
except Exception as err: except Exception as err:
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
# Implement your logic here for the LoadModel service # Implement your logic here for the LoadModel service
@ -430,6 +437,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
return backend_pb2.Result(message="Media generated", success=True) return backend_pb2.Result(message="Media generated", success=True)
def serve(address): def serve(address):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS)) server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
@ -453,6 +461,7 @@ def serve(address):
except KeyboardInterrupt: except KeyboardInterrupt:
server.stop(0) server.stop(0)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run the gRPC server.") parser = argparse.ArgumentParser(description="Run the gRPC server.")
parser.add_argument( parser.add_argument(

View file

@ -1,5 +1,7 @@
setuptools
accelerate accelerate
compel compel
peft
diffusers diffusers
grpcio==1.65.0 grpcio==1.65.0
opencv-python opencv-python