Skip to content

Commit

Permalink
docs: Update ControlNet use case docs (#4519)
Browse files Browse the repository at this point in the history
* Update ControlNext use case docs

Signed-off-by: Sherlock113 <[email protected]>

* ci: auto fixes from pre-commit.ci

For more information, see https://pre-commit.ci

---------

Signed-off-by: Sherlock113 <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Sherlock113 and pre-commit-ci[bot] authored Feb 22, 2024
1 parent 504ff63 commit b64ce64
Showing 1 changed file with 20 additions and 47 deletions.
67 changes: 20 additions & 47 deletions docs/source/use-cases/diffusion-models/controlnet.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,10 @@ Create BentoML :doc:`/guides/services` in a ``service.py`` file to specify the s
import typing as t
import cv2
import numpy as np
import PIL
from PIL.Image import Image as PIL_Image
import torch
from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, AutoencoderKL
from pydantic import BaseModel
import bentoml
Expand All @@ -55,21 +52,21 @@ Create BentoML :doc:`/guides/services` in a ``service.py`` file to specify the s
VAE_MODEL_ID = "madebyollin/sdxl-vae-fp16-fix"
BASE_MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0"
@bentoml.service(
traffic={"timeout": 600},
workers=1,
resources={
"gpu": "1",
"gpu": 1,
"gpu_type": "nvidia-l4",
# we can also specify GPU memory requirement:
# "memory": "16Gi",
}
)
class SDXLControlNetService:
class ControlNet:
def __init__(self) -> None:
import torch
from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, AutoencoderKL
if torch.cuda.is_available():
self.device = "cuda"
self.dtype = torch.float16
Expand All @@ -94,58 +91,34 @@ Create BentoML :doc:`/guides/services` in a ``service.py`` file to specify the s
torch_dtype=self.dtype
).to(self.device)
@bentoml.api
async def generate(
self,
prompt: str,
arr: np.ndarray[t.Any, np.uint8],
**kwargs,
):
image = PIL.Image.fromarray(arr)
return self.pipe(prompt, image=image, **kwargs).to_tuple()
class Params(BaseModel):
prompt: str
negative_prompt: t.Optional[str]
controlnet_conditioning_scale: float = 0.5
num_inference_steps: int = 25
@bentoml.service(
name="sdxl-controlnet-service",
traffic={"timeout": 600},
workers=8,
resources={"cpu": "1"}
)
class ControlNet:
controlnet_service: SDXLControlNetService = bentoml.depends(SDXLControlNetService)
@bentoml.api
async def generate(self, image: PIL_Image, params: Params) -> PIL_Image:
import cv2
arr = np.array(image)
arr = cv2.Canny(arr, 100, 200)
arr = arr[:, :, None]
arr = np.concatenate([arr, arr, arr], axis=2)
params_d = params.dict()
prompt = params_d.pop("prompt")
res = await self.controlnet_service.generate(
image = PIL.Image.fromarray(arr)
return self.pipe(
prompt,
arr=arr,
image=image,
**params_d
)
return res[0][0]
This file defines the following classes:
).to_tuple()[0][0]
* ``SDXLControlNetService``: A BentoML Service with custom configurations in timeout, worker count, and resources.
class Params(BaseModel):
prompt: str
negative_prompt: t.Optional[str]
controlnet_conditioning_scale: float = 0.5
num_inference_steps: int = 25
- It loads the three pre-trained models and configures them to use GPU if available. The main pipeline (``StableDiffusionXLControlNetPipeline``) integrates these models.
- It defines an API endpoint ``generate`` to process a text prompt and an image array. The processed image is converted to a tuple and returned.
This file defines a BentoML Service ``ControlNet`` with custom :doc:`configurations </guides/configurations>` in timeout, worker count, and resources.

* ``Params``: This is a ``pydantic`` model defining the structure for input parameters.
* ``ControlNet``: A BentoML Service with custom configurations in timeout, worker count, and resources. ``ControlNet`` doesn't create images itself. Instead, it preprocesses the image and forwards it along with the text prompt to the ``SDXLControlNetService`` Service. The ``generate`` method in ``ControlNet`` then returns the final generated image.
- It loads the three pre-trained models and configures them to use GPU if available. The main pipeline (``StableDiffusionXLControlNetPipeline``) integrates these models.
- It defines an asynchronous API endpoint ``generate``, which takes an image and a set of parameters as input. The parameters for the generation process are extracted from a ``Params`` instance, a Pydantic model that provides automatic data validation.
- The ``generate`` method returns the generated image by calling the pipeline with the processed image and text prompts.

Run ``bentoml serve`` in your project directory to start the BentoML server.

Expand Down

0 comments on commit b64ce64

Please sign in to comment.