Skip to content

Commit

Permalink
initial light and dynamic convolution kernels (#547)
Browse files Browse the repository at this point in the history
Summary:
CUDA code for light/dynamicconv kernels, including pytorch modules. Modules can be built by running setup.py in each respective folder, and can then be imported and used like any other module.
Pull Request resolved: fairinternal/fairseq-py#547

Reviewed By: myleott, shubho

Differential Revision: D15703660

Pulled By: nng555

fbshipit-source-id: e9c913753be3a1cd571965f7200df6678b644520
  • Loading branch information
nng555 authored and facebook-github-bot committed Aug 14, 2019
1 parent b870468 commit f840564
Show file tree
Hide file tree
Showing 23 changed files with 1,958 additions and 27 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ ENV/

# Generated files
fairseq/temporal_convolution_tbc
fairseq/modules/*_layer/*_forward.cu
fairseq/modules/*_layer/*_backward.cu

# data
data-bin/
16 changes: 14 additions & 2 deletions examples/pay_less_attention_paper/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)
This page contains pointers to pre-trained models as well as instructions on how to train new models for [our paper](https://openreview.net/pdf?id=SkVhlh09tX)
This page contains pointers to pre-trained models as well as instructions on how to train new models for [our paper](https://arxiv.org/abs/1901.10430)

## Citation:
```bibtex
Expand All @@ -8,7 +8,7 @@ This page contains pointers to pre-trained models as well as instructions on how
author = {Felix Wu and Angela Fan and Alexei Baevski and Yann Dauphin and Michael Auli},
booktitle = {International Conference on Learning Representations},
year = {2019},
url = {https://openreview.net/forum?id=SkVhlh09tX},
url = {https://arxiv.org/abs/1901.10430},
}
```

Expand Down Expand Up @@ -39,6 +39,18 @@ To use the model without GLU, please set `--encoder-glu 0 --decoder-glu 0`.
For LightConv, please use `--encoder-conv-type lightweight --decoder-conv-type lightweight`, otherwise the default is DynamicConv.
For best BLEU results, lenpen may need to be manually tuned.

To use the CUDA kernels, first install the PyTorch modules using the commands below
```sh
# to install lightconv
python fairseq/modules/lightconv_layer/cuda_function_gen.py
python fairseq/modules/lightconv_layer/setup.py install

# to install dynamicconv
python fairseq/modules/dynamicconv_layer/cuda_function_gen.py
python fairseq/modules/dynamicconv_layer/setup.py install
```
Once the CUDA modules are installed, they will automatically be used instead of the PyTorch modules.

### IWSLT14 De-En
Training and evaluating DynamicConv (without GLU) on a GPU:
```sh
Expand Down
38 changes: 19 additions & 19 deletions fairseq/models/lightconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

import math
import sys

import torch
import torch.nn as nn
Expand All @@ -19,10 +20,10 @@
)
from fairseq.modules import (
AdaptiveSoftmax,
DynamicConv1dTBC,
DynamicConv,
LayerNorm,
PositionalEmbedding,
LightweightConv1dTBC,
LightweightConv,
MultiheadAttention,
)

Expand Down Expand Up @@ -173,7 +174,6 @@ def build_embedding(dictionary, embed_dim, path=None):
decoder = LightConvDecoder(args, tgt_dict, decoder_embed_tokens)
return LightConvModel(encoder, decoder)


class LightConvEncoder(FairseqEncoder):
"""
LightConv encoder consisting of *args.encoder_layers* layers. Each layer
Expand Down Expand Up @@ -447,15 +447,15 @@ def __init__(self, args, kernel_size=0):
self.linear1 = Linear(self.embed_dim, self.conv_dim)
self.act = None
if args.encoder_conv_type == 'lightweight':
self.conv = LightweightConv1dTBC(self.conv_dim, kernel_size, padding_l=padding_l,
weight_softmax=args.weight_softmax,
num_heads=args.encoder_attention_heads,
weight_dropout=args.weight_dropout)
self.conv = LightweightConv(self.conv_dim, kernel_size, padding_l=padding_l,
weight_softmax=args.weight_softmax,
num_heads=args.encoder_attention_heads,
weight_dropout=args.weight_dropout)
elif args.encoder_conv_type == 'dynamic':
self.conv = DynamicConv1dTBC(self.conv_dim, kernel_size, padding_l=padding_l,
weight_softmax=args.weight_softmax,
num_heads=args.encoder_attention_heads,
weight_dropout=args.weight_dropout)
self.conv = DynamicConv(self.conv_dim, kernel_size, padding_l=padding_l,
weight_softmax=args.weight_softmax,
num_heads=args.encoder_attention_heads,
weight_dropout=args.weight_dropout)
else:
raise NotImplementedError
self.linear2 = Linear(self.conv_dim, self.embed_dim)
Expand Down Expand Up @@ -535,15 +535,15 @@ def __init__(self, args, no_encoder_attn=False, kernel_size=0):
self.linear1 = Linear(self.embed_dim, self.conv_dim)
self.act = None
if args.decoder_conv_type == 'lightweight':
self.conv = LightweightConv1dTBC(self.conv_dim, kernel_size, padding_l=kernel_size-1,
weight_softmax=args.weight_softmax,
num_heads=args.decoder_attention_heads,
weight_dropout=args.weight_dropout)
self.conv = LightweightConv(self.conv_dim, kernel_size, padding_l=kernel_size-1,
weight_softmax=args.weight_softmax,
num_heads=args.decoder_attention_heads,
weight_dropout=args.weight_dropout)
elif args.decoder_conv_type == 'dynamic':
self.conv = DynamicConv1dTBC(self.conv_dim, kernel_size, padding_l=kernel_size-1,
weight_softmax=args.weight_softmax,
num_heads=args.decoder_attention_heads,
weight_dropout=args.weight_dropout)
self.conv = DynamicConv(self.conv_dim, kernel_size, padding_l=kernel_size-1,
weight_softmax=args.weight_softmax,
num_heads=args.decoder_attention_heads,
weight_dropout=args.weight_dropout)
else:
raise NotImplementedError
self.linear2 = Linear(self.conv_dim, self.embed_dim)
Expand Down
10 changes: 8 additions & 2 deletions fairseq/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@
from .character_token_embedder import CharacterTokenEmbedder
from .conv_tbc import ConvTBC
from .downsampled_multihead_attention import DownsampledMultiHeadAttention
from .dynamic_convolution import DynamicConv1dTBC
from .dynamic_convolution import DynamicConv, DynamicConv1dTBC
#from .dynamicconv_layer import DynamicconvLayer
from .gelu import gelu, gelu_accurate
from .grad_multiply import GradMultiply
from .highway import Highway
from .layer_norm import LayerNorm
from .learned_positional_embedding import LearnedPositionalEmbedding
from .lightweight_convolution import LightweightConv1dTBC
from .lightweight_convolution import LightweightConv, LightweightConv1dTBC
#from .lightconv_layer import LightconvLayer
from .linearized_convolution import LinearizedConvolution
from .logsumexp_moe import LogSumExpMoE
from .mean_pool_gating_network import MeanPoolGatingNetwork
Expand All @@ -36,14 +38,18 @@
'CharacterTokenEmbedder',
'ConvTBC',
'DownsampledMultiHeadAttention',
# 'DyamicconvLayer',
'DynamicConv1dTBC',
'DynamicConv',
'gelu',
'gelu_accurate',
'GradMultiply',
'Highway',
'LayerNorm',
'LearnedPositionalEmbedding',
# 'LightconvLayer',
'LightweightConv1dTBC',
'LightweightConv',
'LinearizedConvolution',
'LogSumExpMoE',
'MeanPoolGatingNetwork',
Expand Down
202 changes: 202 additions & 0 deletions fairseq/modules/cuda_utils.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
/**
* Copyright (c) 2018-present, Facebook, Inc.
* All rights reserved.
*
*/


template <typename U, typename V>
constexpr __host__ __device__ auto divUp(U a, V b) -> decltype(a + b) {
return (a + b - 1) / b;
}


template<int FS, int SB, int padding_l, typename scalar_t>
__inline__ __device__
void zeroSharedMem(scalar_t* data) {
/*
Given an array of length FS + SB, zero out the first padding_l and last
(FS - padding_l) values in the array
*/

int tid = threadIdx.x;

if (FS < SB) {

// zero all if we have enough threads in a block to do all of them
if (tid < padding_l || tid > SB - FS + padding_l - 1) {
data[tid] = scalar_t(0.0);
}
} else {

// otherwise zero out one block at a time
const int numIterations = divUp<int, int>(FS, SB);
for (int i = 0; i < numIterations; i++) {
int offset = i * SB;
if (tid + offset < padding_l) {
data[tid + offset] = scalar_t(0.0);
} else if (tid + offset < FS) {
data[SB + tid + offset] = scalar_t(0.0);
}
}
}
}

template<typename scalar_t>
__inline__ __device__
scalar_t warpReduce(scalar_t data) {
/*
Reduce an array within each warp. After processing all values in warp will
caontain the sum of all original values in that warp.
data - pointer to data to reduce
*/
data += __shfl_xor_sync(SHFL_MASK, data, 16);
data += __shfl_xor_sync(SHFL_MASK, data, 8);
data += __shfl_xor_sync(SHFL_MASK, data, 4);
data += __shfl_xor_sync(SHFL_MASK, data, 2);
data += __shfl_xor_sync(SHFL_MASK, data, 1);
return data;
}

template<typename scalar_t>
__inline__ __device__
scalar_t blockReduce(scalar_t data) {
/*
Reduce an entire array on the block level. After processing, the
first value in the array will contain the reduced sum.
data - pointer to data to reduce
*/

static __shared__ scalar_t warpSum[32];
const int tid = threadIdx.x;
int wid = tid / 32;
int lane = tid % 32;

__syncthreads();

// reduce each warp then write to shared memory
scalar_t sum = warpReduce(data);
if (lane == 0) {
warpSum[wid] = sum;
}

__syncthreads();

scalar_t v;
// perform final sum of partial warp sums
if (tid < blockDim.x / 32) {
v = warpSum[lane];
} else {
v = scalar_t(0.0);
}

if (wid == 0) {
v = warpReduce(v);
}
__syncthreads();

return v;
}

void checkCudaStatus(cudaError_t status, int lineNumber = -1) {

if (status != cudaSuccess) {
std::cout << cudaGetErrorString(status)
<< " at line " << lineNumber << std::endl;
std::cout << "Exiting" << std::endl;
exit(1);
}
}

template<int FS, int SB, int padding_l, typename scalar_t>
__device__
void load_input_to_shared(const scalar_t* input, // global memory
int inputOffset, int sequenceLength,
int iteration, int numIterations,
bool no_prev, scalar_t* output /* shared memory */) {
/*
Load a block size of input into shared memory with
right and left overhang of total size FS. If previously
loaded memory, overlap will be shifted over to reduce
global memory access
input - pointer to start of channel sequence
inputOffset - how far in the sequence to start loading
sequenceLength - total length of sequence
iteration - which block of sequence we are loading
numIterations - total number of blocks to load
no_prev - whether to load the whole block if the previous block
wasn't loaded
output - shared memory to write input to
*/

const int tid = threadIdx.x;

// Load the left "overhang" of input
if (iteration > 0) {
if (padding_l < SB) {

// load all at once
if (tid < padding_l) {
output[tid] = (no_prev) ? input[inputOffset - padding_l + tid] : output[tid + SB];
}
} else {

// load in chunks of size SB
int numIterations = divUp<int, int>(padding_l, SB);
for (int i = 0; i < numIterations; i++) {
int offset = i * SB;
if ((tid + offset) < padding_l) {
output[tid + offset] = (no_prev) ? input[inputOffset - padding_l + tid + offset] : output[tid + offset + SB];
}
}
}
}

// Load the right "overhang" of input
if (iteration < (numIterations - 1)) {
const int elementsLeft = sequenceLength - (iteration+1) * SB;

if ((FS - padding_l) < SB) {

// load all at once
if (tid < (FS - padding_l)) {
output[padding_l + SB + tid] = (tid < elementsLeft) ? input[inputOffset + SB + tid] : scalar_t(0.0);
}
} else {

// load in chunks of size SB
int numIterations = divUp<int, int>(FS - padding_l, SB);
for (int i = 0; i < numIterations; i++) {
int offset = i * SB;
if ((tid + offset) < (FS - padding_l)) {
output[padding_l + SB + tid + offset] = ((tid + offset) < elementsLeft) ? input[inputOffset + SB + tid + offset] : scalar_t(0.0);
}
}
}
}

// We should also clear out the right "overhang"
if (iteration == (numIterations - 1)) {
if ((FS - padding_l) < SB) {

// clear out all at once
if (tid < (FS - padding_l)) {
output[padding_l + SB + tid] = scalar_t(0.0);
}
} else {

// clear in chunks of size SB
int numIterations = divUp<int, int>(FS - padding_l, SB);
for (int i = 0; i < numIterations; i++) {
int offset = i * SB;
if ((tid + offset) < (FS - padding_l)) {
output[padding_l + SB + tid + offset] = scalar_t(0.0);
}
}
}
}
output[tid + padding_l] = ((inputOffset + tid) < sequenceLength) ? input[inputOffset + tid] : scalar_t(0.0);
}
Loading

0 comments on commit f840564

Please sign in to comment.