Skip to content

Multi-Class Image Classification using Transfer Learning on Mobile Gallery Images in PyTorch

Notifications You must be signed in to change notification settings

n0obcoder/Mobile-Gallery-Image-Classification-in-PyTorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

23 Commits
 
 
 
 
 
 

Repository files navigation

Mobile-Gallery-Image-Classification-in-PyTorch

Multi-Class Image Classification on Mobile Gallery Images using Transfer Learning in PyTorch.

Introduction

Using the images present in your mobile gallery to train an Image Classifier using Transfer-Learning ! :D a meme which will be used for training an Image Classifier

STEP 1: Building a Custom Dataset

Dataset that I have used is https://www.kaggle.com/n0obcoder/mobile-gallery-image-classification-data It has 6 classes -

  • Cars
  • Memes
  • Mountains
  • Selfies
  • Trees
  • Whataspp_Screenshots

A few of the sample images form the training set are shown below

loss after 12 epochs

STEP 2: Data Pre-Processing and Making DataLoaders

Following are the transforms (ordered) applied to the images while training and testing-

  • Resizing to (224, 224)
  • Random Horizontal Flips (Only applied during the training phase)
  • ToTensor (to convert the images into tensors)
  • Normalization (using the ImageNet stats)

STEP 3: Defining a Suitable Model and Making the Necessary Tweaks

Architecture : Resnet34 I have replaced the last linear layer of the resnet34 with another linear layer which has 6 neurons present in it (6 is the number of classes present in the Mobile Gallery Image Dataset mentioned above in STEP 1). resnet34c

STEP 4: Transfer Learning by Freezing and Un-Freezing the Layers

Used pretrained weights of the selected architecture

We freeze the pretrained filter in the early and middle layers and train only the filters in the deep layers.

visualizing_convnet_features

STEP 5: Loss Function and Optimizer

Loss Function: Cross Entropy Opimizer : Adam

STEP 6: Training and Validation

  • Trained 'layer 4' and 'fc' for 5 epochs.

loss_plots_1

  • Then trained only 'fc' for 3 more epochs

loss_plots_2

STEP 7: It's Testing Time !

  • Test Image

test_image

  • Output

    This Neural Network thinks that the given image belongs to >>> Memes <<< class with confidence of 95.21%

Useful Links