Skip to content

Commit

Permalink
Merge pull request #1 from neubig/lstm-node
Browse files Browse the repository at this point in the history
Fix several compile errors and extra scratch memory
  • Loading branch information
msperber authored Jul 21, 2017
2 parents 6e516cc + 3e54f0b commit 3d1bfe3
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 49 deletions.
20 changes: 10 additions & 10 deletions dynet/matrix-multiply.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ inline void MatrixMultiply(const Device_GPU & dev, const Tensor& l, const Tensor
// -> [x, z*b] = [x, y], [y, z*b]
CUBLAS_CHECK(cublasSgemm(dev.cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N,
y.d.rows(), y.d.cols() * y.d.batch_elems(), l.d.cols(),
kSCALAR_ONE,
dev.kSCALAR_ONE,
l.v, l.d.rows(),
r.v, r.d.rows(),
acc_scalar, y.v, y.d.rows()));
Expand All @@ -30,7 +30,7 @@ inline void MatrixMultiply(const Device_GPU & dev, const Tensor& l, const Tensor
for(unsigned b = 0; b < y.d.bd; ++b) {
CUBLAS_CHECK(cublasSgemm(dev.cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N,
y.d.rows(), y.d.cols(), l.d.cols(),
kSCALAR_ONE,
dev.kSCALAR_ONE,
l.batch_ptr(b), l.d.rows(),
r.batch_ptr(b), r.d.rows(),
acc_scalar, y.batch_ptr(b), y.d.rows()));
Expand Down Expand Up @@ -78,18 +78,18 @@ inline void MatrixTranspMultiplyAcc(const dynet::Device_GPU & dev, const dynet::
if(l.d.bd == 1 && y.d.bd == r.d.bd) {
CUBLAS_CHECK(cublasSgemm(dev.cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N,
y.d.rows(), y.d.cols()*y.d.batch_elems(), l.d.rows(),
kSCALAR_ONE,
dev.kSCALAR_ONE,
l.v, l.d.rows(),
r.v, r.d.rows(),
kSCALAR_ONE, y.v, y.d.rows()));
dev.kSCALAR_ONE, y.v, y.d.rows()));
} else {
for(int b = 0; b < max_b; ++b)
CUBLAS_CHECK(cublasSgemm(dev.cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N,
y.d.rows(), y.d.cols(), l.d.rows(),
kSCALAR_ONE,
dev.kSCALAR_ONE,
l.batch_ptr(b), l.d.rows(),
r.batch_ptr(b), r.d.rows(),
kSCALAR_ONE, y.batch_ptr(b), y.d.rows()));
dev.kSCALAR_ONE, y.batch_ptr(b), y.d.rows()));
}
}

