diff --git a/extra/grpc/diffusers/backend_diffusers.py b/extra/grpc/diffusers/backend_diffusers.py index 8dab0b57..506eb5fd 100755 --- a/extra/grpc/diffusers/backend_diffusers.py +++ b/extra/grpc/diffusers/backend_diffusers.py @@ -14,8 +14,10 @@ import torch from torch import autocast from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline, DiffusionPipeline, EulerAncestralDiscreteScheduler from diffusers.pipelines.stable_diffusion import safety_checker +from compel import Compel _ONE_DAY_IN_SECONDS = 60 * 60 * 24 +COMPEL=os.environ.get("COMPEL", "1") == "1" # https://github.com/CompVis/stable-diffusion/issues/239#issuecomment-1627615287 def sc(self, clip_input, images) : return images, [False for i in images] @@ -99,6 +101,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(self.pipe.scheduler.config, algorithm_type="sde-dpmsolver++") if request.CUDA: self.pipe.to('cuda') + + self.compel = Compel(tokenizer=self.pipe.tokenizer, text_encoder=self.pipe.text_encoder) except Exception as err: return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") # Implement your logic here for the LoadModel service @@ -127,12 +131,20 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): # create a dictionary of parameters by using the keys from EnableParameters and the values from defaults kwargs = {key: options[key] for key in keys} - - # pass the kwargs dictionary to the self.pipe method - image = self.pipe( - prompt, - **kwargs - ).images[0] + image = {} + if COMPEL: + conditioning = self.compel.build_conditioning_tensor(prompt) + kwargs["prompt_embeds"]= conditioning + # pass the kwargs dictionary to the self.pipe method + image = self.pipe( + **kwargs + ).images[0] + else: + # pass the kwargs dictionary to the self.pipe method + image = self.pipe( + prompt, + **kwargs + ).images[0] # save the result image.save(request.dst) diff --git a/extra/requirements.txt b/extra/requirements.txt index 96ab2eaa..fb3cc012 100644 --- a/extra/requirements.txt +++ b/extra/requirements.txt @@ -3,4 +3,5 @@ grpcio google protobuf six -omegaconf \ No newline at end of file +omegaconf +compel \ No newline at end of file