feat(diffusers): Add lora (#965)

**Description**

This PR fixes #914 

Now diffusers respects the `lora_adapter` configuration parameter.

---------

Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
This commit is contained in:
Ettore Di Giacinto 2023-08-27 10:11:16 +02:00 committed by GitHub
parent 9e5fb29965
commit 02704e38d3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 75 additions and 1 deletions

View file

@ -20,6 +20,8 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat
SchedulerType: c.Diffusers.SchedulerType, SchedulerType: c.Diffusers.SchedulerType,
PipelineType: c.Diffusers.PipelineType, PipelineType: c.Diffusers.PipelineType,
CFGScale: c.Diffusers.CFGScale, CFGScale: c.Diffusers.CFGScale,
LoraAdapter: c.LoraAdapter,
LoraBase: c.LoraBase,
IMG2IMG: c.Diffusers.IMG2IMG, IMG2IMG: c.Diffusers.IMG2IMG,
CLIPModel: c.Diffusers.ClipModel, CLIPModel: c.Diffusers.ClipModel,
CLIPSubfolder: c.Diffusers.ClipSubFolder, CLIPSubfolder: c.Diffusers.ClipSubFolder,

View file

@ -20,7 +20,8 @@ from io import BytesIO
from diffusers import StableDiffusionImg2ImgPipeline from diffusers import StableDiffusionImg2ImgPipeline
from transformers import CLIPTextModel from transformers import CLIPTextModel
from enum import Enum from enum import Enum
from collections import defaultdict
from safetensors.torch import load_file
_ONE_DAY_IN_SECONDS = 60 * 60 * 24 _ONE_DAY_IN_SECONDS = 60 * 60 * 24
COMPEL=os.environ.get("COMPEL", "1") == "1" COMPEL=os.environ.get("COMPEL", "1") == "1"
CLIPSKIP=os.environ.get("CLIPSKIP", "1") == "1" CLIPSKIP=os.environ.get("CLIPSKIP", "1") == "1"
@ -213,11 +214,82 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
self.compel = Compel(tokenizer=self.pipe.tokenizer, text_encoder=self.pipe.text_encoder) self.compel = Compel(tokenizer=self.pipe.tokenizer, text_encoder=self.pipe.text_encoder)
if request.CUDA: if request.CUDA:
self.pipe.to('cuda') self.pipe.to('cuda')
# Assume directory from request.ModelFile.
# Only if request.LoraAdapter it's not an absolute path
if request.LoraAdapter and request.ModelFile != "" and not os.path.isabs(request.LoraAdapter) and request.LoraAdapter:
# get base path of modelFile
modelFileBase = os.path.dirname(request.ModelFile)
# modify LoraAdapter to be relative to modelFileBase
request.LoraAdapter = os.path.join(modelFileBase, request.LoraAdapter)
if request.LoraAdapter:
device = "cpu" if not request.CUDA else "cuda"
# Check if its a local file and not a directory ( we load lora differently for a safetensor file )
if os.path.exists(request.LoraAdapter) and not os.path.isdir(request.LoraAdapter):
self.load_lora_weights(request.LoraAdapter, 1, device, torchType)
else:
self.pipe.unet.load_attn_procs(request.LoraAdapter)
except Exception as err: except Exception as err:
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
# Implement your logic here for the LoadModel service # Implement your logic here for the LoadModel service
# Replace this with your desired response # Replace this with your desired response
return backend_pb2.Result(message="Model loaded successfully", success=True) return backend_pb2.Result(message="Model loaded successfully", success=True)
# https://github.com/huggingface/diffusers/issues/3064
def load_lora_weights(self, checkpoint_path, multiplier, device, dtype):
LORA_PREFIX_UNET = "lora_unet"
LORA_PREFIX_TEXT_ENCODER = "lora_te"
# load LoRA weight from .safetensors
state_dict = load_file(checkpoint_path, device=device)
updates = defaultdict(dict)
for key, value in state_dict.items():
# it is suggested to print out the key, it usually will be something like below
# "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
layer, elem = key.split('.', 1)
updates[layer][elem] = value
# directly update weight in diffusers model
for layer, elems in updates.items():
if "text" in layer:
layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
curr_layer = self.pipe.text_encoder
else:
layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_")
curr_layer = self.pipe.unet
# find the target layer
temp_name = layer_infos.pop(0)
while len(layer_infos) > -1:
try:
curr_layer = curr_layer.__getattr__(temp_name)
if len(layer_infos) > 0:
temp_name = layer_infos.pop(0)
elif len(layer_infos) == 0:
break
except Exception:
if len(temp_name) > 0:
temp_name += "_" + layer_infos.pop(0)
else:
temp_name = layer_infos.pop(0)
# get elements for this layer
weight_up = elems['lora_up.weight'].to(dtype)
weight_down = elems['lora_down.weight'].to(dtype)
alpha = elems['alpha']
if alpha:
alpha = alpha.item() / weight_up.shape[1]
else:
alpha = 1.0
# update weight
if len(weight_up.shape) == 4:
curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up.squeeze(3).squeeze(2), weight_down.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
else:
curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down)
def GenerateImage(self, request, context): def GenerateImage(self, request, context):
prompt = request.positive_prompt prompt = request.positive_prompt