Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make classification models inherit from BaseClassifier #1314

Merged
merged 7 commits into from
Jul 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/super_gradients/training/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import warnings

from .sg_module import SgModule
from .classification_models.base_classifer import BaseClassifier

# Classification models
from super_gradients.training.models.classification_models.beit import Beit, BeitLargePatch16_224, BeitBasePatch16_224
Expand Down Expand Up @@ -197,6 +198,7 @@ def inner(*args, **kwargs):
"NDFLHeads",
"YoloNASPANNeckWithC2",
"SgModule",
"BaseClassifier",
"Beit",
"BeitLargePatch16_224",
"BeitBasePatch16_224",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,27 +1,25 @@
from typing import Optional, List
from functools import lru_cache

from super_gradients.training.models import SgModule
from super_gradients.training.pipelines.pipelines import ClassificationPipeline
from super_gradients.training.utils.media.image import ImageSource
from super_gradients.training.utils.predict import ImagesPredictions
from super_gradients.training.utils.predict import ImagesClassificationPrediction
from super_gradients.common.decorators.factory_decorator import resolve_param
from super_gradients.common.factories.processing_factory import ProcessingFactory
from super_gradients.training.processing.processing import Processing
from typing import Optional, List


class BaseClassifier(SgModule):
def __init__(
self,
):
self._class_names: Optional[List[str]] = None
self._image_processor: Optional[Processing] = None
super(BaseClassifier, self).__init__()

@resolve_param("image_processor", ProcessingFactory())
def set_dataset_processing_params(
self,
class_names: Optional[List[str]] = None,
image_processor: Optional[Processing] = None,
) -> None:
def set_dataset_processing_params(self, class_names: Optional[List[str]] = None, image_processor: Optional[Processing] = None) -> None:
"""Set the processing parameters for the dataset.

:param class_names: (Optional) Names of the dataset the model was trained on.
Expand All @@ -48,17 +46,17 @@ def _get_pipeline(self, fuse_model: bool = True) -> ClassificationPipeline:
)
return pipeline

def predict(self, images: ImageSource, batch_size: int = 32, fuse_model: bool = True) -> ImagesPredictions:
def predict(self, images: ImageSource, batch_size: int = 32, fuse_model: bool = True) -> ImagesClassificationPrediction:
"""Predict an image or a list of images.

:param images: Images to predict.
:param images: Images to predict.
:param batch_size: Maximum number of images to process at the same time.
:param fuse_model: If True, create a copy of the model, and fuse some of its layers to increase performance. This increases memory usage.
:param fuse_model: If True, create a copy of the model, and fuse some of its layers to increase performance. This increases memory usage.
"""
pipeline = self._get_pipeline(fuse_model=fuse_model)
return pipeline(images, batch_size=batch_size) # type: ignore

def predict_webcam(self, fuse_model: bool = True):
def predict_webcam(self, fuse_model: bool = True) -> None:
"""Predict using webcam.
:param fuse_model: If True, create a copy of the model, and fuse some of its layers to increase performance. This increases memory usage.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from super_gradients.training.utils.regularization_utils import DropPath
from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.training.utils import HpmStruct, torch_version_is_greater_or_equal
from super_gradients.training.models import SgModule
from super_gradients.training.models import BaseClassifier

logger = get_logger(__name__)

Expand Down Expand Up @@ -290,7 +290,7 @@ def forward(self):
return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww


class Beit(SgModule):
class Beit(BaseClassifier):
"""Vision Transformer with support for patch or hybrid CNN input stage"""

def __init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from super_gradients.common.registry.registry import register_model
from super_gradients.common.object_names import Models
from super_gradients.training.models.sg_module import SgModule
from super_gradients.training.models import BaseClassifier

"""Densenet-BC model class, based on
"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Expand Down Expand Up @@ -72,7 +72,7 @@ def __init__(self, num_input_features, num_output_features):
self.add_module("pool", nn.AvgPool2d(kernel_size=2, stride=2))


class DenseNet(SgModule):
class DenseNet(BaseClassifier):
def __init__(self, growth_rate: int, structure: list, num_init_features: int, bn_size: int, drop_rate: float, num_classes: int):
"""
:param growth_rate: number of filter to add each layer (noted as 'k' in the paper)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch.nn as nn
import torch.nn.functional as F

