Skip to content

Commit

Permalink
changed to absolute imports and added docs (#881)
Browse files Browse the repository at this point in the history
  • Loading branch information
akshaykulkarni07 authored Feb 17, 2020
1 parent f44dfb3 commit 0ad3e8b
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 12 deletions.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,33 @@
import torch.nn as nn
import torch.nn.functional as F

from parts import DoubleConv, Down, Up
from models.unet.parts import DoubleConv, Down, Up


class UNet(nn.Module):
'''
Architecture based on U-Net: Convolutional Networks for Biomedical Image Segmentation
Link - https://arxiv.org/abs/1505.04597
Parameters:
num_classes (int) - Number of output classes required (default 19 for KITTI dataset)
bilinear (bool) - Whether to use bilinear interpolation or transposed
convolutions for upsampling.
'''
def __init__(self, num_classes=19, bilinear=False):
super().__init__()
self.bilinear = bilinear
self.num_classes = num_classes
self.layer1 = DoubleConv(3, 64)
self.layer2 = Down(64, 128)
self.layer3 = Down(128, 256)
self.layer4 = Down(256, 512)
self.layer5 = Down(512, 1024)

self.layer6 = Up(1024, 512, bilinear=self.bilinear)
self.layer7 = Up(512, 256, bilinear=self.bilinear)
self.layer8 = Up(256, 128, bilinear=self.bilinear)
self.layer9 = Up(128, 64, bilinear=self.bilinear)
self.layer6 = Up(1024, 512, bilinear=bilinear)
self.layer7 = Up(512, 256, bilinear=bilinear)
self.layer8 = Up(256, 128, bilinear=bilinear)
self.layer9 = Up(128, 64, bilinear=bilinear)

self.layer10 = nn.Conv2d(64, self.num_classes, kernel_size=1)
self.layer10 = nn.Conv2d(64, num_classes, kernel_size=1)

def forward(self, x):
x1 = self.layer1(x)
Expand Down

0 comments on commit 0ad3e8b

Please sign in to comment.