mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-20 18:45:00 +00:00
fix: Lora loading (#2893)
- Fixed Lora loading Co-authored-by: Alex <alex@akhbar.home>
This commit is contained in:
parent
f521e50fa8
commit
4e84764787
2 changed files with 72 additions and 61 deletions
|
@ -1,6 +1,6 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
from concurrent import futures
|
from concurrent import futures
|
||||||
|
import traceback
|
||||||
import argparse
|
import argparse
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
@ -17,7 +17,8 @@ import backend_pb2_grpc
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
|
||||||
from diffusers import StableDiffusion3Pipeline, StableDiffusionXLPipeline, StableDiffusionDepth2ImgPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline, DiffusionPipeline, EulerAncestralDiscreteScheduler
|
from diffusers import StableDiffusion3Pipeline, StableDiffusionXLPipeline, StableDiffusionDepth2ImgPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline, DiffusionPipeline, \
|
||||||
|
EulerAncestralDiscreteScheduler
|
||||||
from diffusers import StableDiffusionImg2ImgPipeline, AutoPipelineForText2Image, ControlNetModel, StableVideoDiffusionPipeline
|
from diffusers import StableDiffusionImg2ImgPipeline, AutoPipelineForText2Image, ControlNetModel, StableVideoDiffusionPipeline
|
||||||
from diffusers.pipelines.stable_diffusion import safety_checker
|
from diffusers.pipelines.stable_diffusion import safety_checker
|
||||||
from diffusers.utils import load_image, export_to_video
|
from diffusers.utils import load_image, export_to_video
|
||||||
|
@ -26,7 +27,6 @@ from compel import Compel, ReturnedEmbeddingsType
|
||||||
from transformers import CLIPTextModel
|
from transformers import CLIPTextModel
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
|
|
||||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||||
COMPEL = os.environ.get("COMPEL", "0") == "1"
|
COMPEL = os.environ.get("COMPEL", "0") == "1"
|
||||||
XPU = os.environ.get("XPU", "0") == "1"
|
XPU = os.environ.get("XPU", "0") == "1"
|
||||||
|
@ -39,13 +39,17 @@ FRAMES=os.environ.get("FRAMES", "64")
|
||||||
|
|
||||||
if XPU:
|
if XPU:
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
|
|
||||||
print(ipex.xpu.get_device_name(0))
|
print(ipex.xpu.get_device_name(0))
|
||||||
|
|
||||||
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
||||||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
||||||
|
|
||||||
|
|
||||||
# https://github.com/CompVis/stable-diffusion/issues/239#issuecomment-1627615287
|
# https://github.com/CompVis/stable-diffusion/issues/239#issuecomment-1627615287
|
||||||
def sc(self, clip_input, images): return images, [False for i in images]
|
def sc(self, clip_input, images): return images, [False for i in images]
|
||||||
|
|
||||||
|
|
||||||
# edit the StableDiffusionSafetyChecker class so that, when called, it just returns the images and an array of True values
|
# edit the StableDiffusionSafetyChecker class so that, when called, it just returns the images and an array of True values
|
||||||
safety_checker.StableDiffusionSafetyChecker.forward = sc
|
safety_checker.StableDiffusionSafetyChecker.forward = sc
|
||||||
|
|
||||||
|
@ -62,6 +66,8 @@ from diffusers.schedulers import (
|
||||||
PNDMScheduler,
|
PNDMScheduler,
|
||||||
UniPCMultistepScheduler,
|
UniPCMultistepScheduler,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# The scheduler list mapping was taken from here: https://github.com/neggles/animatediff-cli/blob/6f336f5f4b5e38e85d7f06f1744ef42d0a45f2a7/src/animatediff/schedulers.py#L39
|
# The scheduler list mapping was taken from here: https://github.com/neggles/animatediff-cli/blob/6f336f5f4b5e38e85d7f06f1744ef42d0a45f2a7/src/animatediff/schedulers.py#L39
|
||||||
# Credits to https://github.com/neggles
|
# Credits to https://github.com/neggles
|
||||||
# See https://github.com/huggingface/diffusers/issues/4167 for more details on sched mapping from A1111
|
# See https://github.com/huggingface/diffusers/issues/4167 for more details on sched mapping from A1111
|
||||||
|
@ -136,10 +142,12 @@ def get_scheduler(name: str, config: dict = {}):
|
||||||
|
|
||||||
return sched_class.from_config(config)
|
return sched_class.from_config(config)
|
||||||
|
|
||||||
|
|
||||||
# Implement the BackendServicer class with the service methods
|
# Implement the BackendServicer class with the service methods
|
||||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||||
def Health(self, request, context):
|
def Health(self, request, context):
|
||||||
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
||||||
|
|
||||||
def LoadModel(self, request, context):
|
def LoadModel(self, request, context):
|
||||||
try:
|
try:
|
||||||
print(f"Loading model {request.Model}...", file=sys.stderr)
|
print(f"Loading model {request.Model}...", file=sys.stderr)
|
||||||
|
@ -255,7 +263,6 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||||
requires_pooled=[False, True]
|
requires_pooled=[False, True]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if request.ControlNet:
|
if request.ControlNet:
|
||||||
self.controlnet = ControlNetModel.from_pretrained(
|
self.controlnet = ControlNetModel.from_pretrained(
|
||||||
request.ControlNet, torch_dtype=torchType, variant=variant
|
request.ControlNet, torch_dtype=torchType, variant=variant
|
||||||
|
@ -263,13 +270,6 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||||
self.pipe.controlnet = self.controlnet
|
self.pipe.controlnet = self.controlnet
|
||||||
else:
|
else:
|
||||||
self.controlnet = None
|
self.controlnet = None
|
||||||
|
|
||||||
if request.CUDA:
|
|
||||||
self.pipe.to('cuda')
|
|
||||||
if self.controlnet:
|
|
||||||
self.controlnet.to('cuda')
|
|
||||||
if XPU:
|
|
||||||
self.pipe = self.pipe.to("xpu")
|
|
||||||
# Assume directory from request.ModelFile.
|
# Assume directory from request.ModelFile.
|
||||||
# Only if request.LoraAdapter it's not an absolute path
|
# Only if request.LoraAdapter it's not an absolute path
|
||||||
if request.LoraAdapter and request.ModelFile != "" and not os.path.isabs(request.LoraAdapter) and request.LoraAdapter:
|
if request.LoraAdapter and request.ModelFile != "" and not os.path.isabs(request.LoraAdapter) and request.LoraAdapter:
|
||||||
|
@ -282,10 +282,17 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||||
if request.LoraAdapter:
|
if request.LoraAdapter:
|
||||||
# Check if its a local file and not a directory ( we load lora differently for a safetensor file )
|
# Check if its a local file and not a directory ( we load lora differently for a safetensor file )
|
||||||
if os.path.exists(request.LoraAdapter) and not os.path.isdir(request.LoraAdapter):
|
if os.path.exists(request.LoraAdapter) and not os.path.isdir(request.LoraAdapter):
|
||||||
self.load_lora_weights(request.LoraAdapter, 1, device, torchType)
|
# self.load_lora_weights(request.LoraAdapter, 1, device, torchType)
|
||||||
|
self.pipe.load_lora_weights(request.LoraAdapter)
|
||||||
else:
|
else:
|
||||||
self.pipe.unet.load_attn_procs(request.LoraAdapter)
|
self.pipe.unet.load_attn_procs(request.LoraAdapter)
|
||||||
|
|
||||||
|
if request.CUDA:
|
||||||
|
self.pipe.to('cuda')
|
||||||
|
if self.controlnet:
|
||||||
|
self.controlnet.to('cuda')
|
||||||
|
if XPU:
|
||||||
|
self.pipe = self.pipe.to("xpu")
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||||
# Implement your logic here for the LoadModel service
|
# Implement your logic here for the LoadModel service
|
||||||
|
@ -430,6 +437,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||||
|
|
||||||
return backend_pb2.Result(message="Media generated", success=True)
|
return backend_pb2.Result(message="Media generated", success=True)
|
||||||
|
|
||||||
|
|
||||||
def serve(address):
|
def serve(address):
|
||||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
|
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
|
@ -453,6 +461,7 @@ def serve(address):
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
server.stop(0)
|
server.stop(0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Run the gRPC server.")
|
parser = argparse.ArgumentParser(description="Run the gRPC server.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
|
setuptools
|
||||||
accelerate
|
accelerate
|
||||||
compel
|
compel
|
||||||
|
peft
|
||||||
diffusers
|
diffusers
|
||||||
grpcio==1.65.0
|
grpcio==1.65.0
|
||||||
opencv-python
|
opencv-python
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue