feat(diffusers): update, add autopipeline, controlnet (#1432)

* feat(diffusers): update, add autopipeline, controlenet

* tests with AutoPipeline

* simplify logic
This commit is contained in:
Ettore Di Giacinto 2023-12-13 13:20:22 -05:00 committed by GitHub
parent 72325fd0a3
commit 7641f92cde
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 812 additions and 770 deletions

View file

@ -383,7 +383,7 @@ help: ## Show this help.
protogen: protogen-go protogen-python protogen: protogen-go protogen-python
protogen-go: protogen-go:
protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative \ protoc -Ibackend/ --go_out=pkg/grpc/proto/ --go_opt=paths=source_relative --go-grpc_out=pkg/grpc/proto/ --go-grpc_opt=paths=source_relative \
backend/backend.proto backend/backend.proto
protogen-python: protogen-python:

View file

@ -27,6 +27,7 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat
CLIPModel: c.Diffusers.ClipModel, CLIPModel: c.Diffusers.ClipModel,
CLIPSubfolder: c.Diffusers.ClipSubFolder, CLIPSubfolder: c.Diffusers.ClipSubFolder,
CLIPSkip: int32(c.Diffusers.ClipSkip), CLIPSkip: int32(c.Diffusers.ClipSkip),
ControlNet: c.Diffusers.ControlNet,
}), }),
}) })

View file

@ -38,8 +38,7 @@ type Config struct {
// Diffusers // Diffusers
Diffusers Diffusers `yaml:"diffusers"` Diffusers Diffusers `yaml:"diffusers"`
Step int `yaml:"step"`
Step int `yaml:"step"`
// GRPC Options // GRPC Options
GRPC GRPC `yaml:"grpc"` GRPC GRPC `yaml:"grpc"`
@ -77,6 +76,7 @@ type Diffusers struct {
ClipSkip int `yaml:"clip_skip"` // Skip every N frames ClipSkip int `yaml:"clip_skip"` // Skip every N frames
ClipModel string `yaml:"clip_model"` // Clip model to use ClipModel string `yaml:"clip_model"` // Clip model to use
ClipSubFolder string `yaml:"clip_subfolder"` // Subfolder to use for clip model ClipSubFolder string `yaml:"clip_subfolder"` // Subfolder to use for clip model
ControlNet string `yaml:"control_net"`
} }
type LLMConfig struct { type LLMConfig struct {

View file

@ -110,6 +110,7 @@ message ModelOptions {
string CLIPModel = 31; string CLIPModel = 31;
string CLIPSubfolder = 32; string CLIPSubfolder = 32;
int32 CLIPSkip = 33; int32 CLIPSkip = 33;
string ControlNet = 48;
// RWKV // RWKV
string Tokenizer = 34; string Tokenizer = 34;

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -18,9 +18,9 @@ import backend_pb2_grpc
import grpc import grpc
from diffusers import StableDiffusionXLPipeline, StableDiffusionDepth2ImgPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline, DiffusionPipeline, EulerAncestralDiscreteScheduler from diffusers import StableDiffusionXLPipeline, StableDiffusionDepth2ImgPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline, DiffusionPipeline, EulerAncestralDiscreteScheduler
from diffusers import StableDiffusionImg2ImgPipeline from diffusers import StableDiffusionImg2ImgPipeline, AutoPipelineForText2Image, ControlNetModel
from diffusers.pipelines.stable_diffusion import safety_checker from diffusers.pipelines.stable_diffusion import safety_checker
from diffusers.utils import load_image
from compel import Compel from compel import Compel
from transformers import CLIPTextModel from transformers import CLIPTextModel
@ -30,6 +30,7 @@ from safetensors.torch import load_file
_ONE_DAY_IN_SECONDS = 60 * 60 * 24 _ONE_DAY_IN_SECONDS = 60 * 60 * 24
COMPEL=os.environ.get("COMPEL", "1") == "1" COMPEL=os.environ.get("COMPEL", "1") == "1"
CLIPSKIP=os.environ.get("CLIPSKIP", "1") == "1" CLIPSKIP=os.environ.get("CLIPSKIP", "1") == "1"
SAFETENSORS=os.environ.get("SAFETENSORS", "1") == "1"
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1 # If MAX_WORKERS are specified in the environment use it, otherwise default to 1
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1')) MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
@ -135,8 +136,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
print(f"Loading model {request.Model}...", file=sys.stderr) print(f"Loading model {request.Model}...", file=sys.stderr)
print(f"Request {request}", file=sys.stderr) print(f"Request {request}", file=sys.stderr)
torchType = torch.float32 torchType = torch.float32
variant = None
if request.F16Memory: if request.F16Memory:
torchType = torch.float16 torchType = torch.float16
variant="fp16"
local = False local = False
modelFile = request.Model modelFile = request.Model
@ -160,14 +164,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
fromSingleFile = request.Model.startswith("http") or request.Model.startswith("/") or local fromSingleFile = request.Model.startswith("http") or request.Model.startswith("/") or local
if request.IMG2IMG and request.PipelineType == "":
request.PipelineType == "StableDiffusionImg2ImgPipeline"
if request.PipelineType == "":
request.PipelineType == "StableDiffusionPipeline"
## img2img ## img2img
if request.PipelineType == "StableDiffusionImg2ImgPipeline": if (request.PipelineType == "StableDiffusionImg2ImgPipeline") or (request.IMG2IMG and request.PipelineType == ""):
if fromSingleFile: if fromSingleFile:
self.pipe = StableDiffusionImg2ImgPipeline.from_single_file(modelFile, self.pipe = StableDiffusionImg2ImgPipeline.from_single_file(modelFile,
torch_dtype=torchType, torch_dtype=torchType,
@ -177,12 +175,18 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
torch_dtype=torchType, torch_dtype=torchType,
guidance_scale=cfg_scale) guidance_scale=cfg_scale)
if request.PipelineType == "StableDiffusionDepth2ImgPipeline": elif request.PipelineType == "StableDiffusionDepth2ImgPipeline":
self.pipe = StableDiffusionDepth2ImgPipeline.from_pretrained(request.Model, self.pipe = StableDiffusionDepth2ImgPipeline.from_pretrained(request.Model,
torch_dtype=torchType, torch_dtype=torchType,
guidance_scale=cfg_scale) guidance_scale=cfg_scale)
## text2img ## text2img
if request.PipelineType == "StableDiffusionPipeline": elif request.PipelineType == "AutoPipelineForText2Image" or request.PipelineType == "":
self.pipe = AutoPipelineForText2Image.from_pretrained(request.Model,
torch_dtype=torchType,
use_safetensors=SAFETENSORS,
variant=variant,
guidance_scale=cfg_scale)
elif request.PipelineType == "StableDiffusionPipeline":
if fromSingleFile: if fromSingleFile:
self.pipe = StableDiffusionPipeline.from_single_file(modelFile, self.pipe = StableDiffusionPipeline.from_single_file(modelFile,
torch_dtype=torchType, torch_dtype=torchType,
@ -191,13 +195,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
self.pipe = StableDiffusionPipeline.from_pretrained(request.Model, self.pipe = StableDiffusionPipeline.from_pretrained(request.Model,
torch_dtype=torchType, torch_dtype=torchType,
guidance_scale=cfg_scale) guidance_scale=cfg_scale)
elif request.PipelineType == "DiffusionPipeline":
if request.PipelineType == "DiffusionPipeline":
self.pipe = DiffusionPipeline.from_pretrained(request.Model, self.pipe = DiffusionPipeline.from_pretrained(request.Model,
torch_dtype=torchType, torch_dtype=torchType,
guidance_scale=cfg_scale) guidance_scale=cfg_scale)
elif request.PipelineType == "StableDiffusionXLPipeline":
if request.PipelineType == "StableDiffusionXLPipeline":
if fromSingleFile: if fromSingleFile:
self.pipe = StableDiffusionXLPipeline.from_single_file(modelFile, self.pipe = StableDiffusionXLPipeline.from_single_file(modelFile,
torch_dtype=torchType, use_safetensors=True, torch_dtype=torchType, use_safetensors=True,
@ -207,21 +209,34 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
request.Model, request.Model,
torch_dtype=torchType, torch_dtype=torchType,
use_safetensors=True, use_safetensors=True,
# variant="fp16" variant=variant,
guidance_scale=cfg_scale) guidance_scale=cfg_scale)
# https://github.com/huggingface/diffusers/issues/4446
# do not use text_encoder in the constructor since then
# https://github.com/huggingface/diffusers/issues/3212#issuecomment-1521841481
if CLIPSKIP and request.CLIPSkip != 0: if CLIPSKIP and request.CLIPSkip != 0:
text_encoder = CLIPTextModel.from_pretrained(clipmodel, num_hidden_layers=request.CLIPSkip, subfolder=clipsubfolder, torch_dtype=torchType) self.clip_skip = request.CLIPSkip
self.pipe.text_encoder=text_encoder else:
self.clip_skip = 0
# torch_dtype needs to be customized. float16 for GPU, float32 for CPU # torch_dtype needs to be customized. float16 for GPU, float32 for CPU
# TODO: this needs to be customized # TODO: this needs to be customized
if request.SchedulerType != "": if request.SchedulerType != "":
self.pipe.scheduler = get_scheduler(request.SchedulerType, self.pipe.scheduler.config) self.pipe.scheduler = get_scheduler(request.SchedulerType, self.pipe.scheduler.config)
self.compel = Compel(tokenizer=self.pipe.tokenizer, text_encoder=self.pipe.text_encoder) self.compel = Compel(tokenizer=self.pipe.tokenizer, text_encoder=self.pipe.text_encoder)
if request.ControlNet:
self.controlnet = ControlNetModel.from_pretrained(
request.ControlNet, torch_dtype=torchType, variant=variant
)
self.pipe.controlnet = self.controlnet
else:
self.controlnet = None
if request.CUDA: if request.CUDA:
self.pipe.to('cuda') self.pipe.to('cuda')
if self.controlnet:
self.controlnet.to('cuda')
# Assume directory from request.ModelFile. # Assume directory from request.ModelFile.
# Only if request.LoraAdapter it's not an absolute path # Only if request.LoraAdapter it's not an absolute path
if request.LoraAdapter and request.ModelFile != "" and not os.path.isabs(request.LoraAdapter) and request.LoraAdapter: if request.LoraAdapter and request.ModelFile != "" and not os.path.isabs(request.LoraAdapter) and request.LoraAdapter:
@ -316,9 +331,15 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
"num_inference_steps": steps, "num_inference_steps": steps,
} }
if request.src != "": if request.src != "" and not self.controlnet:
image = Image.open(request.src) image = Image.open(request.src)
options["image"] = image options["image"] = image
elif self.controlnet and request.src:
pose_image = load_image(request.src)
options["image"] = pose_image
if CLIPSKIP and self.clip_skip != 0:
options["clip_skip"]=self.clip_skip
# Get the keys that we will build the args for our pipe for # Get the keys that we will build the args for our pipe for
keys = options.keys() keys = options.keys()

File diff suppressed because one or more lines are too long

View file

@ -53,7 +53,7 @@ class TestBackendServicer(unittest.TestCase):
self.setUp() self.setUp()
with grpc.insecure_channel("localhost:50051") as channel: with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel) stub = backend_pb2_grpc.BackendStub(channel)
response = stub.LoadModel(backend_pb2.ModelOptions(Model="runwayml/stable-diffusion-v1-5", PipelineType="StableDiffusionPipeline")) response = stub.LoadModel(backend_pb2.ModelOptions(Model="runwayml/stable-diffusion-v1-5"))
self.assertTrue(response.success) self.assertTrue(response.success)
self.assertEqual(response.message, "Model loaded successfully") self.assertEqual(response.message, "Model loaded successfully")
except Exception as err: except Exception as err:
@ -71,7 +71,7 @@ class TestBackendServicer(unittest.TestCase):
self.setUp() self.setUp()
with grpc.insecure_channel("localhost:50051") as channel: with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel) stub = backend_pb2_grpc.BackendStub(channel)
response = stub.LoadModel(backend_pb2.ModelOptions(Model="runwayml/stable-diffusion-v1-5", PipelineType="StableDiffusionPipeline")) response = stub.LoadModel(backend_pb2.ModelOptions(Model="runwayml/stable-diffusion-v1-5"))
print(response.message) print(response.message)
self.assertTrue(response.success) self.assertTrue(response.success)
image_req = backend_pb2.GenerateImageRequest(positive_prompt="cat", width=16,height=16, dst="test.jpg") image_req = backend_pb2.GenerateImageRequest(positive_prompt="cat", width=16,height=16, dst="test.jpg")

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load diff

View file

@ -1,8 +1,8 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT. // Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions: // versions:
// - protoc-gen-go-grpc v1.2.0 // - protoc-gen-go-grpc v1.2.0
// - protoc v4.23.4 // - protoc v3.6.1
// source: backend/backend.proto // source: backend.proto
package proto package proto
@ -453,5 +453,5 @@ var Backend_ServiceDesc = grpc.ServiceDesc{
ServerStreams: true, ServerStreams: true,
}, },
}, },
Metadata: "backend/backend.proto", Metadata: "backend.proto",
} }