Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

small lesson about problems during train my own dataset #128

Closed
LongxingTan opened this issue Dec 17, 2019 · 23 comments
Closed

small lesson about problems during train my own dataset #128

LongxingTan opened this issue Dec 17, 2019 · 23 comments
Labels
training Training Related Questions

Comments

@LongxingTan
Copy link

LongxingTan commented Dec 17, 2019

thanks zzh8829 for the code sharing, really nice writing, I like it
when i use it to try training my own dataset, i have some problems, that's how i solve them.
hope this could save some time for others.

  • nan loss
  1. nan loss, i first change the learning rate smaller
  2. i found that the data input labelled by vott and labelImg is different. so make sure the input box is right (without nan and the box is smaller than the width and height), and check carefully the box format is x1,y1,x2,y2, or x,y,w,h, or x1/width,y1/height,x2/width,y2/height
  • loss is unbelievable large
  1. the first step loss is ok. but after 2nd step, the loss is very large and can't converge any more. i change the backbone part according to other repositories of yolov3. and it solves
  • hard to convergence
  1. remove the sigmoid operator of class_prob_loss
  2. add the conf_focal=tf.pow(true_obj-pred_obj , 2) as a multiplier in confidence_loss
  • Resize the image by resize or pad
  1. I checked the process to train VOC2012. if you use the voc2012.py to save the tf-record, there is no problem. In object detection, if you resize the image with Pad, then you have to pad the labelled box at the same time. But if you use resize function in cv or tf, and the label is relative (0,1), then no necessary to adjust it.
@zzh8829
Copy link
Owner

zzh8829 commented Dec 18, 2019

thanks for the update, I will add these to the readme, a lot of people are having nan loss problem,

@AnaRhisT94
Copy link

AnaRhisT94 commented Dec 18, 2019

@zzh8829 I've provided another detailed explanation here on nan lose, its found almost in every nan loss issue here. Hopefully you can add it too

@zzh8829
Copy link
Owner

zzh8829 commented Dec 18, 2019

I added a section in read me with @LongxingTan 's insight, @AnaRhisT94 is it possible for you to make a pull request on the readme file with your detailed explanation? I am not sure which one specifically you are referring to. I really appreciate you helping other people to solve their training problems. It would be great if we can share that knowledge to everyone else too.

@antongisli
Copy link

antongisli commented Dec 19, 2019 via email

@zzh8829
Copy link
Owner

zzh8829 commented Dec 20, 2019

@antongisli for sure that would be amazing

@zzh8829 zzh8829 added the training Training Related Questions label Dec 20, 2019
@zzh8829
Copy link
Owner

zzh8829 commented Dec 21, 2019

I compiled a full tutorial at https://github.com/zzh8829/yolov3-tf2/blob/master/docs/training_voc.md on custom training. welcome to add your learning on it

@zzh8829 zzh8829 closed this as completed Dec 21, 2019
@nicolefinnie
Copy link

nicolefinnie commented Dec 31, 2019

@LongxingTan @zzh8829
Thanks for your insight. I'm just wondering how you came up with the idea of conf_focal? Is there any mathematical reason behind it?

I applied this idea to the code as follows and it did converge a lot faster. And we can see obj_loss was many orders of magnitude to start with and made the total loss very imbalanced. Bringing obj_loss down would probably make its contribution to the total loss more fair.

conf_focal = tf.squeeze(tf.pow(true_obj-pred_obj, 2), -1)
obj_loss = binary_crossentropy(true_obj, pred_obj) * conf_focal

Can you shed some light on it?

If I'm not mistaken, this is the focal loss for gamma=2.
I'm also thinking of using focal loss instead of cross entropy, this paper - Focal Loss for Dense Object Detectio shows some improvement compared to cross entropy.

@nicolefinnie
Copy link

@zzh8829 I'm just wondering why you chose sigmoid for class_prob_loss? It doesn't make much sense when there are more than 2 classes for sparse_categorical_crossentropy. From my experiment, if you pass sigmoid, the trained model actually acts like a binary classification model that can only detects two strongest classes with the most training data (others are possbily ignored due to a low objectness score). I guess it might work fine the train data is equally distributed among all classes.

