Skip to content

Commit

Permalink
make onnx compatible
Browse files Browse the repository at this point in the history
- remove broadcast_to
- remove returning a dictionary
  • Loading branch information
harishanand95 committed Aug 29, 2022
1 parent efa773a commit 8dd4e82
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ def forward(
timesteps = timesteps[None].to(sample.device)

# broadcast to batch dimension
timesteps = timesteps.broadcast_to(sample.shape[0])
#timesteps = timesteps.broadcast_to(sample.shape[0])
timesteps = timesteps * torch.ones(sample.shape[0])

t_emb = self.time_proj(timesteps)
emb = self.time_embedding(t_emb)
Expand Down Expand Up @@ -181,6 +182,6 @@ def forward(
sample = self.conv_act(sample)
sample = self.conv_out(sample)

output = {"sample": sample}
#output = {"sample": sample}

return output
return sample

0 comments on commit 8dd4e82

Please sign in to comment.