Skip to content

Commit

Permalink
Update unet.py (Lightning-AI#1955)
Browse files Browse the repository at this point in the history
  • Loading branch information
nanddalal authored May 26, 2020
1 parent d0ec11b commit c967b88
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion pl_examples/models/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,10 @@ def __init__(self, in_ch: int, out_ch: int, bilinear: bool = False):
super().__init__()
self.upsample = None
if bilinear:
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.upsample = nn.Sequential(
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
nn.Conv2d(in_ch, in_ch // 2, kernel_size=1),
)
else:
self.upsample = nn.ConvTranspose2d(in_ch, in_ch // 2, kernel_size=2, stride=2)

Expand Down

0 comments on commit c967b88

Please sign in to comment.