@LongxingTan
Copy link
Author

LongxingTan commented Jan 19, 2020

@nicolefinnie
Hi nicolefinnie, good to hear that it might help you a little.
what i use in my code is similar

conf_focal=tf.pow(obj_mask-tf.squeeze(tf.sigmoid(pred_obj),-1),2)
loss_obj=tf.squeeze(tf.nn.sigmoid_cross_entropy_with_logits(true_obj,pred_obj),axis=-1)
loss_obj=conf_focal*(obj_mask*loss_obj + noobj_mask*loss_obj)  # batch * grid * grid * anchors_per_scale

you are right that this is focal loss for gamma=2.
i also read this paper to know that it helps to improve in detect the difficult box, and imbalance class box. To be honest, i have no idea why it help converge faster.
i guess, the coefficient is always less than 1, so it shows that loss itself is smaller, and smaller loss let it look like converge faster, actually it's not.
If you have any idea, i am happy to know that.

@nicolefinnie
Copy link

@nicolefinnie
Hi nicolefinnie, good to hear that it might help you a little.
what i use in my code is similar

conf_focal=tf.pow(obj_mask-tf.squeeze(tf.sigmoid(pred_obj),-1),2)
loss_obj=tf.squeeze(tf.nn.sigmoid_cross_entropy_with_logits(true_obj,pred_obj),axis=-1)
loss_obj=conf_focal*(obj_mask*loss_obj + noobj_mask*loss_obj)  # batch * grid * grid * anchors_per_scale

you are right that this is focal loss for gamma=2.
i also read this paper to know that it helps to improve in detect the difficult box, and imbalance class box. To be honest, i have no idea why it help converge faster.
i guess, the coefficient is always less than 1, so it shows that loss itself is smaller, and smaller loss let it look like converge faster, actually it's not.
If you have any idea, i am happy to know that.

Oops, @LongxingTan thanks! I overlooked your reply. Adding sigmoid() is a good idea to redistribute the loss. I think adding the focal loss does not necessarily make the loss to converge faster, but it makes sense in your case, because the 2nd term of your loss noobj_mask*loss_obj (the loss for false positives) may be many orders of magnitude greater than the 1st term (the loss for true positives) after you redistributed your loss, so you'd make both terms to be more equal. (See my reference below)

But if we add the focal loss, the first term will get more weight and get penalized more, so adding the focal loss actually changes the focus (no puns intended) and it will focus on the loss of true positives and false positives go unpunished. However, I jumped into this conclusion from my experimental results using your loss and my loss. DL is often not explainable though we try hard to do so.

Reference:
original impl. from the paper's author, both terms are equally treated.

// calculating the best IOU, skip...

// calculate the loss of false positives, ignore true positives when it crosses the ignore threshold, equivalent to `(1-obj_mask)*ignore_mask*obj_loss` in this repo
l.delta[obj_index] = 0 - l.output[obj_index];
if (best_iou > l.ignore_thresh) {
    l.delta[obj_index] = 0;
}
// calculate the loss of the true positives above the threshold, equivalent to `obj_mask*obj_loss` in this repo
if (best_iou > l.truth_thresh) {
    l.delta[obj_index] = 1 - l.output[obj_index];
}

@zazeng
Copy link

zazeng commented Feb 18, 2020

@LongxingTan Is ignore_thresh mask already accounted for in noobj_mask. If not is there a reason for not using it? Thanks!

@walidayech
Copy link

The validation losses are zeros from the first to the end iteration. The detection has no boxes into the output.jpg when I use my training model for 7 epochs (yolov3_train_7.tf)

@procule
Copy link

procule commented Mar 7, 2020

There are so many different codes and formats out there and most are just to get a PoC.

check carefully the box format is x1,y1,x2,y2, or x,y,w,h, or x1/width,y1/height,x2/width,y2/height

What does that mean ? Which one ?
Normalized YOLO format (x,y,w,h all normalized):

