mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-20 18:45:00 +00:00
feat(diffusers): update, add autopipeline, controlnet (#1432)
* feat(diffusers): update, add autopipeline, controlenet * tests with AutoPipeline * simplify logic
This commit is contained in:
parent
72325fd0a3
commit
7641f92cde
19 changed files with 812 additions and 770 deletions
2
Makefile
2
Makefile
|
@ -383,7 +383,7 @@ help: ## Show this help.
|
||||||
protogen: protogen-go protogen-python
|
protogen: protogen-go protogen-python
|
||||||
|
|
||||||
protogen-go:
|
protogen-go:
|
||||||
protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative \
|
protoc -Ibackend/ --go_out=pkg/grpc/proto/ --go_opt=paths=source_relative --go-grpc_out=pkg/grpc/proto/ --go-grpc_opt=paths=source_relative \
|
||||||
backend/backend.proto
|
backend/backend.proto
|
||||||
|
|
||||||
protogen-python:
|
protogen-python:
|
||||||
|
|
|
@ -27,6 +27,7 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat
|
||||||
CLIPModel: c.Diffusers.ClipModel,
|
CLIPModel: c.Diffusers.ClipModel,
|
||||||
CLIPSubfolder: c.Diffusers.ClipSubFolder,
|
CLIPSubfolder: c.Diffusers.ClipSubFolder,
|
||||||
CLIPSkip: int32(c.Diffusers.ClipSkip),
|
CLIPSkip: int32(c.Diffusers.ClipSkip),
|
||||||
|
ControlNet: c.Diffusers.ControlNet,
|
||||||
}),
|
}),
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -38,7 +38,6 @@ type Config struct {
|
||||||
|
|
||||||
// Diffusers
|
// Diffusers
|
||||||
Diffusers Diffusers `yaml:"diffusers"`
|
Diffusers Diffusers `yaml:"diffusers"`
|
||||||
|
|
||||||
Step int `yaml:"step"`
|
Step int `yaml:"step"`
|
||||||
|
|
||||||
// GRPC Options
|
// GRPC Options
|
||||||
|
@ -77,6 +76,7 @@ type Diffusers struct {
|
||||||
ClipSkip int `yaml:"clip_skip"` // Skip every N frames
|
ClipSkip int `yaml:"clip_skip"` // Skip every N frames
|
||||||
ClipModel string `yaml:"clip_model"` // Clip model to use
|
ClipModel string `yaml:"clip_model"` // Clip model to use
|
||||||
ClipSubFolder string `yaml:"clip_subfolder"` // Subfolder to use for clip model
|
ClipSubFolder string `yaml:"clip_subfolder"` // Subfolder to use for clip model
|
||||||
|
ControlNet string `yaml:"control_net"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type LLMConfig struct {
|
type LLMConfig struct {
|
||||||
|
|
|
@ -110,6 +110,7 @@ message ModelOptions {
|
||||||
string CLIPModel = 31;
|
string CLIPModel = 31;
|
||||||
string CLIPSubfolder = 32;
|
string CLIPSubfolder = 32;
|
||||||
int32 CLIPSkip = 33;
|
int32 CLIPSkip = 33;
|
||||||
|
string ControlNet = 48;
|
||||||
|
|
||||||
// RWKV
|
// RWKV
|
||||||
string Tokenizer = 34;
|
string Tokenizer = 34;
|
||||||
|
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -18,9 +18,9 @@ import backend_pb2_grpc
|
||||||
import grpc
|
import grpc
|
||||||
|
|
||||||
from diffusers import StableDiffusionXLPipeline, StableDiffusionDepth2ImgPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline, DiffusionPipeline, EulerAncestralDiscreteScheduler
|
from diffusers import StableDiffusionXLPipeline, StableDiffusionDepth2ImgPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline, DiffusionPipeline, EulerAncestralDiscreteScheduler
|
||||||
from diffusers import StableDiffusionImg2ImgPipeline
|
from diffusers import StableDiffusionImg2ImgPipeline, AutoPipelineForText2Image, ControlNetModel
|
||||||
from diffusers.pipelines.stable_diffusion import safety_checker
|
from diffusers.pipelines.stable_diffusion import safety_checker
|
||||||
|
from diffusers.utils import load_image
|
||||||
from compel import Compel
|
from compel import Compel
|
||||||
|
|
||||||
from transformers import CLIPTextModel
|
from transformers import CLIPTextModel
|
||||||
|
@ -30,6 +30,7 @@ from safetensors.torch import load_file
|
||||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||||
COMPEL=os.environ.get("COMPEL", "1") == "1"
|
COMPEL=os.environ.get("COMPEL", "1") == "1"
|
||||||
CLIPSKIP=os.environ.get("CLIPSKIP", "1") == "1"
|
CLIPSKIP=os.environ.get("CLIPSKIP", "1") == "1"
|
||||||
|
SAFETENSORS=os.environ.get("SAFETENSORS", "1") == "1"
|
||||||
|
|
||||||
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
||||||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
||||||
|
@ -135,8 +136,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||||
print(f"Loading model {request.Model}...", file=sys.stderr)
|
print(f"Loading model {request.Model}...", file=sys.stderr)
|
||||||
print(f"Request {request}", file=sys.stderr)
|
print(f"Request {request}", file=sys.stderr)
|
||||||
torchType = torch.float32
|
torchType = torch.float32
|
||||||
|
variant = None
|
||||||
|
|
||||||
if request.F16Memory:
|
if request.F16Memory:
|
||||||
torchType = torch.float16
|
torchType = torch.float16
|
||||||
|
variant="fp16"
|
||||||
|
|
||||||
local = False
|
local = False
|
||||||
modelFile = request.Model
|
modelFile = request.Model
|
||||||
|
@ -160,14 +164,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||||
|
|
||||||
fromSingleFile = request.Model.startswith("http") or request.Model.startswith("/") or local
|
fromSingleFile = request.Model.startswith("http") or request.Model.startswith("/") or local
|
||||||
|
|
||||||
if request.IMG2IMG and request.PipelineType == "":
|
|
||||||
request.PipelineType == "StableDiffusionImg2ImgPipeline"
|
|
||||||
|
|
||||||
if request.PipelineType == "":
|
|
||||||
request.PipelineType == "StableDiffusionPipeline"
|
|
||||||
|
|
||||||
## img2img
|
## img2img
|
||||||
if request.PipelineType == "StableDiffusionImg2ImgPipeline":
|
if (request.PipelineType == "StableDiffusionImg2ImgPipeline") or (request.IMG2IMG and request.PipelineType == ""):
|
||||||
if fromSingleFile:
|
if fromSingleFile:
|
||||||
self.pipe = StableDiffusionImg2ImgPipeline.from_single_file(modelFile,
|
self.pipe = StableDiffusionImg2ImgPipeline.from_single_file(modelFile,
|
||||||
torch_dtype=torchType,
|
torch_dtype=torchType,
|
||||||
|
@ -177,12 +175,18 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||||
torch_dtype=torchType,
|
torch_dtype=torchType,
|
||||||
guidance_scale=cfg_scale)
|
guidance_scale=cfg_scale)
|
||||||
|
|
||||||
if request.PipelineType == "StableDiffusionDepth2ImgPipeline":
|
elif request.PipelineType == "StableDiffusionDepth2ImgPipeline":
|
||||||
self.pipe = StableDiffusionDepth2ImgPipeline.from_pretrained(request.Model,
|
self.pipe = StableDiffusionDepth2ImgPipeline.from_pretrained(request.Model,
|
||||||
torch_dtype=torchType,
|
torch_dtype=torchType,
|
||||||
guidance_scale=cfg_scale)
|
guidance_scale=cfg_scale)
|
||||||
## text2img
|
## text2img
|
||||||
if request.PipelineType == "StableDiffusionPipeline":
|
elif request.PipelineType == "AutoPipelineForText2Image" or request.PipelineType == "":
|
||||||
|
self.pipe = AutoPipelineForText2Image.from_pretrained(request.Model,
|
||||||
|
torch_dtype=torchType,
|
||||||
|
use_safetensors=SAFETENSORS,
|
||||||
|
variant=variant,
|
||||||
|
guidance_scale=cfg_scale)
|
||||||
|
elif request.PipelineType == "StableDiffusionPipeline":
|
||||||
if fromSingleFile:
|
if fromSingleFile:
|
||||||
self.pipe = StableDiffusionPipeline.from_single_file(modelFile,
|
self.pipe = StableDiffusionPipeline.from_single_file(modelFile,
|
||||||
torch_dtype=torchType,
|
torch_dtype=torchType,
|
||||||
|
@ -191,13 +195,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||||
self.pipe = StableDiffusionPipeline.from_pretrained(request.Model,
|
self.pipe = StableDiffusionPipeline.from_pretrained(request.Model,
|
||||||
torch_dtype=torchType,
|
torch_dtype=torchType,
|
||||||
guidance_scale=cfg_scale)
|
guidance_scale=cfg_scale)
|
||||||
|
elif request.PipelineType == "DiffusionPipeline":
|
||||||
if request.PipelineType == "DiffusionPipeline":
|
|
||||||
self.pipe = DiffusionPipeline.from_pretrained(request.Model,
|
self.pipe = DiffusionPipeline.from_pretrained(request.Model,
|
||||||
torch_dtype=torchType,
|
torch_dtype=torchType,
|
||||||
guidance_scale=cfg_scale)
|
guidance_scale=cfg_scale)
|
||||||
|
elif request.PipelineType == "StableDiffusionXLPipeline":
|
||||||
if request.PipelineType == "StableDiffusionXLPipeline":
|
|
||||||
if fromSingleFile:
|
if fromSingleFile:
|
||||||
self.pipe = StableDiffusionXLPipeline.from_single_file(modelFile,
|
self.pipe = StableDiffusionXLPipeline.from_single_file(modelFile,
|
||||||
torch_dtype=torchType, use_safetensors=True,
|
torch_dtype=torchType, use_safetensors=True,
|
||||||
|
@ -207,21 +209,34 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||||
request.Model,
|
request.Model,
|
||||||
torch_dtype=torchType,
|
torch_dtype=torchType,
|
||||||
use_safetensors=True,
|
use_safetensors=True,
|
||||||
# variant="fp16"
|
variant=variant,
|
||||||
guidance_scale=cfg_scale)
|
guidance_scale=cfg_scale)
|
||||||
# https://github.com/huggingface/diffusers/issues/4446
|
|
||||||
# do not use text_encoder in the constructor since then
|
|
||||||
# https://github.com/huggingface/diffusers/issues/3212#issuecomment-1521841481
|
|
||||||
if CLIPSKIP and request.CLIPSkip != 0:
|
if CLIPSKIP and request.CLIPSkip != 0:
|
||||||
text_encoder = CLIPTextModel.from_pretrained(clipmodel, num_hidden_layers=request.CLIPSkip, subfolder=clipsubfolder, torch_dtype=torchType)
|
self.clip_skip = request.CLIPSkip
|
||||||
self.pipe.text_encoder=text_encoder
|
else:
|
||||||
|
self.clip_skip = 0
|
||||||
|
|
||||||
# torch_dtype needs to be customized. float16 for GPU, float32 for CPU
|
# torch_dtype needs to be customized. float16 for GPU, float32 for CPU
|
||||||
# TODO: this needs to be customized
|
# TODO: this needs to be customized
|
||||||
if request.SchedulerType != "":
|
if request.SchedulerType != "":
|
||||||
self.pipe.scheduler = get_scheduler(request.SchedulerType, self.pipe.scheduler.config)
|
self.pipe.scheduler = get_scheduler(request.SchedulerType, self.pipe.scheduler.config)
|
||||||
|
|
||||||
self.compel = Compel(tokenizer=self.pipe.tokenizer, text_encoder=self.pipe.text_encoder)
|
self.compel = Compel(tokenizer=self.pipe.tokenizer, text_encoder=self.pipe.text_encoder)
|
||||||
|
|
||||||
|
|
||||||
|
if request.ControlNet:
|
||||||
|
self.controlnet = ControlNetModel.from_pretrained(
|
||||||
|
request.ControlNet, torch_dtype=torchType, variant=variant
|
||||||
|
)
|
||||||
|
self.pipe.controlnet = self.controlnet
|
||||||
|
else:
|
||||||
|
self.controlnet = None
|
||||||
|
|
||||||
if request.CUDA:
|
if request.CUDA:
|
||||||
self.pipe.to('cuda')
|
self.pipe.to('cuda')
|
||||||
|
if self.controlnet:
|
||||||
|
self.controlnet.to('cuda')
|
||||||
# Assume directory from request.ModelFile.
|
# Assume directory from request.ModelFile.
|
||||||
# Only if request.LoraAdapter it's not an absolute path
|
# Only if request.LoraAdapter it's not an absolute path
|
||||||
if request.LoraAdapter and request.ModelFile != "" and not os.path.isabs(request.LoraAdapter) and request.LoraAdapter:
|
if request.LoraAdapter and request.ModelFile != "" and not os.path.isabs(request.LoraAdapter) and request.LoraAdapter:
|
||||||
|
@ -316,9 +331,15 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||||
"num_inference_steps": steps,
|
"num_inference_steps": steps,
|
||||||
}
|
}
|
||||||
|
|
||||||
if request.src != "":
|
if request.src != "" and not self.controlnet:
|
||||||
image = Image.open(request.src)
|
image = Image.open(request.src)
|
||||||
options["image"] = image
|
options["image"] = image
|
||||||
|
elif self.controlnet and request.src:
|
||||||
|
pose_image = load_image(request.src)
|
||||||
|
options["image"] = pose_image
|
||||||
|
|
||||||
|
if CLIPSKIP and self.clip_skip != 0:
|
||||||
|
options["clip_skip"]=self.clip_skip
|
||||||
|
|
||||||
# Get the keys that we will build the args for our pipe for
|
# Get the keys that we will build the args for our pipe for
|
||||||
keys = options.keys()
|
keys = options.keys()
|
||||||
|
|
File diff suppressed because one or more lines are too long
|
@ -53,7 +53,7 @@ class TestBackendServicer(unittest.TestCase):
|
||||||
self.setUp()
|
self.setUp()
|
||||||
with grpc.insecure_channel("localhost:50051") as channel:
|
with grpc.insecure_channel("localhost:50051") as channel:
|
||||||
stub = backend_pb2_grpc.BackendStub(channel)
|
stub = backend_pb2_grpc.BackendStub(channel)
|
||||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="runwayml/stable-diffusion-v1-5", PipelineType="StableDiffusionPipeline"))
|
response = stub.LoadModel(backend_pb2.ModelOptions(Model="runwayml/stable-diffusion-v1-5"))
|
||||||
self.assertTrue(response.success)
|
self.assertTrue(response.success)
|
||||||
self.assertEqual(response.message, "Model loaded successfully")
|
self.assertEqual(response.message, "Model loaded successfully")
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
|
@ -71,7 +71,7 @@ class TestBackendServicer(unittest.TestCase):
|
||||||
self.setUp()
|
self.setUp()
|
||||||
with grpc.insecure_channel("localhost:50051") as channel:
|
with grpc.insecure_channel("localhost:50051") as channel:
|
||||||
stub = backend_pb2_grpc.BackendStub(channel)
|
stub = backend_pb2_grpc.BackendStub(channel)
|
||||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="runwayml/stable-diffusion-v1-5", PipelineType="StableDiffusionPipeline"))
|
response = stub.LoadModel(backend_pb2.ModelOptions(Model="runwayml/stable-diffusion-v1-5"))
|
||||||
print(response.message)
|
print(response.message)
|
||||||
self.assertTrue(response.success)
|
self.assertTrue(response.success)
|
||||||
image_req = backend_pb2.GenerateImageRequest(positive_prompt="cat", width=16,height=16, dst="test.jpg")
|
image_req = backend_pb2.GenerateImageRequest(positive_prompt="cat", width=16,height=16, dst="test.jpg")
|
||||||
|
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load diff
|
@ -1,8 +1,8 @@
|
||||||
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
|
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
|
||||||
// versions:
|
// versions:
|
||||||
// - protoc-gen-go-grpc v1.2.0
|
// - protoc-gen-go-grpc v1.2.0
|
||||||
// - protoc v4.23.4
|
// - protoc v3.6.1
|
||||||
// source: backend/backend.proto
|
// source: backend.proto
|
||||||
|
|
||||||
package proto
|
package proto
|
||||||
|
|
||||||
|
@ -453,5 +453,5 @@ var Backend_ServiceDesc = grpc.ServiceDesc{
|
||||||
ServerStreams: true,
|
ServerStreams: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Metadata: "backend/backend.proto",
|
Metadata: "backend.proto",
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue