From 02704e38d35b69bd7aee5e3ac8300f625f3ee57d Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sun, 27 Aug 2023 10:11:16 +0200 Subject: [PATCH] feat(diffusers): Add lora (#965) **Description** This PR fixes #914 Now diffusers respects the `lora_adapter` configuration parameter. --------- Signed-off-by: Ettore Di Giacinto --- api/backend/image.go | 2 + extra/grpc/diffusers/backend_diffusers.py | 74 ++++++++++++++++++++++- 2 files changed, 75 insertions(+), 1 deletion(-) diff --git a/api/backend/image.go b/api/backend/image.go index ea3f2069..1c415fd1 100644 --- a/api/backend/image.go +++ b/api/backend/image.go @@ -20,6 +20,8 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat SchedulerType: c.Diffusers.SchedulerType, PipelineType: c.Diffusers.PipelineType, CFGScale: c.Diffusers.CFGScale, + LoraAdapter: c.LoraAdapter, + LoraBase: c.LoraBase, IMG2IMG: c.Diffusers.IMG2IMG, CLIPModel: c.Diffusers.ClipModel, CLIPSubfolder: c.Diffusers.ClipSubFolder, diff --git a/extra/grpc/diffusers/backend_diffusers.py b/extra/grpc/diffusers/backend_diffusers.py index a005d7f4..8f8e2b5c 100755 --- a/extra/grpc/diffusers/backend_diffusers.py +++ b/extra/grpc/diffusers/backend_diffusers.py @@ -20,7 +20,8 @@ from io import BytesIO from diffusers import StableDiffusionImg2ImgPipeline from transformers import CLIPTextModel from enum import Enum - +from collections import defaultdict +from safetensors.torch import load_file _ONE_DAY_IN_SECONDS = 60 * 60 * 24 COMPEL=os.environ.get("COMPEL", "1") == "1" CLIPSKIP=os.environ.get("CLIPSKIP", "1") == "1" @@ -213,11 +214,82 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): self.compel = Compel(tokenizer=self.pipe.tokenizer, text_encoder=self.pipe.text_encoder) if request.CUDA: self.pipe.to('cuda') + # Assume directory from request.ModelFile. + # Only if request.LoraAdapter it's not an absolute path + if request.LoraAdapter and request.ModelFile != "" and not os.path.isabs(request.LoraAdapter) and request.LoraAdapter: + # get base path of modelFile + modelFileBase = os.path.dirname(request.ModelFile) + # modify LoraAdapter to be relative to modelFileBase + request.LoraAdapter = os.path.join(modelFileBase, request.LoraAdapter) + if request.LoraAdapter: + device = "cpu" if not request.CUDA else "cuda" + # Check if its a local file and not a directory ( we load lora differently for a safetensor file ) + if os.path.exists(request.LoraAdapter) and not os.path.isdir(request.LoraAdapter): + self.load_lora_weights(request.LoraAdapter, 1, device, torchType) + else: + self.pipe.unet.load_attn_procs(request.LoraAdapter) + except Exception as err: return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") # Implement your logic here for the LoadModel service # Replace this with your desired response return backend_pb2.Result(message="Model loaded successfully", success=True) + + # https://github.com/huggingface/diffusers/issues/3064 + def load_lora_weights(self, checkpoint_path, multiplier, device, dtype): + LORA_PREFIX_UNET = "lora_unet" + LORA_PREFIX_TEXT_ENCODER = "lora_te" + # load LoRA weight from .safetensors + state_dict = load_file(checkpoint_path, device=device) + + updates = defaultdict(dict) + for key, value in state_dict.items(): + # it is suggested to print out the key, it usually will be something like below + # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" + + layer, elem = key.split('.', 1) + updates[layer][elem] = value + + # directly update weight in diffusers model + for layer, elems in updates.items(): + + if "text" in layer: + layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") + curr_layer = self.pipe.text_encoder + else: + layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_") + curr_layer = self.pipe.unet + + # find the target layer + temp_name = layer_infos.pop(0) + while len(layer_infos) > -1: + try: + curr_layer = curr_layer.__getattr__(temp_name) + if len(layer_infos) > 0: + temp_name = layer_infos.pop(0) + elif len(layer_infos) == 0: + break + except Exception: + if len(temp_name) > 0: + temp_name += "_" + layer_infos.pop(0) + else: + temp_name = layer_infos.pop(0) + + # get elements for this layer + weight_up = elems['lora_up.weight'].to(dtype) + weight_down = elems['lora_down.weight'].to(dtype) + alpha = elems['alpha'] + if alpha: + alpha = alpha.item() / weight_up.shape[1] + else: + alpha = 1.0 + + # update weight + if len(weight_up.shape) == 4: + curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up.squeeze(3).squeeze(2), weight_down.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + else: + curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down) + def GenerateImage(self, request, context): prompt = request.positive_prompt