From 8dd4e822f87e1b4259755a2181218797ceecc410 Mon Sep 17 00:00:00 2001 From: Harish Anand Date: Mon, 29 Aug 2022 10:00:33 -0700 Subject: [PATCH] make onnx compatible - remove broadcast_to - remove returning a dictionary --- src/diffusers/models/unet_2d_condition.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 25c4e37d8a6d..616c103ff06f 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -134,7 +134,8 @@ def forward( timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension - timesteps = timesteps.broadcast_to(sample.shape[0]) + #timesteps = timesteps.broadcast_to(sample.shape[0]) + timesteps = timesteps * torch.ones(sample.shape[0]) t_emb = self.time_proj(timesteps) emb = self.time_embedding(t_emb) @@ -181,6 +182,6 @@ def forward( sample = self.conv_act(sample) sample = self.conv_out(sample) - output = {"sample": sample} + #output = {"sample": sample} - return output + return sample