dw = size[1]
dh = size[0]
x = (x1 + x2)/2.0
y = (y1 + y2)/2.0
w = x2 - x1
h = y2 - y1
x = x
w = w
y = y
h = h
output = str(x/dw) + " " + str(y/dh) + " " + str(w/dw) + " " + str(h/dh)

Or as the FAQ is saying, is it what I may understand:

w = size[1]
h = size[0]
xmin = x1/w
ymax = y1/h
xmax = x2/w
ymin = y2/h

?

@procule
Copy link

procule commented Mar 7, 2020

Got it, it's the second one

@Robin2091
Copy link

Robin2091 commented Mar 15, 2020

@LongxingTan Hi, I am experiencing the problem of the validation loss exploding and then not converging properly on my custom dataset. In the first few epochs, validation loss is reasonable. Then it explodes to some large number like 2000000. You mentioned that you made changes to the backbone? Can you guide me with making those changes to see if they can help my problem?

Thank you.

Edit: I would also like to add that I successfully trained the yolov3 tiny model on my custom dataset. This problem only seems to be happening with yolov3.

@LongxingTan
Copy link
Author

@Robin2091

yeah, it looks like the same phenomenon as my situation.
I can copy my code here, could you please check the difference, to be honest, i maybe change a lot little by little, so don't remember exactly i changed where to solve this. hope it may give you some hint to solve your problem,


import tensorflow as tf


def conv_block(inputs, kernel_size, filters, strides=1, padding='same', downsample=False, activate=True, bn=True):
    # basic block
    if downsample:
        inputs = tf.keras.layers.ZeroPadding2D(((1, 0), (1, 0)))(inputs)
        padding = 'valid'
        strides = 2
    else:
        strides = 1
        padding = 'same'

    conv = tf.keras.layers.Conv2D(filters=filters, kernel_size=kernel_size, strides=strides, padding=padding,
                                  use_bias=not bn, kernel_regularizer=tf.keras.regularizers.l2(0.0005),
                                  kernel_initializer=tf.random_normal_initializer(stddev=0.01),
                                  bias_initializer=tf.constant_initializer(0.))(inputs)
    if bn:
        conv = BatchNormalization()(conv)
    if activate:
        conv = tf.nn.leaky_relu(conv, alpha=0.1)
    return conv


class BatchNormalization(tf.keras.layers.BatchNormalization):
    """
    tf.keras.layers.BatchNormalization doesn't work very well for transfer learning,
    "Frozen state" and "inference mode" are two separate concepts.
    `layer.trainable = False` is to freeze the layer, so the layer will use
    stored moving `var` and `mean` in the "inference mode", and both `gama`
    and `beta` will not be updated !
    """
    def call(self, x, training=False):
        if not training:
            training = tf.constant(False)
        training = tf.logical_and(training, self.trainable)
        return super().call(x, training)
    
    
def residual_block(inputs, filter_num1, filter_num2):
    shortcut = inputs
    conv = conv_block(inputs, 1, filter_num1)
    conv = conv_block(conv, 3, filter_num2)
    residual_output = shortcut + conv
    return residual_output


def upsample(inputs):
    return tf.image.resize(inputs, (inputs.shape[1] * 2, inputs.shape[2] * 2), method='nearest')
import tensorflow as tf
from model.backbones.common import *


class DarkNet(object):
    def __init__(self):
        pass

    def __call__(self,name):
        x=inputs = tf.keras.layers.Input([416, 416, 3])
        x=conv_block(x, 3, 32)  # => batch_size * 416 * 416 * 32
        x=conv_block(x, 3, 64, downsample=True)  # => batch_size * 208 * 208 * 64

        for _ in range(1):
            x=residual_block(x,32,64)  # => batch_size * 208 * 208 * 64

        x=conv_block(x,3,128,downsample=True)  # => batch_size * 104 * 104 * 128

        for _ in range(2):
            x=residual_block(x,64,128)  # => batch_size * 104 * 104 * 128

        x=conv_block(x,3,256,downsample=True)  # => batch_size * 52 * 52 * 256

        for _ in range(8):
            x=residual_block(x,128,256)  # => batch_size * 52 * 52 * 256
        route_1=x  # => batch_size * 52 *52 * 256

        x=conv_block(x,3,512,downsample=True)  # => batch_size * 26 * 26 * 512

        for _ in range(8):
            x=residual_block(x,256,512)
        route_2=x  # => batch_size * 26 * 26 * 512

        x=conv_block(x,3,1024,downsample=True)

        for _ in range(4):
            x=residual_block(x,512,1024)
        route_3=x  # => batch_size * 13 * 13 * 1024

        return tf.keras.Model(inputs,(route_1,route_2,route_3),name=name)
