Skip to content

Commit

Permalink
initial change
Browse files Browse the repository at this point in the history
  • Loading branch information
Hainan Xu committed Aug 20, 2024
1 parent e043ac1 commit 6f1cdc4
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 54 deletions.
4 changes: 3 additions & 1 deletion nemo/collections/asr/modules/conformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,8 @@ def forward_internal(

audio_signal, pos_emb = self.pos_enc(x=audio_signal, cache_len=cache_len)

print("pos_emb", pos_emb.shape)

# Create the self-attention and padding masks
pad_mask, att_mask = self._create_masks(
att_context_size=cur_att_context_size,
Expand Down Expand Up @@ -679,7 +681,7 @@ def set_max_audio_length(self, max_audio_length):
"""
self.max_audio_length = max_audio_length
device = next(self.parameters()).device
self.pos_enc.extend_pe(max_audio_length, device)
self.pos_enc.extend_pe(24, device)

def _create_masks(self, att_context_size, padding_length, max_audio_length, offset, device):
if self.self_attention_model != "rel_pos_local_attn":
Expand Down
191 changes: 138 additions & 53 deletions nemo/collections/asr/parts/submodules/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,76 +203,140 @@ def rel_shift(self, x):
x = x[:, :, 1:].view(b, h, qlen, pos_len) # (b, h, t1, t2)
return x

def create_power2_mask(self, seq_length):
mask = torch.zeros(seq_length, seq_length, dtype=torch.bool, device=self.pos_bias_u.device)
def create_power2_indices(self, seq_length):
indices = []
valid_mask = []
legit_distances = [0]
dist = 1
while dist < seq_length:
legit_distances.append(dist)
dist = dist * 2

reversed_dists = legit_distances[::-1]
legit_distances = [-i for i in reversed_dists] + legit_distances[1:]

for i in range(seq_length):
for j in range(seq_length):
distance = abs(i - j)
if distance == 0 or (distance & (distance - 1) == 0): # Check if distance is a power of 2
mask[i, j] = True
return mask
for idx, j in enumerate(legit_distances):
valid = i + j >= 0 and i + j < seq_length
if valid:
valid_mask.append(True)
indices.append((i, i + j, idx))
else:
valid_mask.append(False)
indices.append((i, 0, 0)) # will be masked out anyway

ret = torch.tensor(indices, device=self.pos_bias_u.device)
mask = torch.tensor(valid_mask, device=self.pos_bias_u.device)
return ret, torch.reshape(mask, [seq_length, -1])

def forward(self, query, key, value, mask, pos_emb, cache=None):
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query (torch.Tensor): (batch, time1, size)
key (torch.Tensor): (batch, time2, size)
value(torch.Tensor): (batch, time2, size)
mask (torch.Tensor): (batch, time1, time2)
pos_emb (torch.Tensor) : (batch, time1, size)
cache (torch.Tensor) : (batch, time_cache, size)
Returns:
output (torch.Tensor): transformed `value` (batch, time1, d_model) weighted by the query dot key attention
cache (torch.Tensor) : (batch, time_cache_next, size)
"""
key, value, query, cache = self.update_cache(key=key, value=value, query=query, cache=cache)

B, T, _ = query.shape

if torch.is_autocast_enabled():
query, key, value = query.to(torch.float32), key.to(torch.float32), value.to(torch.float32)

# temporary until we solve this more gracefully
with avoid_float16_autocast_context():
q, k, v = self.forward_qkv(query, key, value)
q = q.transpose(1, 2) # (batch, time1, head, d_k)
q = q.transpose(1, 2) # (batch, time2, head, d_k)
k = k.transpose(1, 2) # (batch, time2, head, d_k)

n_batch_pos = pos_emb.size(0)
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
p = p.transpose(1, 2) # (batch, head, time1, d_k)

# (batch, head, time1, d_k)
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
# (batch, head, time1, d_k)
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)

# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))

# compute matrix b and matrix d
# (batch, head, time1, time2)
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
matrix_bd = self.rel_shift(matrix_bd)
# drops extra elements in the matrix_bd to match the matrix_ac's size
matrix_bd = matrix_bd[:, :, :, : matrix_ac.size(-1)]

scores = (matrix_ac + matrix_bd) / self.s_d_k # (batch, head, time1, time2)

# Apply power-of-2 distance mask
power2_mask = self.create_power2_mask(scores.size(-1))
power2_mask = power2_mask.unsqueeze(0).unsqueeze(0) # Add batch and head dimensions
scores = scores.masked_fill(~power2_mask, float('-inf'))
heads = k.shape[2]

# pos_emb: (batch, E, d_model)
E = pos_emb.shape[1] # number of relative positions needed
p = self.linear_pos(pos_emb).view(B, -1, self.h, self.d_k) # (batch, E, head, d)

out = self.forward_attention(v, scores, mask)
q_with_bias_u = q + self.pos_bias_u # (batch, t, head, d_k)
q_with_bias_v = q + self.pos_bias_v # (batch, t, head, d_k)

# Get indices for power-of-2 distances
# indices reference 't' axis (0, 1, 2, 4, ...)
# indices2 references the pos_emb (0,1,2,3,4,...)
# att_mask is false when this_position + position_distance is out of bound
indices, att_mask = self.create_power2_indices(q.size(1))

q_indices, k_indices, log_indices = indices[:, 0], indices[:, 1], indices[:, 2]

# Compute attention scores only for power-of-2 distances
q_selected = q_with_bias_u[:, q_indices, :, :] # (batch, t * E, head, d_k)
d_k = q_selected.shape[-1]
q_selected = q_selected.transpose(1, 2)
q_selected = torch.reshape(q_selected, [B, heads, T, E, d_k])


k_selected = k[:, k_indices, :, :] # (batch, t * E, head, d_k)
k_selected = k_selected.transpose(1, 2)
k_selected = torch.reshape(k_selected, [B, heads, T, E, d_k])

matrix_ac = torch.sum(q_selected * k_selected, dim=-1) # (batch, head, T, E)
# print("HERE matrix_ac", matrix_ac.shape)

# Handle relative positional encoding
# p: batch, time, head, dimension
# print("HERE q_with_bias_v", q_with_bias_v.shape)
# p_selected = p[:, :, k_indices, :]
# p_selected = torch.reshape(p_selected, [B, E, T, E, -1])

q_v_selected = q_with_bias_v[:, q_indices, :, :] # (batch, t * E, head, d)
q_v_selected = q_v_selected.transpose(1, 2) # (batch, head, t*E, d)
q_v_selected = torch.reshape(q_v_selected, [B, heads, T, E, d_k])

# p shape: [batch, E, head, d]
p = p.transpose(1, 2) # [B, head, E, d]
p = torch.reshape(p, [B, heads, 1, E, d_k])

matrix_bd = torch.sum(q_v_selected * p, dim=-1) # [B, head, T, E]

scores = (matrix_ac + matrix_bd) / self.s_d_k # (batch, head, t, E)

# print("k_indices is", torch.reshape(k_indices, [T, E]))
out = self.forward_hop_attention(v, scores, mask, att_mask, torch.reshape(k_indices, [T, E]))

if cache is None:
return out
else:
return out, cache

def forward_hop_attention(self, value, scores, mask, att_mask, k_indices):
"""Compute attention context vector.
Args:
value (torch.Tensor): (batch, time2, size)
scores(torch.Tensor): (batch, time1, time2)
mask(torch.Tensor): (batch, time1, time2)
returns:
value (torch.Tensor): transformed `value` (batch, time2, d_model) weighted by the attention scores
"""
B = value.size(0)
# print("HERE")
# print("value", value.shape) # b, head, t, d
# print("score", scores.shape) # b, head, t, E
# print("mask ", mask.shape)
# print("attmask", att_mask.shape) # t, E

selected_value = value[:,:,k_indices,:] # [batch, head, t, E, d]

if att_mask is not None:
att_mask = att_mask.unsqueeze(0).unsqueeze(0) # (1, 1, t, E)
scores = scores.masked_fill(att_mask, -10000.0)
attn = torch.softmax(scores, dim=-1).masked_fill(att_mask, 0.0)
else:
attn = torch.softmax(scores, dim=-1)

# if mask is not None:
# mask = mask.unsqueeze(1) # (batch, 1, time1, time2)
# scores = scores.masked_fill(mask, -10000.0)
# attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2)
# else:
# attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)

p_attn = self.dropout(attn) # [b, head, t, E]
p_attn = p_attn.unsqueeze(-1) # [b, head, t, E, 1]
x = torch.sum(p_attn * selected_value, dim=-2) # [b, head, t, d]
x = x.transpose(1, 2).reshape(B, -1, self.h * self.d_k) # (batch, time1, d_model)

return self.linear_out(x) # (batch, time1, d_model)

class RelPositionMultiHeadAttentionLongformer(RelPositionMultiHeadAttention):
"""Multi-Head Attention layer of Transformer-XL with sliding window local+global attention from Longformer.
Expand Down Expand Up @@ -980,7 +1044,20 @@ def extend_pe(self, length, device):
return
# positions would be from negative numbers to positive
# positive positions would be used for left positions and negative for right positions
positions = torch.arange(length - 1, -length, -1, dtype=torch.float32, device=device).unsqueeze(1)
positions = []
for i in range(-length, length + 1):
if i < 0:
positions.append(2 ** (-i))
elif i == 0:
positions.append(0)
else:
positions.append(2 ** i)

positions = torch.tensor(positions, dtype=torch.float32, device=device)
positions = positions.unsqueeze(1)

print("HERE positions is", positions.shape)

self.create_pe(positions=positions)

def forward(self, x, cache_len=0):
Expand All @@ -999,10 +1076,18 @@ def forward(self, x, cache_len=0):
# center_pos would be the index of position 0
# negative positions would be used for right and positive for left tokens
# for input of length L, 2*L-1 positions are needed, positions from (L-1) to -(L-1)
input_len = x.size(1) + cache_len
input_len = int(math.log2(x.size(1))) + 2

print("x size", x.shape)
print("HERE input_len", input_len)


print("HERE self.pe", self.pe.shape)
center_pos = self.pe.size(1) // 2 + 1
print("HERE center_pos", center_pos)
start_pos = center_pos - input_len
end_pos = center_pos + input_len - 1
print("start_pos:end_pos", start_pos, end_pos)
pos_emb = self.pe[:, start_pos:end_pos]
if self.dropout_emb:
pos_emb = self.dropout_emb(pos_emb)
Expand Down

0 comments on commit 6f1cdc4

Please sign in to comment.