from super_gradients.training.models.sg_module import SgModule
from super_gradients.training.models import BaseClassifier


class Bottleneck(nn.Module):
Expand Down Expand Up @@ -42,7 +42,7 @@ def forward(self, x):
return out


class DPN(SgModule):
class DPN(BaseClassifier):
def __init__(
self,
in_planes: Tuple[int, int, int, int],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from super_gradients.common.registry.registry import register_model
from super_gradients.common.object_names import Models
from super_gradients.training.utils import HpmStruct
from super_gradients.training.models.sg_module import SgModule
from super_gradients.training.models import BaseClassifier

# Parameters for an individual model block
BlockArgs = collections.namedtuple(
Expand Down Expand Up @@ -393,7 +393,7 @@ def forward(self, inputs: torch.Tensor, drop_connect_rate: Optional[float] = Non
return x


class EfficientNet(SgModule):
class EfficientNet(BaseClassifier):
"""
EfficientNet model.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@

from super_gradients.common.registry.registry import register_model
from super_gradients.common.object_names import Models
from super_gradients.training.models.sg_module import SgModule
from super_gradients.training.models import BaseClassifier

GoogLeNetOutputs = namedtuple("GoogLeNetOutputs", ["log_", "aux_logits2", "aux_logits1"])


class GoogLeNet(SgModule):
class GoogLeNet(BaseClassifier):
def __init__(self, num_classes=1000, aux_logits=True, init_weights=True, backbone_mode=False, dropout=0.3):
super(GoogLeNet, self).__init__()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
"""
import torch.nn as nn
import torch.nn.functional as F
from super_gradients.training.models.sg_module import SgModule
from super_gradients.training.models import BaseClassifier


class LeNet(SgModule):
class LeNet(BaseClassifier):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""
import torch.nn as nn
import torch.nn.functional as F
from super_gradients.training.models.sg_module import SgModule
from super_gradients.training.models import BaseClassifier


class Block(nn.Module):
Expand All @@ -24,7 +24,7 @@ def forward(self, x):
return out


class MobileNet(SgModule):
class MobileNet(BaseClassifier):
# (128,2) means conv planes=128, conv stride=2, by default conv stride=1
cfg = [64, 128, (128, 2), 256, (256, 2), 512, 512, 512, 512, 512, (512, 2), 1024, (1024, 2)]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@

from super_gradients.common.registry.registry import register_model
from super_gradients.common.object_names import Models
from super_gradients.training.models.sg_module import SgModule
from super_gradients.training.models import BaseClassifier
from super_gradients.training.utils.utils import get_param


class MobileNetBase(SgModule):
class MobileNetBase(BaseClassifier):
def __init__(self):
super(MobileNetBase, self).__init__()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from super_gradients.training.models.sg_module import SgModule
from super_gradients.training.models import BaseClassifier


class SepConv(nn.Module):
Expand Down Expand Up @@ -71,7 +71,7 @@ def forward(self, x):
return F.relu(self.bn2(self.conv2(y)))


class PNASNet(SgModule):
class PNASNet(BaseClassifier):
def __init__(self, cell_type, num_cells, num_planes):
super(PNASNet, self).__init__()
self.in_planes = num_planes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from super_gradients.training.models.sg_module import SgModule
from super_gradients.training.models import BaseClassifier


class PreActBlock(nn.Module):
Expand Down Expand Up @@ -63,7 +63,7 @@ def forward(self, x):
return out


class PreActResNet(SgModule):
class PreActResNet(BaseClassifier):
def __init__(self, block, num_blocks, num_classes=10):
super(PreActResNet, self).__init__()
self.in_planes = 64
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from super_gradients.common.registry.registry import register_model
from super_gradients.common.object_names import Models
from super_gradients.modules import Residual
from super_gradients.training.models.sg_module import SgModule
from super_gradients.training.models import BaseClassifier
from super_gradients.training.utils.regularization_utils import DropPath
from super_gradients.training.utils.utils import get_param

Expand Down Expand Up @@ -110,7 +110,7 @@ def forward(self, x):
return x


class AnyNetX(SgModule):
class AnyNetX(BaseClassifier):
def __init__(
self,
ls_num_blocks,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
from super_gradients.common.registry.registry import register_model
from super_gradients.common.object_names import Models
from super_gradients.modules import RepVGGBlock, SEBlock
from super_gradients.training.models.sg_module import SgModule
from super_gradients.training.models import BaseClassifier
from super_gradients.modules.repvgg_block import fuse_repvgg_blocks_residual_branches
from super_gradients.training.utils.utils import get_param


class RepVGG(SgModule):
class RepVGG(BaseClassifier):
def __init__(
self,
struct,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import torch.nn as nn
import torch.nn.functional as F
from super_gradients.training.models.sg_module import SgModule
from super_gradients.training.models import BaseClassifier

from super_gradients.common.object_names import Models
from super_gradients.common.registry.registry import register_model
Expand Down Expand Up @@ -83,7 +83,7 @@ def forward(self, x):
return out


class CifarResNet(SgModule):
class CifarResNet(BaseClassifier):
def __init__(self, block, num_blocks, num_classes=10, width_mult=1, expansion=1):
super(CifarResNet, self).__init__()
self.expansion = expansion
Expand Down Expand Up @@ -126,7 +126,7 @@ def forward(self, x):
return out


class ResNet(SgModule):
class ResNet(BaseClassifier):
def __init__(
self,
block,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from super_gradients.common.registry.registry import register_model
from super_gradients.common.object_names import Models
from super_gradients.training.models.sg_module import SgModule
from super_gradients.training.models import BaseClassifier


def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
Expand Down Expand Up @@ -68,7 +68,7 @@ def forward(self, x):
return out


class ResNeXt(SgModule):
class ResNeXt(BaseClassifier):
def __init__(self, layers, cardinality, bottleneck_width, num_classes=10, replace_stride_with_dilation=None):
super(ResNeXt, self).__init__()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from super_gradients.training.models.sg_module import SgModule
from super_gradients.training.models import BaseClassifier


class BasicBlock(nn.Module):
Expand Down Expand Up @@ -74,7 +74,7 @@ def forward(self, x):
return out


class SENet(SgModule):
class SENet(BaseClassifier):
def __init__(self, block, num_blocks, num_classes=10):
super(SENet, self).__init__()
self.in_planes = 64
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from super_gradients.training.models.sg_module import SgModule
from super_gradients.training.models import BaseClassifier


class ShuffleBlock(nn.Module):
Expand Down Expand Up @@ -51,7 +51,7 @@ def forward(self, x):
return out


class ShuffleNet(SgModule):
class ShuffleNet(BaseClassifier):
def __init__(self, cfg):
super(ShuffleNet, self).__init__()
out_planes = cfg["out_planes"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from super_gradients.common.registry.registry import register_model
from super_gradients.common.object_names import Models
from super_gradients.training.utils import HpmStruct
from super_gradients.training.models.sg_module import SgModule
from super_gradients.training.models import BaseClassifier


__all__ = ["ShuffleNetV2Base", "ShufflenetV2_x0_5", "ShufflenetV2_x1_0", "ShufflenetV2_x1_5", "ShufflenetV2_x2_0", "CustomizedShuffleNetV2"]
Expand Down Expand Up @@ -112,7 +112,7 @@ def forward(self, x: Tensor) -> Tensor:
return out


class ShuffleNetV2Base(SgModule):
class ShuffleNetV2Base(BaseClassifier):
def __init__(
self,
structure: List[int],
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""VGG11/13/16/19 in Pytorch. Adapted from https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py"""
import torch
import torch.nn as nn
from super_gradients.training.models.sg_module import SgModule
from super_gradients.training.models import BaseClassifier

cfg = {
"VGG11": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
Expand All @@ -11,7 +11,7 @@
}


class VGG(SgModule):
class VGG(BaseClassifier):
def __init__(self, vgg_name):
super(VGG, self).__init__()
self.features = self._make_layers(cfg[vgg_name])
Expand Down
Loading