import numpy as np
import tensorflow as tf
from model.backbones.darknet53 import DarkNet
from model.backbones.common import conv_block,upsample


# https://github.com/HKU-ICRA/YoloV3_TF2.0/blob/master/yolo.py

class YoloPreOut(object):
    def __init__(self):
        pass

    def __call__(self, input, skip_filters=None,ratio=1,name=None):
        if isinstance(input,tuple):
            model_inputs=tf.keras.layers.Input(input[0].shape[1:]),tf.keras.layers.Input(input[1].shape[1:])
            x,skip=model_inputs

            x=conv_block(x,1,skip_filters)
            x=upsample(x)
            x=tf.concat([x,skip],axis=-1)
        else:
            x=model_inputs=tf.keras.layers.Input(input.shape[1:])

        x=conv_block(x,1,512//ratio)
        x=conv_block(x,3,1024//ratio)
        x=conv_block(x,1,512//ratio)
        x=conv_block(x,3,1024//ratio)
        x=conv_block(x,1,512//ratio)
        return tf.keras.Model(model_inputs,x,name=name)(input)


class YoloOutput(object):
    def __init__(self):
        pass

    def __call__(self,input,kernel_sizes,filters,grid_channels,name):
        x=tf.keras.layers.Input(input.shape[1:])
        output=conv_block(x,kernel_sizes,filters)
        output=conv_block(output,1,3*grid_channels,activate=False,bn=False)
        output=tf.reshape(output,(-1,output.shape[1],output.shape[2],3,grid_channels)) 
        return tf.keras.Model(x,output,name=name)(input)

    
class YoloDecode(object):
    def __init__(self):
        pass

    def __call__(self, anchors, image_input_size, num_classes,inputs_shape, name='decode'):
        '''
        x: batch * grid_size * grid_size * 3 * 5+num_classes  Todo: format is different
        '''
        x=tf.keras.layers.Input(inputs_shape)

        box_xy, box_wh, objectness, class_probs = tf.split(x, (2, 2, 1, num_classes), axis=-1)
        box_xy = tf.sigmoid(box_xy)
        objectness = tf.sigmoid(objectness)
        class_probs = tf.sigmoid(class_probs)

        grid_size = inputs_shape[0]
        grid_xy = tf.meshgrid(tf.range(grid_size), tf.range(grid_size))
        grid_xy = tf.cast(tf.expand_dims(tf.stack(grid_xy, axis=-1), axis=2), tf.float32)
        strides = tf.cast(image_input_size / grid_size, tf.float32)

        box_xy = (box_xy + grid_xy) * strides /image_input_size
        box_wh = tf.exp(box_wh) * anchors / image_input_size

        box_x1y1 = box_xy - box_wh / 2
        box_x2y2 = box_xy + box_wh / 2
        bbox = tf.concat([box_x1y1, box_x2y2], axis=-1)
        outputs=tf.concat([bbox, objectness, class_probs], axis=-1)

        return tf.keras.Model(x,outputs,name=name)
    
    
def yolo_nms(outputs,num_classes,iou_threshold,score_threshold,nms_max_bbox):
    # boxes, conf, type
    b, c, t = [], [], []

    for o in outputs:
        b.append(tf.reshape(o[...,0:4], [tf.shape(o)[0], -1, 4]))
        c.append(tf.reshape(o[...,4:5], (tf.shape(o)[0], -1, 1)))
        t.append(tf.reshape(o[...,5:], (tf.shape(o)[0], -1, num_classes)))

    bbox = tf.concat(b, axis=1)
    confidence = tf.concat(c, axis=1)
    class_probs = tf.concat(t, axis=1)

    if num_classes>1:
        scores = confidence * class_probs
    else:
        scores = confidence
    boxes, scores, classes, valid_detections = tf.image.combined_non_max_suppression(
        boxes=tf.reshape(bbox, (tf.shape(bbox)[0], -1, 1, 4)),
        scores=tf.reshape(scores, (tf.shape(scores)[0], -1, tf.shape(scores)[-1])),
        max_output_size_per_class=nms_max_bbox,
        max_total_size=nms_max_bbox,
        iou_threshold=iou_threshold,
        score_threshold=score_threshold
    )

    return boxes, scores, classes, valid_detections


class YoloNet(object):
    def __init__(self,params):
        self.params=params
        self.darknet=DarkNet()
        self.yolo_preout=YoloPreOut()
        self.yolo_output=YoloOutput()
        self.decode=YoloDecode()
        self.anchors = np.array(params['anchors'], dtype=np.float32).reshape(9, 2)
        self.anchors_mask = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]

    def __call__(self,training=False):
        x = tf.keras.layers.Input([416, 416, 3])
        x_52, x_26, x_13 = self.darknet(name='yolo_darknet')(x)

        grid_channels = self.params['num_classes'] + 5
        out_13 = self.yolo_preout(input=x_13, name='preout_13')  # => batch_size * 13 * 13 * 512
        output_13 = self.yolo_output(input=out_13,
                                     kernel_sizes=3,
                                     filters=1024,
                                     grid_channels=grid_channels,
                                     name='output_13')

        # => batch_size * 26 * 26 *256
        out_26 = self.yolo_preout(input=(out_13, x_26), skip_filters=256, ratio=2, name='preout_26')
        output_26 = self.yolo_output(input=out_26,
                                     kernel_sizes=3,
                                     filters=512,
                                     grid_channels=grid_channels,
                                     name='output_26')

        # => batch * 52 * 52 * 128
        out_52 = self.yolo_preout(input=(out_26, x_52), skip_filters=128, ratio=4, name='preout_52')
        output_52 = self.yolo_output(input=out_52,
                                     kernel_sizes=3,
                                     filters=256,
                                     grid_channels=grid_channels,
                                     name='output_52')

        output_tensors = (output_13, output_26, output_52)
        if training:
            return tf.keras.Model(x,output_tensors,name='yolo_net')
        
        bbox_tensors = []
        for i, feature_map in enumerate(output_tensors):
            bbox_tensor = self.decode(self.anchors[self.anchors_mask[i]],
                                      self.params['image_input_size'],
                                      self.params['num_classes'],
                                      feature_map.get_shape().as_list()[1:],
                                      name="decode_{}".format(i))(feature_map)
            bbox_tensors.append(bbox_tensor)

        output = yolo_nms(bbox_tensors,
                          num_classes=self.params['num_classes'],
                          iou_threshold=self.params['iou_threshold'],
                          score_threshold=self.params['score_threshold'],
                          nms_max_bbox=self.params['nms_max_bbox'])
        return tf.keras.Model(inputs=x, outputs=output,name='yolo_predict')

@Robin2091
Copy link

Robin2091 commented Mar 16, 2020

@LongxingTan Thank you so much for the code. I will look into it and try to locate the error. Also, if you have time, can you check the new issue I created #206 . It discusses the problem I am experiencing and gives a log of my training output for further inspection.

Edit: Quick question, did you get the backbone code from a single repository or did you combine them from multiple?

Thank you for your help!

@XeonAngel
Copy link

  • hard to convergence

    remove the sigmoid operator of class_prob_loss
    add the conf_focal=tf.pow(true_obj-pred_obj , 2) as a multiplier in confidence_loss

@LongxingTan where exactly is class_prob_loss located and how to remove the sigmoid operator?
Where I can find confidence_loss to change the multiplier into this:

> conf_focal=tf.pow(obj_mask-tf.squeeze(tf.sigmoid(pred_obj),-1),2)
> loss_obj=tf.squeeze(tf.nn.sigmoid_cross_entropy_with_logits(true_obj,pred_obj),axis=-1)
> loss_obj=conf_focal*(obj_mask*loss_obj + noobj_mask*loss_obj)  # batch * grid * grid * anchors_per_scale

@LongxingTan
Copy link
Author

  • hard to convergence
    remove the sigmoid operator of class_prob_loss
    add the conf_focal=tf.pow(true_obj-pred_obj , 2) as a multiplier in confidence_loss

@LongxingTan where exactly is class_prob_loss located and how to remove the sigmoid operator?
Where I can find confidence_loss to change the multiplier into this:

> conf_focal=tf.pow(obj_mask-tf.squeeze(tf.sigmoid(pred_obj),-1),2)
> loss_obj=tf.squeeze(tf.nn.sigmoid_cross_entropy_with_logits(true_obj,pred_obj),axis=-1)
> loss_obj=conf_focal*(obj_mask*loss_obj + noobj_mask*loss_obj)  # batch * grid * grid * anchors_per_scale

@XeonAngel

maybe like this, when parse/decode the network output to calculate the loss, we have to parse the value by anchor to the value by real coordinates. So in the decode function, there is sigmoid to change the scale of value.
But it depends if it really helps.

in this repo, you can find yolo_boxes function in models.py

box_xy = tf.sigmoid(box_xy)
objectness = tf.sigmoid(objectness)
class_probs = tf.sigmoid(class_probs)
pred_box = tf.concat((box_xy, box_wh), axis=-1)  # original xywh for loss

Yololoss function in models.py

obj_loss = binary_crossentropy(true_obj, pred_obj)
obj_loss = obj_mask * obj_loss + \
            (1 - obj_mask) * ignore_mask * obj_loss
# TODO: use binary_crossentropy instead
 class_loss = obj_mask * sparse_categorical_crossentropy(
            true_class_idx, pred_class)

the last item class_loss is the class_prob_loss.

@yjwong1999
Copy link

Hi @LongxingTan, may I know what hyperparameters did you used to train the model?
I am training the model using the VOC2012 dataset.

I'm using:
learning rate: 1e-3
batch size: 8
epochs: 50

In my implementation, I have to lower down the confidence threshold to 0.1.
I'm not sure if the batch size has to be bigger, or should I train the model longer.

Many thanks in advance!

@yyccR
Copy link

yyccR commented Sep 27, 2021

Actually i have try these code to deal the class unbalanced problem, and it work well.
first get the negative obj num and positive obj num:

obj_mask = tf.squeeze(true_obj, -1)
positive_num = tf.cast(tf.reduce_sum(obj_mask), tf.int32) + 1
negative_num = 10 * positive_num

then update the ignore_mask using the negative num:

ignore_mask = tf.cast(best_iou < ignore_thresh, tf.float32)
ignore_num = tf.cast(tf.reduce_sum(ignore_mask),tf.int32)
if ignore_num > negative_num:
    neg_inds = tf.random.shuffle(tf.where(ignore_mask))[:negative_num]
    neg_inds = tf.expand_dims(neg_inds, axis=1)
    ones = tf.ones(tf.shape(neg_inds)[0], tf.float32)
    ones = tf.expand_dims(ones, axis=1)
    ignore_mask = tf.zeros_like(ignore_mask, tf.float32)
    ignore_mask = tf.tensor_scatter_nd_add(ignore_mask, neg_inds, ones)

final obj_loss is:

conf_focal = tf.pow(obj_mask - tf.squeeze(pred_obj, -1), 2)
obj_loss = tf.keras.losses.binary_crossentropy(true_obj, pred_obj)
obj_loss = conf_focal * (obj_mask * obj_loss + (1 - obj_mask) * ignore_mask * obj_loss)

@yjwong1999
Copy link

Hi @yyccR
I tried out your code, it helps in dealing with class imbalance issues!
May I know what is this code for? Is it a form of focal loss?

@yyccR
Copy link

yyccR commented Oct 15, 2021

@yjwong1999
conf_focal it is the form of focal loss, and i also add the ignore_mask to balanced the positive nums and negative nums.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
training Training Related Questions
Projects
None yet
Development

No branches or pull requests