Skip to content

Commit

Permalink
Merge pull request #11 from fuzailpalnak/refactor/refinenet
Browse files Browse the repository at this point in the history
Refactor/refinenet
  • Loading branch information
fuzailpalnak authored Nov 2, 2020
2 parents e790a55 + a1e6a0a commit f6c8e64
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,8 @@ def forward(self, backbone_features, refine_block_features=None):
class ReFineNet(nn.Module):
def __init__(
self,
res_net_to_use,
pre_trained_image_net,
res_net_to_use="resnet34",
pre_trained_image_net=True,
top_layers_trainable=True,
num_classes=1,
):
Expand Down Expand Up @@ -236,7 +236,7 @@ def __init__(
in_planes=256, out_planes=256
)

self.final_layer = convolution_3x3(in_planes=256, out_planes=self.num_classes)
self.final_layer = convolution_1x1(in_planes=256, out_planes=self.num_classes)

def forward(self, input_feature: Tensor) -> Tensor:

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@ def forward(self, backbone_features, refine_block_features=None):
class ReFineNetLite(nn.Module):
def __init__(
self,
res_net_to_use,
pre_trained_image_net,
res_net_to_use="resnet34",
pre_trained_image_net=True,
top_layers_trainable=True,
num_classes=1,
):
Expand Down

0 comments on commit f6c8e64

Please sign in to comment.