Skip to content

Commit

Permalink
Merge pull request #3612 from kashefy/tied_weights_ip_transpose
Browse files Browse the repository at this point in the history
Tied weights with transpose flag for InnerProduct layer
  • Loading branch information
jeffdonahue committed Feb 25, 2016
2 parents 3b2e733 + 8f847fa commit fe0f441
Show file tree
Hide file tree
Showing 5 changed files with 304 additions and 15 deletions.
1 change: 1 addition & 0 deletions include/caffe/layers/inner_product_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class InnerProductLayer : public Layer<Dtype> {
int N_;
bool bias_term_;
Blob<Dtype> bias_multiplier_;
bool transpose_; ///< if true, assume transposed weights
};

} // namespace caffe
Expand Down
42 changes: 33 additions & 9 deletions src/caffe/layers/inner_product_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ void InnerProductLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const int num_output = this->layer_param_.inner_product_param().num_output();
bias_term_ = this->layer_param_.inner_product_param().bias_term();
transpose_ = this->layer_param_.inner_product_param().transpose();
N_ = num_output;
const int axis = bottom[0]->CanonicalAxisIndex(
this->layer_param_.inner_product_param().axis());
Expand All @@ -27,10 +28,15 @@ void InnerProductLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
} else {
this->blobs_.resize(1);
}
// Intialize the weight
// Initialize the weights
vector<int> weight_shape(2);
weight_shape[0] = N_;
weight_shape[1] = K_;
if (transpose_) {
weight_shape[0] = K_;
weight_shape[1] = N_;
} else {
weight_shape[0] = N_;
weight_shape[1] = K_;
}
this->blobs_[0].reset(new Blob<Dtype>(weight_shape));
// fill the weights
shared_ptr<Filler<Dtype> > weight_filler(GetFiller<Dtype>(
Expand Down Expand Up @@ -80,7 +86,8 @@ void InnerProductLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype* top_data = top[0]->mutable_cpu_data();
const Dtype* weight = this->blobs_[0]->cpu_data();
caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasTrans, M_, N_, K_, (Dtype)1.,
caffe_cpu_gemm<Dtype>(CblasNoTrans, transpose_ ? CblasNoTrans : CblasTrans,
M_, N_, K_, (Dtype)1.,
bottom_data, weight, (Dtype)0., top_data);
if (bias_term_) {
caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, 1, (Dtype)1.,
Expand All @@ -97,8 +104,17 @@ void InnerProductLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const Dtype* top_diff = top[0]->cpu_diff();
const Dtype* bottom_data = bottom[0]->cpu_data();
// Gradient with respect to weight
caffe_cpu_gemm<Dtype>(CblasTrans, CblasNoTrans, N_, K_, M_, (Dtype)1.,
top_diff, bottom_data, (Dtype)1., this->blobs_[0]->mutable_cpu_diff());
if (transpose_) {
caffe_cpu_gemm<Dtype>(CblasTrans, CblasNoTrans,
K_, N_, M_,
(Dtype)1., bottom_data, top_diff,
(Dtype)1., this->blobs_[0]->mutable_cpu_diff());
} else {
caffe_cpu_gemm<Dtype>(CblasTrans, CblasNoTrans,
N_, K_, M_,
(Dtype)1., top_diff, bottom_data,
(Dtype)1., this->blobs_[0]->mutable_cpu_diff());
}
}
if (bias_term_ && this->param_propagate_down_[1]) {
const Dtype* top_diff = top[0]->cpu_diff();
Expand All @@ -110,9 +126,17 @@ void InnerProductLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
if (propagate_down[0]) {
const Dtype* top_diff = top[0]->cpu_diff();
// Gradient with respect to bottom data
caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, K_, N_, (Dtype)1.,
top_diff, this->blobs_[0]->cpu_data(), (Dtype)0.,
bottom[0]->mutable_cpu_diff());
if (transpose_) {
caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasTrans,
M_, K_, N_,
(Dtype)1., top_diff, this->blobs_[0]->cpu_data(),
(Dtype)0., bottom[0]->mutable_cpu_diff());
} else {
caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans,
M_, K_, N_,
(Dtype)1., top_diff, this->blobs_[0]->cpu_data(),
(Dtype)0., bottom[0]->mutable_cpu_diff());
}
}
}

Expand Down
31 changes: 25 additions & 6 deletions src/caffe/layers/inner_product_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ void InnerProductLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
caffe_gpu_axpy<Dtype>(N_, bias_multiplier_.cpu_data()[0],
this->blobs_[1]->gpu_data(), top_data);
} else {
caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasTrans, M_, N_, K_, (Dtype)1.,
caffe_gpu_gemm<Dtype>(CblasNoTrans,
transpose_ ? CblasNoTrans : CblasTrans,
M_, N_, K_, (Dtype)1.,
bottom_data, weight, (Dtype)0., top_data);
if (bias_term_)
caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, 1, (Dtype)1.,
Expand All @@ -36,8 +38,17 @@ void InnerProductLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const Dtype* top_diff = top[0]->gpu_diff();
const Dtype* bottom_data = bottom[0]->gpu_data();
// Gradient with respect to weight
caffe_gpu_gemm<Dtype>(CblasTrans, CblasNoTrans, N_, K_, M_, (Dtype)1.,
top_diff, bottom_data, (Dtype)1., this->blobs_[0]->mutable_gpu_diff());
if (transpose_) {
caffe_gpu_gemm<Dtype>(CblasTrans, CblasNoTrans,
K_, N_, M_,
(Dtype)1., bottom_data, top_diff,
(Dtype)1., this->blobs_[0]->mutable_gpu_diff());
} else {
caffe_gpu_gemm<Dtype>(CblasTrans, CblasNoTrans,
N_, K_, M_,
(Dtype)1., top_diff, bottom_data,
(Dtype)1., this->blobs_[0]->mutable_gpu_diff());
}
}
if (bias_term_ && this->param_propagate_down_[1]) {
const Dtype* top_diff = top[0]->gpu_diff();
Expand All @@ -49,9 +60,17 @@ void InnerProductLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
if (propagate_down[0]) {
const Dtype* top_diff = top[0]->gpu_diff();
// Gradient with respect to bottom data
caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, K_, N_, (Dtype)1.,
top_diff, this->blobs_[0]->gpu_data(), (Dtype)0.,
bottom[0]->mutable_gpu_diff());
if (transpose_) {
caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasTrans,
M_, K_, N_,
(Dtype)1., top_diff, this->blobs_[0]->gpu_data(),
(Dtype)0., bottom[0]->mutable_gpu_diff());
} else {
caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans,
M_, K_, N_,
(Dtype)1., top_diff, this->blobs_[0]->gpu_data(),
(Dtype)0., bottom[0]->mutable_gpu_diff());
}
}
}

Expand Down
5 changes: 5 additions & 0 deletions src/caffe/proto/caffe.proto
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,11 @@ message InnerProductParameter {
// all preceding axes are retained in the output.
// May be negative to index from the end (e.g., -1 for the last axis).
optional int32 axis = 5 [default = 1];
// Specify whether to transpose the weight matrix or not.
// If transpose == true, any operations will be performed on the transpose
// of the weight matrix. The weight matrix itself is not going to be transposed
// but rather the transfer flag of operations will be toggled accordingly.
optional bool transpose = 6 [default = false];
}

// Message that stores parameters used by LogLayer
Expand Down
Loading

0 comments on commit fe0f441

Please sign in to comment.