Expand All @@ -112,18 +112,18 @@ inline void MatrixMultiplyTranspAcc(const dynet::Device_GPU & dev, const dynet::
if(y.d.bd == 1 && (l.d.bd == r.d.bd)) {
CUBLAS_CHECK(cublasSgemm(dev.cublas_handle, CUBLAS_OP_N, CUBLAS_OP_T,
y.d.rows(), y.d.cols(), l.d.cols() * l.d.batch_elems(),
kSCALAR_ONE,
dev.kSCALAR_ONE,
l.v, l.d.rows(),
r.v, r.d.rows(),
kSCALAR_ONE, y.v, y.d.rows()));
dev.kSCALAR_ONE, y.v, y.d.rows()));
} else {
for(int b = 0; b < max_b; ++b)
CUBLAS_CHECK(cublasSgemm(dev.cublas_handle, CUBLAS_OP_N, CUBLAS_OP_T,
y.d.rows(), y.d.cols(), l.d.cols(),
kSCALAR_ONE,
dev.kSCALAR_ONE,
l.batch_ptr(b), l.d.rows(),
r.batch_ptr(b), r.d.rows(),
kSCALAR_ONE, y.batch_ptr(b), y.d.rows()));
dev.kSCALAR_ONE, y.batch_ptr(b), y.d.rows()));
}
}
# else
Expand Down
68 changes: 34 additions & 34 deletions dynet/nodes-lstm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,22 +122,22 @@ namespace dynet {
// [Wx_o] [do . o_t . (1-o_t)]
// [Wx_g] [dg . (1 - g_t^2)]
// note: here Wx is broadcasted over batches
// allocate scratch mem mult_l, mult_r
AlignedMemoryPool* scratch_allocator = fx.device->pools[(int)DeviceMempool::SCS];
Tensor mult_r(Dim({hidden_dim*4, 1},batch_size), nullptr, fx.device, fx.mem_pool);
mult_r.v = static_cast<float*>(scratch_allocator->allocate(mult_r.d.size() * sizeof(float)));
// allocate scratch mem mult_l, mult_r
AlignedMemoryPool* scratch_allocator = fx.device->pools[(int)DeviceMempool::SCS];
Tensor mult_r(Dim({hidden_dim*4, 1},batch_size), nullptr, fx.device, fx.mem_pool);
mult_r.v = static_cast<float*>(scratch_allocator->allocate(mult_r.d.size() * sizeof(float)));

// mult_r = [di . i_t . (1-i_t)]
// [df . f_t . (1-f_t)]
// [do . o_t . (1-o_t)]
// [dg . (1 - g_t^2)]
mult_r.tb<2>().slice(indices_mat_i, sizes_mat_3).device(*dev.edevice) = dEdf.tb<2>().slice(indices_mat_i, sizes_mat_3) * fx.tb<2>().slice(indices_mat_i, sizes_mat_3) * (fx.tb<2>().slice(indices_mat_i, sizes_mat_3).constant(1) - fx.tb<2>().slice(indices_mat_i, sizes_mat_3));
mult_r.tb<2>().slice(indices_mat_g, sizes_mat_1).device(*dev.edevice) = dEdf.tb<2>().slice(indices_mat_g, sizes_mat_1) * (fx.tb<2>().slice(indices_mat_g, sizes_mat_1).constant(1) - fx.tb<2>().slice(indices_mat_g, sizes_mat_1).square());
// mult_r = [di . i_t . (1-i_t)]
// [df . f_t . (1-f_t)]
// [do . o_t . (1-o_t)]
// [dg . (1 - g_t^2)]
mult_r.tb<2>().slice(indices_mat_i, sizes_mat_3).device(*dev.edevice) = dEdf.tb<2>().slice(indices_mat_i, sizes_mat_3) * fx.tb<2>().slice(indices_mat_i, sizes_mat_3) * (fx.tb<2>().slice(indices_mat_i, sizes_mat_3).constant(1) - fx.tb<2>().slice(indices_mat_i, sizes_mat_3));
mult_r.tb<2>().slice(indices_mat_g, sizes_mat_1).device(*dev.edevice) = dEdf.tb<2>().slice(indices_mat_g, sizes_mat_1) * (fx.tb<2>().slice(indices_mat_g, sizes_mat_1).constant(1) - fx.tb<2>().slice(indices_mat_g, sizes_mat_1).square());

// dx_t += mult_l^T * mult_r
MatrixTranspMultiplyAcc(dev, *xs[2], mult_r, dEdxi);
// dx_t += mult_l^T * mult_r
MatrixTranspMultiplyAcc(dev, *xs[2], mult_r, dEdxi);

scratch_allocator->free();
scratch_allocator->free();

} else if(i==1){ // dh_tm1
// goal: dh_tm1 = [Wh_i]^T [di . i_t . (1-i_t)]
Expand All @@ -146,22 +146,22 @@ namespace dynet {
// [Wh_g] [dg . (1 - g_t^2)]
// note: here Wh is broadcasted over batches

// allocate scratch mem mult_l, mult_r
AlignedMemoryPool* scratch_allocator = fx.device->pools[(int)DeviceMempool::SCS];
Tensor mult_r(Dim({hidden_dim*4, 1},batch_size), nullptr, fx.device, fx.mem_pool);
mult_r.v = static_cast<float*>(scratch_allocator->allocate(mult_r.d.size() * sizeof(float)));
// allocate scratch mem mult_l, mult_r
AlignedMemoryPool* scratch_allocator = fx.device->pools[(int)DeviceMempool::SCS];
Tensor mult_r(Dim({hidden_dim*4, 1},batch_size), nullptr, fx.device, fx.mem_pool);
mult_r.v = static_cast<float*>(scratch_allocator->allocate(mult_r.d.size() * sizeof(float)));

// mult_r = [di . i_t . (1-i_t)]
// [df . f_t . (1-f_t)]
// [do . o_t . (1-o_t)]
// [dg . (1 - g_t^2)]
mult_r.tb<2>().slice(indices_mat_i, sizes_mat_3).device(*dev.edevice) = dEdf.tb<2>().slice(indices_mat_i, sizes_mat_3) * fx.tb<2>().slice(indices_mat_i, sizes_mat_3) * (fx.tb<2>().slice(indices_mat_i, sizes_mat_3).constant(1) - fx.tb<2>().slice(indices_mat_i, sizes_mat_3));
mult_r.tb<2>().slice(indices_mat_g, sizes_mat_1).device(*dev.edevice) = dEdf.tb<2>().slice(indices_mat_g, sizes_mat_1) * (fx.tb<2>().slice(indices_mat_g, sizes_mat_1).constant(1) - fx.tb<2>().slice(indices_mat_g, sizes_mat_1).square());
// mult_r = [di . i_t . (1-i_t)]
// [df . f_t . (1-f_t)]
// [do . o_t . (1-o_t)]
// [dg . (1 - g_t^2)]
mult_r.tb<2>().slice(indices_mat_i, sizes_mat_3).device(*dev.edevice) = dEdf.tb<2>().slice(indices_mat_i, sizes_mat_3) * fx.tb<2>().slice(indices_mat_i, sizes_mat_3) * (fx.tb<2>().slice(indices_mat_i, sizes_mat_3).constant(1) - fx.tb<2>().slice(indices_mat_i, sizes_mat_3));
mult_r.tb<2>().slice(indices_mat_g, sizes_mat_1).device(*dev.edevice) = dEdf.tb<2>().slice(indices_mat_g, sizes_mat_1) * (fx.tb<2>().slice(indices_mat_g, sizes_mat_1).constant(1) - fx.tb<2>().slice(indices_mat_g, sizes_mat_1).square());

// dx_t += mult_l * mult_r
MatrixTranspMultiplyAcc(dev, *xs[3], mult_r, dEdxi);
// dx_t += mult_l * mult_r
MatrixTranspMultiplyAcc(dev, *xs[3], mult_r, dEdxi);

scratch_allocator->free();
scratch_allocator->free();

} else if(i==2){ // dWx
// goal: dWx_i = [di . i_t . (1-i_t)] * x_t (here * is outer product), then sum over batches
Expand Down Expand Up @@ -196,8 +196,8 @@ namespace dynet {
AlignedMemoryPool* scratch_allocator = fx.device->pools[(int)DeviceMempool::SCS];
Tensor mult_l(Dim({hidden_dim*4, 1},batch_size), nullptr, fx.device, fx.mem_pool);
mult_l.v = static_cast<float*>(scratch_allocator->allocate(mult_l.d.size() * sizeof(float)));
Tensor mult_y(Dim({hidden_dim*4, hidden_dim},batch_size), nullptr, fx.device, fx.mem_pool);
mult_y.v = static_cast<float*>(scratch_allocator->allocate(mult_y.d.size() * sizeof(float)));
// Tensor mult_y(Dim({hidden_dim*4, hidden_dim},batch_size), nullptr, fx.device, fx.mem_pool);
// mult_y.v = static_cast<float*>(scratch_allocator->allocate(mult_y.d.size() * sizeof(float)));

// mult_l = [di . i_t . (1-i_t)]
// [df . f_t . (1-f_t)]
Expand Down Expand Up @@ -270,7 +270,7 @@ namespace dynet {
Eigen::DSizes<ptrdiff_t, 2> sizes_1(hidden_dim, static_cast<ptrdiff_t>(fx.d.bd));

fx.tbvec().device(*dev.edevice) = gates_t->tbvec().slice(indices_i, sizes_1) * gates_t->tbvec().slice(indices_g, sizes_1)
+ gates_t->tbvec().slice(indices_f, sizes_1) * c_tm1->tbvec();
+ gates_t->tbvec().slice(indices_f, sizes_1) * c_tm1->tbvec();

}

Expand Down Expand Up @@ -358,14 +358,14 @@ namespace dynet {
// dc_t = dh_t . o_t . (1 - tanh^2(c_t)))
// = dh_t . o_t . (1 - (h_t cdiv o_t)^2)
dEdxi.tb<2>().device(*dev.edevice) += dEdf.tb<2>()
* xs[1]->tb<2>().slice(indices_o, sizes_1)
* (xs[0]->tb<2>().constant(1) - xs[0]->tb<2>().tanh().square());
* xs[1]->tb<2>().slice(indices_o, sizes_1)
* (xs[0]->tb<2>().constant(1) - xs[0]->tb<2>().tanh().square());
// TODO: we could use the below..
// - pro: potential speed up (replace tanh by cdiv)
// - con: potential (though unlikely) division by 0
// dEdxi.tb<2>().device(*dev.edevice) += dEdf.tb<2>()
// * xs[1]->tb<2>().slice(indices_o, sizes_1)
// * (xs[0]->tb<2>().constant(1) - (fx.tb<2>() / xs[1]->tb<2>().slice(indices_o, sizes_1)).square());
// dEdxi.tb<2>().device(*dev.edevice) += dEdf.tb<2>()
// * xs[1]->tb<2>().slice(indices_o, sizes_1)
// * (xs[0]->tb<2>().constant(1) - (fx.tb<2>() / xs[1]->tb<2>().slice(indices_o, sizes_1)).square());
} else if(i==1){
// do_t = dh_t . tanh(c_t)
dEdxi.tb<2>().slice(indices_o, sizes_1).device(*dev.edevice) += dEdf.tb<2>() * xs[0]->tb<2>().tanh();
Expand Down
2 changes: 1 addition & 1 deletion dynet/nodes-matrixmultiply.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ void MatrixMultiply::forward_dev_impl(const MyDevice & dev, const vector<const T
DYNET_ASSERT(xs.size() == 2, "Failed dimension check in MatrixMultiply::forward");
#ifdef __CUDACC__
// fx = 0*fx + xs[0] * xs[1]
MatrixMultiply(dev, *xs[0], *xs[1], fx, kSCALAR_ZERO);
dynet::MatrixMultiply(dev, *xs[0], *xs[1], fx, kSCALAR_ZERO);
#else
DYNET_ASSERT(fx.d.bd == max(xs[0]->d.bd, xs[1]->d.bd), "Failed dimension check in MatrixMultiply::forward");
if(xs[0]->d.bd == 1) {
Expand Down
8 changes: 4 additions & 4 deletions tests/test-rnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ BOOST_AUTO_TEST_CASE( lstm_node_h_fwd ) {
Expression c_t = dynet::input(cg, Dim({hidden_dim}, batch_size), {0.f, 0.1f, 0.2f, 0.3f, 0.4f, 0.7f});
Expression gates_t = dynet::input(cg, Dim({hidden_dim*4}, batch_size), {0.f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.f, -0.1f, -0.2f, -0.3f, -0.4f, -0.5f, 0.01f, 0.11f, 0.21f, 0.31f, 0.41f, 0.51f, -0.01f, -0.11f, -0.21f, -0.31f, -0.41f, -0.51f});
Expression h_t = vanilla_lstm_h(c_t, gates_t);
BOOST_CHECK_CLOSE(as_vector(pick_batch_elem(h_t, (unsigned)0).value())[0], 0.0, 0.001);
BOOST_CHECK_SMALL(as_vector(pick_batch_elem(h_t, (unsigned)0).value())[0], (float)1.0e-6);
BOOST_CHECK_CLOSE(as_vector(pick_batch_elem(h_t, (unsigned)0).value())[1], -0.009966799462, 0.001);
BOOST_CHECK_CLOSE(as_vector(pick_batch_elem(h_t, (unsigned)0).value())[2], -0.03947506404, 0.001);
BOOST_CHECK_CLOSE(as_vector(pick_batch_elem(h_t, (unsigned)1).value())[0], -0.002913126125, 0.001);
Expand Down Expand Up @@ -270,9 +270,9 @@ BOOST_AUTO_TEST_CASE( lstm_node_c_fwd ) {
Expression gates_t = dynet::input(cg, Dim({hidden_dim*4}, batch_size), {0.f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.f, -0.1f, -0.2f, -0.3f, -0.4f, -0.5f,
0.01f, 0.11f, 0.21f, 0.31f, 0.41f, 0.51f, -0.01f, -0.11f, -0.21f, -0.31f, -0.41f, -0.51f});
Expression c_t = vanilla_lstm_c(c_tm1, gates_t);
BOOST_CHECK_CLOSE(as_vector(pick_batch_elem(c_t, (unsigned)0).value())[0], 0, 0.001);
BOOST_CHECK_CLOSE(as_vector(pick_batch_elem(c_t, (unsigned)0).value())[1], 0, 0.001);
BOOST_CHECK_CLOSE(as_vector(pick_batch_elem(c_t, (unsigned)0).value())[2], 0, 0.001);
BOOST_CHECK_SMALL(as_vector(pick_batch_elem(c_t, (unsigned)0).value())[0], (float)1.0e-6);
BOOST_CHECK_SMALL(as_vector(pick_batch_elem(c_t, (unsigned)0).value())[1], (float)1.0e-6);
BOOST_CHECK_SMALL(as_vector(pick_batch_elem(c_t, (unsigned)0).value())[2], (float)1.0e-6);
BOOST_CHECK_CLOSE(as_vector(pick_batch_elem(c_t, (unsigned)1).value())[0], 0.0899, 0.001);
BOOST_CHECK_CLOSE(as_vector(pick_batch_elem(c_t, (unsigned)1).value())[1], 0.1189, 0.001);
BOOST_CHECK_CLOSE(as_vector(pick_batch_elem(c_t, (unsigned)1).value())[2], 0.1479, 0.001);
Expand Down

0 comments on commit 3d1bfe3

Please sign in to comment.