Skip to content

Commit

Permalink
Add nf4 bnb quantized format
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonrising committed Aug 19, 2024
1 parent 9c4576c commit 45f5af3
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 7 deletions.
4 changes: 2 additions & 2 deletions invokeai/app/invocations/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,12 @@ def invoke(self, context: InvocationContext) -> ModelIdentifierOutput:
T5_ENCODER_OPTIONS = Literal["base", "8b_quantized"]
T5_ENCODER_MAP: Dict[str, Dict[str, str]] = {
"base": {
"repo": "invokeai/flux_dev::t5_xxl_encoder/base",
"repo": "InvokeAI/flux_schnell::t5_xxl_encoder/base",
"name": "t5_base_encoder",
"format": ModelFormat.T5Encoder,
},
"8b_quantized": {
"repo": "invokeai/flux_dev::t5_xxl_encoder/8b_quantized",
"repo": "invokeai/flux_dev::t5_xxl_encoder/optimum_quanto_qfloat8",
"name": "t5_8b_quantized_encoder",
"format": ModelFormat.T5Encoder,
},
Expand Down
21 changes: 18 additions & 3 deletions invokeai/backend/model_manager/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ class ModelFormat(str, Enum):
T5Encoder = "t5_encoder"
T5Encoder8b = "t5_encoder_8b"
T5Encoder4b = "t5_encoder_4b"
BnbQuantizednf4b = "bnb_quantized_nf4b"


class SchedulerPredictionType(str, Enum):
Expand Down Expand Up @@ -193,7 +194,7 @@ def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> N
class CheckpointConfigBase(ModelConfigBase):
"""Model config for checkpoint-style models."""

format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
format: Literal[ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b] = Field(description="Format of the provided checkpoint model", default=ModelFormat.Checkpoint)
config_path: str = Field(description="path to the checkpoint model config file")
converted_at: Optional[float] = Field(
description="When this model was last converted to diffusers", default_factory=time.time
Expand Down Expand Up @@ -248,7 +249,6 @@ class VAECheckpointConfig(CheckpointConfigBase):
"""Model config for standalone VAE models."""

type: Literal[ModelType.VAE] = ModelType.VAE
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint

@staticmethod
def get_tag() -> Tag:
Expand Down Expand Up @@ -287,7 +287,6 @@ class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase)
"""Model config for ControlNet models (diffusers version)."""

type: Literal[ModelType.ControlNet] = ModelType.ControlNet
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint

@staticmethod
def get_tag() -> Tag:
Expand Down Expand Up @@ -336,6 +335,21 @@ def get_tag() -> Tag:
return Tag(f"{ModelType.Main.value}.{ModelFormat.Checkpoint.value}")


class MainBnbQuantized4bCheckpointConfig(CheckpointConfigBase, MainConfigBase):
"""Model config for main checkpoint models."""

prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.format = ModelFormat.BnbQuantizednf4b

@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Main.value}.{ModelFormat.BnbQuantizednf4b.value}")


class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase):
"""Model config for main diffusers models."""

Expand Down Expand Up @@ -438,6 +452,7 @@ def get_model_discriminator_value(v: Any) -> str:
Union[
Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
Annotated[MainBnbQuantized4bCheckpointConfig, MainBnbQuantized4bCheckpointConfig.get_tag()],
Annotated[VAEDiffusersConfig, VAEDiffusersConfig.get_tag()],
Annotated[VAECheckpointConfig, VAECheckpointConfig.get_tag()],
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
Expand Down
8 changes: 6 additions & 2 deletions invokeai/backend/model_manager/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def probe(
fields["description"] = (
fields.get("description") or f"{fields['base'].value} {model_type.value} model {fields['name']}"
)
fields["format"] = ModelFormat(fields.get("format")) or probe.get_format()
fields["format"] = ModelFormat(fields.get("format")) if "format" in fields else probe.get_format()
fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path)

fields["default_settings"] = fields.get("default_settings")
Expand All @@ -179,7 +179,7 @@ def probe(
# additional fields needed for main and controlnet models
if (
fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE]
and fields["format"] is ModelFormat.Checkpoint
and fields["format"] in [ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b]
):
ckpt_config_path = cls._get_checkpoint_config_path(
model_path,
Expand Down Expand Up @@ -323,6 +323,7 @@ def _get_checkpoint_config_path(

if model_type is ModelType.Main:
if base_type == BaseModelType.Flux:
# TODO: Decide between dev/schnell
config_file = "flux/flux1-schnell.yaml"
else:
config_file = LEGACY_CONFIGS[base_type][variant_type]
Expand Down Expand Up @@ -422,6 +423,9 @@ def __init__(self, model_path: Path):
self.checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)

def get_format(self) -> ModelFormat:
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
if "double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4" in state_dict:
return ModelFormat.BnbQuantizednf4b
return ModelFormat("checkpoint")

def get_variant_type(self) -> ModelVariantType:
Expand Down

0 comments on commit 45f5af3

Please sign in to comment.