Skip to content

Commit

Permalink
disable early stopping; there is a bug when validation percentage is set
Browse files Browse the repository at this point in the history
will hopefully be fixed
(Lightning-AI/pytorch-lightning#524)
add 3d model without dropout
  • Loading branch information
fellnerse committed Nov 27, 2019
1 parent 54d081d commit 2f378d5
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 7 deletions.
4 changes: 4 additions & 0 deletions src/forgery_detection/lightning/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,17 @@
from forgery_detection.models.image.multi_class_classification import Resnet18MultiHead
from forgery_detection.models.utils import SequenceClassificationModel
from forgery_detection.models.video.multi_class_classification import Resnet183D
from forgery_detection.models.video.multi_class_classification import (
Resnet183DNoDropout,
)
from forgery_detection.utils import cl_logger


class Supervised(pl.LightningModule):
MODEL_DICT = {
"resnet18multiclassdropout": Resnet18MultiClassDropout,
"resnet183d": Resnet183D,
"resnet183dnodropout": Resnet183DNoDropout,
"resnet18heads": Resnet18MultiHead,
}

Expand Down
12 changes: 6 additions & 6 deletions src/forgery_detection/lightning/train.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import click
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping

from forgery_detection.lightning.logging.utils import get_logger_and_checkpoint_callback
from forgery_detection.lightning.logging.utils import SystemMode
from forgery_detection.lightning.system import Supervised
from forgery_detection.lightning.utils import PythonLiteralOptionGPUs
from forgery_detection.lightning.utils import VAL_ACC


@click.command(context_settings=dict(help_option_names=["-h", "--help"]))
Expand Down Expand Up @@ -74,10 +72,12 @@ def run_lightning(*args, **kwargs):
model = Supervised(kwargs)

# early stopping
# todo do i neee to have a filter for early stopping? or only for saving checkpoints
early_stopping_callback = EarlyStopping(
monitor=VAL_ACC, patience=10, verbose=True, mode="max"
)
# somehow does not work any more, the logs it receives are from train and not from
# val
early_stopping_callback = None
# EarlyStopping(
# monitor=VAL_ACC, patience=1, verbose=True, mode="max"
# )

trainer = Trainer(
gpus=kwargs["gpus"],
Expand Down
29 changes: 28 additions & 1 deletion src/forgery_detection/models/video/multi_class_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


class Resnet183D(PretrainedResnet18):
def __init__(self,):
def __init__(self):
super().__init__(num_classes=5, sequence_length=8, contains_dropout=True)
self.resnet = resnet18(pretrained=True, num_classes=1000)

Expand Down Expand Up @@ -35,3 +35,30 @@ def forward(self, x):
x = self.resnet.fc(x)

return x


class Resnet183DNoDropout(PretrainedResnet18):
def __init__(self):
super().__init__(num_classes=5, sequence_length=8, contains_dropout=False)
self.resnet = resnet18(pretrained=True, num_classes=1000)

self.resnet.conv1 = nn.Conv3d(8, 64, kernel_size=(3, 7, 7), bias=False)
self.resnet.layer4 = nn.Identity()
self.resnet.fc = nn.Linear(256, self.num_classes)

def forward(self, x):
x = self.resnet.conv1(x).squeeze(2)
x = self.resnet.bn1(x)
x = self.resnet.relu(x)
x = self.resnet.maxpool(x)

x = self.resnet.layer1(x)
x = self.resnet.layer2(x)
x = self.resnet.layer3(x)
x = self.resnet.layer4(x)

x = self.resnet.avgpool(x)
x = torch.flatten(x, 1)
x = self.resnet.fc(x)

return x

0 comments on commit 2f378d5

Please sign in to comment.