diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 25c4e37d8a6d..616c103ff06f 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -134,7 +134,8 @@ def forward( timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension - timesteps = timesteps.broadcast_to(sample.shape[0]) + #timesteps = timesteps.broadcast_to(sample.shape[0]) + timesteps = timesteps * torch.ones(sample.shape[0]) t_emb = self.time_proj(timesteps) emb = self.time_embedding(t_emb) @@ -181,6 +182,6 @@ def forward( sample = self.conv_act(sample) sample = self.conv_out(sample) - output = {"sample": sample} + #output = {"sample": sample} - return output + return sample