Skip to content

Commit

Permalink
Working inference node with quantized bnb nf4 checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonrising committed Aug 19, 2024
1 parent a1c6213 commit 00b63af
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 11 deletions.
14 changes: 8 additions & 6 deletions invokeai/app/invocations/flux_text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def _run_diffusion(
img, img_ids = self._prepare_latent_img_patches(x)

# HACK(ryand): Find a better way to determine if this is a schnell model or not.
is_schnell = "shnell" in transformer_info.config.path if transformer_info.config else ""
is_schnell = "schnell" in transformer_info.config.path if transformer_info.config else ""
timesteps = get_schedule(
num_steps=self.num_steps,
image_seq_len=img.shape[1],
Expand Down Expand Up @@ -139,9 +139,9 @@ def _prepare_latent_img_patches(self, latent_img: torch.Tensor) -> tuple[torch.T
img = repeat(img, "1 ... -> bs ...", bs=bs)

# Generate patch position ids.
img_ids = torch.zeros(h // 2, w // 2, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)

return img, img_ids
Expand All @@ -155,8 +155,10 @@ def _run_vae_decoding(
with vae_info as vae:
assert isinstance(vae, AutoEncoder)
# TODO(ryand): Test that this works with both float16 and bfloat16.
with torch.autocast(device_type=latents.device.type, dtype=TorchDevice.choose_torch_dtype()):
img = vae.decode(latents)
# with torch.autocast(device_type=latents.device.type, dtype=torch.float32):
vae.to(torch.float32)
latents.to(torch.float32)
img = vae.decode(latents)

img.clamp(-1, 1)
img = rearrange(img[0], "c h w -> h w c")
Expand Down
62 changes: 57 additions & 5 deletions invokeai/backend/model_manager/load/model_loaders/flux.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team
"""Class for Flux model loading in InvokeAI."""

import accelerate
import torch
from dataclasses import fields
from pathlib import Path
from typing import Any, Optional
Expand All @@ -24,13 +26,15 @@
CheckpointConfigBase,
CLIPEmbedDiffusersConfig,
MainCheckpointConfig,
MainBnbQuantized4bCheckpointConfig,
T5EncoderConfig,
VAECheckpointConfig,
)
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.silence_warnings import SilenceWarnings
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4

app_config = get_config()

Expand Down Expand Up @@ -62,7 +66,7 @@ def _load_model(
with SilenceWarnings():
model = load_class(params).to(self._torch_dtype)
# load_sft doesn't support torch.device
sd = load_file(model_path, device=str(TorchDevice.choose_torch_device()))
sd = load_file(model_path)
model.load_state_dict(sd, strict=False, assign=True)

return model
Expand Down Expand Up @@ -105,9 +109,9 @@ def _load_model(

match submodel_type:
case SubModelType.Tokenizer2:
return T5Tokenizer.from_pretrained(Path(config.path) / "encoder", max_length=512)
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
case SubModelType.TextEncoder2:
return T5EncoderModel.from_pretrained(Path(config.path) / "tokenizer")
return T5EncoderModel.from_pretrained(Path(config.path) / "text_encoder_2") #TODO: Fix hf subfolder install

raise Exception("Only Checkpoint Flux models are currently supported.")

Expand Down Expand Up @@ -152,7 +156,55 @@ def _load_from_singlefile(

with SilenceWarnings():
model = load_class(params).to(self._torch_dtype)
# load_sft doesn't support torch.device
sd = load_file(model_path, device=str(TorchDevice.choose_torch_device()))
sd = load_file(model_path)
model.load_state_dict(sd, strict=False, assign=True)
return model


@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.BnbQuantizednf4b)
class FluxBnbQuantizednf4bCheckpointModel(GenericDiffusersLoader):
"""Class to load main models."""

def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, CheckpointConfigBase):
raise Exception("Only Checkpoint Flux models are currently supported.")
legacy_config_path = app_config.legacy_conf_path / config.config_path
config_path = legacy_config_path.as_posix()
with open(config_path, "r") as stream:
try:
flux_conf = yaml.safe_load(stream)
except:
raise

match submodel_type:
case SubModelType.Transformer:
return self._load_from_singlefile(config, flux_conf)

raise Exception("Only Checkpoint Flux models are currently supported.")

def _load_from_singlefile(
self,
config: AnyModelConfig,
flux_conf: Any,
) -> AnyModel:
assert isinstance(config, MainBnbQuantized4bCheckpointConfig)
load_class = Flux
params = None
model_path = Path(config.path)
dataclass_fields = {f.name for f in fields(FluxParams)}
filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields}
params = FluxParams(**filtered_data)

with SilenceWarnings():
with accelerate.init_empty_weights():
model = load_class(params)
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
# TODO(ryand): Right now, some of the weights are loaded in bfloat16. Think about how best to handle
# this on GPUs without bfloat16 support.
sd = load_file(model_path)
model.load_state_dict(sd, strict=False, assign=True)
return model

0 comments on commit 00b63af

Please sign in to comment.