diff --git a/invokeai/app/invocations/flux_text_encoder.py b/invokeai/app/invocations/flux_text_encoder.py index ce173a49a15..955d864b166 100644 --- a/invokeai/app/invocations/flux_text_encoder.py +++ b/invokeai/app/invocations/flux_text_encoder.py @@ -1,14 +1,9 @@ -from pathlib import Path - import torch from diffusers.pipelines.flux.pipeline_flux import FluxPipeline -from optimum.quanto import qfloat8 from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation -from invokeai.app.invocations.model import CLIPField, T5EncoderField -from invokeai.app.invocations.fields import InputField, FieldDescriptions, Input -from invokeai.app.invocations.flux_text_to_image import FLUX_MODELS, QuantizedModelForTextEncoding +from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField from invokeai.app.invocations.model import CLIPField, T5EncoderField from invokeai.app.invocations.primitives import ConditioningOutput from invokeai.app.services.shared.invocation_context import InvocationContext @@ -40,7 +35,6 @@ class FluxTextEncoderInvocation(BaseInvocation): # compatible with other ConditioningOutputs. @torch.no_grad() def invoke(self, context: InvocationContext) -> ConditioningOutput: - t5_embeddings, clip_embeddings = self._encode_prompt(context) conditioning_data = ConditioningFieldData( conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)] @@ -48,7 +42,7 @@ def invoke(self, context: InvocationContext) -> ConditioningOutput: conditioning_name = context.conditioning.save(conditioning_data) return ConditioningOutput.build(conditioning_name) - + def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torch.Tensor]: # TODO: Determine the T5 max sequence length based on the model. # if self.model == "flux-schnell": diff --git a/invokeai/app/invocations/flux_text_to_image.py b/invokeai/app/invocations/flux_text_to_image.py index 8ad9a7d4495..bfb1484ed12 100644 --- a/invokeai/app/invocations/flux_text_to_image.py +++ b/invokeai/app/invocations/flux_text_to_image.py @@ -1,13 +1,9 @@ -from pathlib import Path from typing import Literal -from pydantic import Field import torch from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel from diffusers.pipelines.flux.pipeline_flux import FluxPipeline -from invokeai.app.invocations.model import TransformerField, VAEField -from optimum.quanto import qfloat8 from PIL import Image from transformers.models.auto import AutoModelForTextEncoding @@ -19,8 +15,8 @@ InputField, WithBoard, WithMetadata, - UIType, ) +from invokeai.app.invocations.model import TransformerField, VAEField from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.quantization.fast_quantized_diffusion_model import FastQuantizedDiffusersModel @@ -72,7 +68,6 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): @torch.no_grad() def invoke(self, context: InvocationContext) -> ImageOutput: - # Load the conditioning data. cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name) assert len(cond_data.conditionings) == 1 diff --git a/invokeai/backend/quantization/fast_quantized_diffusion_model.py b/invokeai/backend/quantization/fast_quantized_diffusion_model.py index b1531094d13..65b64a69a17 100644 --- a/invokeai/backend/quantization/fast_quantized_diffusion_model.py +++ b/invokeai/backend/quantization/fast_quantized_diffusion_model.py @@ -3,6 +3,7 @@ from typing import Union from diffusers.models.model_loading_utils import load_state_dict +from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel from diffusers.utils import ( CONFIG_NAME, SAFE_WEIGHTS_INDEX_NAME, @@ -12,7 +13,6 @@ ) from optimum.quanto.models import QuantizedDiffusersModel from optimum.quanto.models.shared_dict import ShardedStateDict -from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel from invokeai.backend.requantize import requantize diff --git a/invokeai/backend/quantization/fast_quantized_transformers_model.py b/invokeai/backend/quantization/fast_quantized_transformers_model.py index 5f16bae611b..72636a43fb1 100644 --- a/invokeai/backend/quantization/fast_quantized_transformers_model.py +++ b/invokeai/backend/quantization/fast_quantized_transformers_model.py @@ -1,14 +1,13 @@ import json import os -import torch from typing import Union from optimum.quanto.models import QuantizedTransformersModel from optimum.quanto.models.shared_dict import ShardedStateDict from transformers import AutoConfig from transformers.modeling_utils import get_checkpoint_shard_files, load_state_dict -from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, is_accelerate_available from transformers.models.auto import AutoModelForTextEncoding +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, is_accelerate_available from invokeai.backend.requantize import requantize