Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add & test regularizer class hierarchy: L1, L2 & skeleton of MaxNorm #113

Closed
wants to merge 98 commits into from
Closed

Add & test regularizer class hierarchy: L1, L2 & skeleton of MaxNorm #113

wants to merge 98 commits into from

Conversation

kloudkl
Copy link
Contributor

@kloudkl kloudkl commented Feb 15, 2014

The design and implementation in this pull request is highly inspired by the counterparts in DeCAF. Thank the original author(s).

The Regularizer has not been integrated in the Layer::Backward* methods to play its role and the MaxNorm regularizer is only half-baked yet. Those functionalities would be finished based on the feedbacks of the community.

Related issues:
#60: Sparsity penalties for unsupervised learning
#109: Alternative to weight decay: max column norm

return (Dtype(0) < val) - (val < Dtype(0));
}

#define MAKE_REGULARIZER_CLASS(type) \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MAKE_SIMPLE_REGULARIZER_CLASS

since it only covers the very basic declaration.

@kloudkl
Copy link
Contributor Author

kloudkl commented Feb 20, 2014

Integration is finished. Multiple regularizers can be used in one layer.

@aravindhm, does this PR fit the sparse convolutional autoencoders architecture you are working on (#60)?

@tdomhan, if this is ok to you, please add the max column norm you wanted in #109 after it is merged.

for (int i = 0; i < layer_param_.regularizer_size(); ++i) {
regularizers_[i]->Regularize(bottom->at(0));
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this cold will never be executed, because you already returned in the switch statement.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Fixed in fdb67fc. Also add return value of Regularize.

@tdomhan
Copy link
Contributor

tdomhan commented Feb 20, 2014

thanks for adding this, however I would not put the regularization in the backward function of the layers.
It's something that should be executed after the update. So the network would be a more appropriate location.
Especially because that's where the current regularization, weight decay, happens. Which should also be integrated into the same class structure.

@kloudkl
Copy link
Contributor Author

kloudkl commented Feb 20, 2014

At first, I tried to regularize the weight parameters of the network. But I found that in DeCAF, regularization is executed in the backward method of convolution layer, innerproduct layer, and deconvolution.

If each parameter blob is regularized independently, the two methods are almost equivalent. The API of the Regularizer::Regularize(Blob* bottom) makes it easy to be placed in SGDSolver::ComputeUpdateValue before or after the weight decay.

With a counterpart of DeCAF's AutoencoderLossLayer in place, I will add a demo_sparse_autoencoder to test which is more effective. I guess that there will be little difference.

@tdomhan
Copy link
Contributor

tdomhan commented Feb 20, 2014

In any case weight decay should be just one of many regularizers instead of first applying weight-decay and then another regularizer.

@aravindhm
Copy link

@kloudkl This fits with the sparse convolutional auto-encoder. I'm still tweaking the learning rate to make that work. Thanks for this feature!

@tdomhan The weight decay just regularizes the parameters. The regularizer as loss regularizes the features. The latter should change the backprop gradient and effect the parameters of layers below. It is more convenient in the backward pass because it depends on the type of the layer preceeding the regularizer as loss layer. In particular, L1 regularization on the output of a fully connected network will effect the W matrix differently from a situation in which the fully connected network is followed by a tanh followed by a L1 regularization. In my own attempt at this feature, I tried putting the regularizer in the layer itself (layer "has a" regularization on its parameters/blobs) but this didn't work as the code changes completely with layer type.

@tdomhan
Copy link
Contributor

tdomhan commented Feb 20, 2014

@aravindhm That makes sense. I guess we are talking about two different features here. One being a new layer that backpropagates some regularization metric. And then secondly, what I'm interested in, replacing weight decay by other regularizers.

For the RegularizationAsLoss layer it of course makes perfect sense to put this in the backward pass. However in this case you just need to backpropagate from the RegularizationAsLoss layer and not from all the other layers, like the inner product or conv layers, or is that wrong?
I'm asking because the regularizers are added to the backpropagation of every layer in caffe and not just the RegularizationAsLoss layer.

I would argue that regularization the parameters should go the network, where it is right now. This of course doesn't mean that there can't be a RegularizationAsLoss layer with a regularizer in the backprop step.

@aravindhm
Copy link

@tdomhan Agree with everything.

@kloudkl
Copy link
Contributor Author

kloudkl commented Feb 21, 2014

The Regularize methods do not change the weight parameters of the layer being regularized. Only the diff blobs are changed. Thus the order of applying weight decay or any number of other regularizers makes no difference.

The regularization results of placing RegularizationAsLoss which simply wraps a set of Regularizers after a layer or embedding the same set of Regularizers in the layer should be the same too.

In void SGDSolver::ComputeUpdateValue, the weight decays are the products of the solver range global weight decay and the parameter wise local ones.

vector<float>& net_params_weight_decay = this->net_->params_weight_decay();
...
  for (int param_id = 0; param_id < net_params.size(); ++param_id) {
      // Compute the value to history, and then copy them to the blob's diff.
      Dtype local_rate = rate * net_params_lr[param_id];
      Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
      caffe_axpby(net_params[param_id]->count(), local_rate,
          net_params[param_id]->cpu_diff(), momentum,
          history_[param_id]->mutable_cpu_data());
      if (local_decay) {
        // add weight decay
        caffe_axpy(net_params[param_id]->count(),
            local_decay * local_rate,
            net_params[param_id]->cpu_data(),
            history_[param_id]->mutable_cpu_data());
      }
      // copy
      caffe_copy(net_params[param_id]->count(),
          history_[param_id]->cpu_data(),
          net_params[param_id]->mutable_cpu_diff());
    }

Where do the local ones come from?
In Net::GetLearningRateAndWeightDecay, the network just collects them from the layers.

 if (layers_[i]->layer_param().weight_decay_size()) {
      CHECK_EQ(layers_[i]->layer_param().weight_decay_size(),
          layer_blobs.size());
      for (int j = 0; j < layer_blobs.size(); ++j) {
        float local_decay = layers_[i]->layer_param().weight_decay(j);
        CHECK_GE(local_decay, 0.);
        params_weight_decay_.push_back(local_decay);
      }
    } else {
      for (int j = 0; j < layer_blobs.size(); ++j) {
        params_weight_decay_.push_back(1.);
      }
    }

src/caffe/proto/caffe.proto

message LayerParameter {
// The weight decay that is multiplied on the global weight decay.
  repeated float weight_decay = 52;

The motivation for the current design is perhaps not wanting to scatter the weight decay codes around the layers. The Backward method of the Layer base class is another solution to the concern.

The original author has the most thorough understanding of the issue. @Yangqing, would you like to make a comment?

@shelhamer
Copy link
Member

@kloudkl please rebase this for further review. @Yangqing could you comment on the regularization vs. weight decay and learning rate choices made in this implementation?

Note this is slated for 1.1 release.

Thanks all.

@kloudkl
Copy link
Contributor Author

kloudkl commented Mar 15, 2014

Shouldn't this be replaced by a new PR targeting at dev?

@shelhamer
Copy link
Member

Yes, please open a new PR against dev, then close this with reference to the new PR. Thanks.

kloudkl and others added 25 commits March 23, 2014 21:24
Add more convenience math functions and all tests pass
@kloudkl kloudkl mentioned this pull request Mar 25, 2014
@kloudkl
Copy link
Contributor Author

kloudkl commented Mar 25, 2014

Closing. This PR is replaced by #258.

@kloudkl kloudkl closed this Mar 25, 2014
@shelhamer shelhamer removed this from the 1.1 milestone Mar 28, 2014
thatguymike pushed a commit to thatguymike/caffe that referenced this pull request Mar 16, 2016
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants