ynhe
init
16dc4f2
raw
history blame
37.7 kB
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
DETR Transformer class.
Copy-paste from torch.nn.Transformer with modifications:
* positional encodings are passed in MHattention
* extra LN at the end of encoder is removed
* decoder returns a stack of activations from all decoding layers
"""
import copy
from typing import Optional
import torch
import torch.nn.functional as F
from torch import nn, Tensor
import math
import numpy as np
from .attention import MultiheadAttention
from .crossattention import MultiheadAttention as cateattention
class MLP(nn.Module):
""" Very simple multi-layer perceptron (also called FFN)"""
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
return x
def inverse_sigmoid(x, eps=1e-3):
x = x.clamp(min=0, max=1)
x1 = x.clamp(min=eps)
x2 = (1 - x).clamp(min=eps)
return torch.log(x1/x2)
def gen_sineembed_for_position(pos_tensor, d_model):
# n_query, bs, _ = pos_tensor.size()
# sineembed_tensor = torch.zeros(n_query, bs, 256)
scale = 2 * math.pi
dim_t = torch.arange(d_model//2, dtype=torch.float32, device=pos_tensor.device)
dim_t = 10000 ** (2 * (dim_t // 2) / (d_model//2))
center_embed = pos_tensor[:, :, 0] * scale
pos_x = center_embed[:, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
span_embed = pos_tensor[:, :, 1] * scale
pos_w = span_embed[:, :, None] / dim_t
pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2)
pos = torch.cat((pos_x, pos_w), dim=2)
return pos
class Transformer(nn.Module):
def __init__(self, d_model=512, nhead=8, num_queries=2, num_encoder_layers=6,
num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
activation="relu", normalize_before=False,
return_intermediate_dec=False, query_dim=2,
keep_query_pos=False, query_scale_type='cond_elewise',
num_patterns=0,
modulate_t_attn=True,
bbox_embed_diff_each_layer=False, args=None
):
super().__init__()
self.args = args
mcls_encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
dropout, activation, normalize_before)
mcls_encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
self.mcls_encoder = TransformerEncoder(mcls_encoder_layer, args.moment_layers, mcls_encoder_norm)
t2v_encoder_layer = T2V_TransformerEncoderLayer(d_model, nhead, dim_feedforward,
dropout, activation, normalize_before, self.args.num_dummies)
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
self.t2v_encoder = TransformerCATEEncoder(t2v_encoder_layer, args.t2v_layers, encoder_norm)
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
dropout, activation, normalize_before)
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
dropout, activation, normalize_before, keep_query_pos=keep_query_pos)
decoder_norm = nn.LayerNorm(d_model)
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
return_intermediate=return_intermediate_dec,
d_model=d_model, query_dim=query_dim, keep_query_pos=keep_query_pos, query_scale_type=query_scale_type,
modulate_t_attn=modulate_t_attn,
bbox_embed_diff_each_layer=bbox_embed_diff_each_layer)
self._reset_parameters()
self.d_model = d_model
self.nhead = nhead
self.dec_layers = num_decoder_layers
self.num_queries = num_queries
self.num_patterns = num_patterns
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, src, mask, query_embed, pos_embed, video_length=None, moment_idx=None, msrc=None, mpos=None, mmask=None,
nmsrc=None, nmpos=None, nmmask=None,
ctxtoken=None, gtoken=None, gpos=None, vlen=None):
"""
Args:
src: (batch_size, L, d)
mask: (batch_size, L)
query_embed: (#queries, d)
pos_embed: (batch_size, L, d) the same as src
video length: feature shape
vlen: actual video length
Returns:
"""
# moment token
device = ctxtoken.device
if msrc is not None:
msrc = msrc.permute(1, 0, 2) # (L, batch_size, d)
mpos = mpos.permute(1, 0, 2) # (L, batch_size, d)
mmemory = self.mcls_encoder(msrc, src_key_padding_mask=mmask, pos=mpos) # (L, batch_size, d)
mmemory_moment, mmemory_frames = mmemory[0], mmemory[1:]
else:
mmemory_moment = None
mmemory_frames = None
if nmsrc is not None:
nmsrc = nmsrc.permute(1, 0, 2) # (L, batch_size, d)
nmpos = nmpos.permute(1, 0, 2) # (L, batch_size, d)
nmmemory = self.mcls_encoder(nmsrc, src_key_padding_mask=nmmask, pos=nmpos) # (L, batch_size, d)
nmmemory_moment, nmmemory_frames = nmmemory[0], nmmemory[1:]
else:
nmmemory_moment = None
nmmemory_frames = None
# flatten NxCxHxW to HWxNxC
bs, l, d = src.shape
src = src.permute(1, 0, 2) # (L, batch_size, d)
pos_embed = pos_embed.permute(1, 0, 2) # (L, batch_size, d)
refpoint_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) # (#queries, batch_size, d)
# import pdb; pdb.set_trace()
# print(src.dtype)
t2v_src, attn_weights = self.t2v_encoder(src, src_key_padding_mask=mask, pos=pos_embed, video_length=video_length) # (L, batch_size, d)
# Saliency Token
## Context
ctx_src_ = ctxtoken.permute(1, 0, 2) # L b d
## Distribution Token with 10 prompt tokens
### Video Clip featre - context (avg) --> Find top 10 similar tokens --> weighted sum
# import pdb; pdb.set_trace()
fr_token_sim = torch.softmax(torch.matmul(F.normalize((src[:video_length] - ctx_src_).permute(1, 0, 2), dim=2), F.normalize(gtoken, dim=1).T), dim=-1)# src : b 75 d, token : 10 x d --> b 75 10
### Calculate clip importance
frame_importance = attn_weights[:, :, self.args.num_dummies:].sum(2).clone().detach() # b 75
### Masking empty clips
for i in range(len(frame_importance)):
frame_importance[i][vlen[i]:] *= 0.
### Normalize
frame_importance = (frame_importance / frame_importance.sum(1).unsqueeze(1)) * frame_importance.size(1) # b 75
### Scale the similarity with importance
fr_token_sim = fr_token_sim * frame_importance.unsqueeze(2).repeat(1, 1, fr_token_sim.size(2)) # b 75 10
fr_token_sim = fr_token_sim.mean(1) # b 10
topk_val, topkidx = torch.topk(fr_token_sim, k=self.args.num_prompts, dim=1)
src_ = torch.zeros((len(fr_token_sim), self.d_model), dtype=torch.bfloat16).to(device)
for i in range(len(fr_token_sim)):
src_[i] = (topk_val[i].unsqueeze(1) * gtoken[topkidx[i]]).sum(0)
src_ = src_.reshape(1, src.size(1), -1)
## Add context and distribution token
src_ = src_ + ctx_src_
pos_ = gpos.reshape([1, 1, self.d_model]).repeat(1, pos_embed.shape[1], 1)
mask_ = torch.tensor([[False]]).to(mask.device).repeat(mask.shape[0], 1)
# import pdb; pdb.set_trace()
src_, _ = self.t2v_encoder(src_, src_key_padding_mask=mask_, pos=pos_,
video_length=video_length, dummy=False) # (L, batch_size, d)
src = torch.cat([src_, t2v_src], dim=0)
mask = torch.cat([mask_, mask], dim=1)
pos_embed = torch.cat([pos_, pos_embed], dim=0)
src = src[:video_length + 1]
mask = mask[:, :video_length + 1]
pos_embed = pos_embed[:video_length + 1]
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) # (L, batch_size, d)
memory_global, memory_local = memory[0], memory[1:]
memory_local += memory_global.unsqueeze(0).repeat(memory_local.size(0), 1, 1)
mask_local = mask[:, 1:]
pos_embed_local = pos_embed[1:]
tgt = torch.zeros(refpoint_embed.shape[0], bs, d).to(device)
tgt = tgt.type(torch.bfloat16)
# import pdb; pdb.set_trace()
hs, references = self.decoder(tgt, memory_local, memory_key_padding_mask=mask_local, pos=pos_embed_local, refpoints_unsigmoid=refpoint_embed) # (#layers, #queries, batch_size, d)
memory_local = memory_local.transpose(0, 1) # (batch_size, L, d)
return hs, references, memory_local, memory_global, attn_weights, mmemory_moment, nmmemory_moment, mmemory_frames, nmmemory_frames
class TransformerCATEEncoder(nn.Module):
def __init__(self, encoder_layer, num_layers, norm=None, return_intermediate=False):
super().__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
self.return_intermediate = return_intermediate
def forward(self, src,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
dummy=True,
**kwargs):
output = src
intermediate = []
attn_weights = None
for i, layer in enumerate(self.layers):
output, attn_weight = layer(output, src_mask=mask,
src_key_padding_mask=src_key_padding_mask, pos=pos, dummy=dummy, **kwargs)
if attn_weights is None:
attn_weights = attn_weight
else:
attn_weights = attn_weights + attn_weight
if self.return_intermediate:
intermediate.append(output)
attn_weights /= self.num_layers
if self.norm is not None:
output = self.norm(output)
if self.return_intermediate:
return torch.stack(intermediate)
return output, attn_weights
class TransformerEncoder(nn.Module):
def __init__(self, encoder_layer, num_layers, norm=None, return_intermediate=False):
super().__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
self.return_intermediate = return_intermediate
def forward(self, src,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
**kwargs):
output = src
intermediate = []
for layer in self.layers:
output = layer(output, src_mask=mask,
src_key_padding_mask=src_key_padding_mask, pos=pos, **kwargs)
if self.return_intermediate:
intermediate.append(output)
if self.norm is not None:
output = self.norm(output)
if self.return_intermediate:
return torch.stack(intermediate)
return output
class TransformerDecoder(nn.Module):
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False,
d_model=256, query_dim=2, keep_query_pos=False, query_scale_type='cond_elewise',
modulate_t_attn=False,
bbox_embed_diff_each_layer=False,
):
super().__init__()
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
self.return_intermediate = return_intermediate
assert return_intermediate
self.query_dim = query_dim
assert query_scale_type in ['cond_elewise', 'cond_scalar', 'fix_elewise']
self.query_scale_type = query_scale_type
if query_scale_type == 'cond_elewise':
self.query_scale = MLP(d_model, d_model, d_model, 2)
elif query_scale_type == 'cond_scalar':
self.query_scale = MLP(d_model, d_model, 1, 2)
elif query_scale_type == 'fix_elewise':
self.query_scale = nn.Embedding(num_layers, d_model)
else:
raise NotImplementedError("Unknown query_scale_type: {}".format(query_scale_type))
self.ref_point_head = MLP(d_model, d_model, d_model, 2)
# self.bbox_embed = None
# for DAB-detr
if bbox_embed_diff_each_layer:
self.bbox_embed = nn.ModuleList([MLP(d_model, d_model, 2, 3) for i in range(num_layers)])
else:
self.bbox_embed = MLP(d_model, d_model, 2, 3)
# init bbox_embed
if bbox_embed_diff_each_layer:
for bbox_embed in self.bbox_embed:
nn.init.constant_(bbox_embed.layers[-1].weight.data, 0)
nn.init.constant_(bbox_embed.layers[-1].bias.data, 0)
else:
nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0)
nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)
self.d_model = d_model
self.modulate_t_attn = modulate_t_attn
self.bbox_embed_diff_each_layer = bbox_embed_diff_each_layer
if modulate_t_attn:
self.ref_anchor_head = MLP(d_model, d_model, 1, 2)
if not keep_query_pos:
for layer_id in range(num_layers - 1):
self.layers[layer_id + 1].ca_qpos_proj = None
def forward(self, tgt, memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
refpoints_unsigmoid: Optional[Tensor] = None, # num_queries, bs, 2
):
output = tgt
intermediate = []
reference_points = refpoints_unsigmoid.sigmoid()
ref_points = [reference_points]
# import pdb; pdb.set_trace()
for layer_id, layer in enumerate(self.layers):
obj_center = reference_points[..., :self.query_dim]
# get sine embedding for the query vector
query_sine_embed = gen_sineembed_for_position(obj_center, self.d_model)
query_sine_embed = query_sine_embed.type(torch.bfloat16)
query_pos = self.ref_point_head(query_sine_embed)
# For the first decoder layer, we do not apply transformation over p_s
if self.query_scale_type != 'fix_elewise':
if layer_id == 0:
pos_transformation = 1
else:
pos_transformation = self.query_scale(output)
else:
pos_transformation = self.query_scale.weight[layer_id]
# apply transformation
query_sine_embed = query_sine_embed * pos_transformation
# modulated HW attentions
if self.modulate_t_attn:
reft_cond = self.ref_anchor_head(output).sigmoid() # nq, bs, 1
query_sine_embed *= (reft_cond[..., 0] / obj_center[..., 1]).unsqueeze(-1)
output = layer(output, memory, tgt_mask=tgt_mask,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
pos=pos, query_pos=query_pos, query_sine_embed=query_sine_embed,
is_first=(layer_id == 0))
# iter update
if self.bbox_embed is not None:
if self.bbox_embed_diff_each_layer:
tmp = self.bbox_embed[layer_id](output)
else:
tmp = self.bbox_embed(output)
# import ipdb; ipdb.set_trace()
tmp[..., :self.query_dim] += inverse_sigmoid(reference_points)
new_reference_points = tmp[..., :self.query_dim].sigmoid()
if layer_id != self.num_layers - 1:
ref_points.append(new_reference_points)
reference_points = new_reference_points.detach()
if self.return_intermediate:
intermediate.append(self.norm(output))
if self.norm is not None:
output = self.norm(output)
if self.return_intermediate:
intermediate.pop()
intermediate.append(output)
if self.return_intermediate:
if self.bbox_embed is not None:
return [
torch.stack(intermediate).transpose(1, 2),
torch.stack(ref_points).transpose(1, 2),
]
else:
return [
torch.stack(intermediate).transpose(1, 2),
reference_points.unsqueeze(0).transpose(1, 2)
]
return output.unsqueeze(0)
class TransformerEncoderLayerThin(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
activation="relu", normalize_before=False):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
# self.linear1 = nn.Linear(d_model, dim_feedforward)
# self.dropout = nn.Dropout(dropout)
# self.linear2 = nn.Linear(dim_feedforward, d_model)
self.linear = nn.Linear(d_model, d_model)
self.norm = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
# self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(self,
src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None):
q = k = self.with_pos_embed(src, pos)
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
key_padding_mask=src_key_padding_mask)[0]
src2 = self.linear(src2)
src = src + self.dropout(src2)
src = self.norm(src)
# src = src + self.dropout1(src2)
# src = self.norm1(src)
# src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
# src = src + self.dropout2(src2)
# src = self.norm2(src)
return src
def forward_pre(self, src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None):
"""not used"""
src2 = self.norm1(src)
q = k = self.with_pos_embed(src2, pos)
src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
key_padding_mask=src_key_padding_mask)[0]
src = src + self.dropout1(src2)
src2 = self.norm2(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
src = src + self.dropout2(src2)
return src
def forward(self, src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None):
if self.normalize_before:
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
class T2V_TransformerEncoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
activation="relu", normalize_before=False, num_dummies=3):
super().__init__()
self.self_attn = cateattention(d_model, nhead, dropout=dropout, num_dummies=num_dummies)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = DropPath(dropout)
self.dropout2 = DropPath(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
self.nhead = nhead
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(self,
src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
video_length=None, dummy=True):
assert video_length is not None
pos_src = self.with_pos_embed(src, pos)
q, k, v = pos_src[:video_length], pos_src[video_length:], src[video_length:]
qmask, kmask = src_key_padding_mask[:, :video_length].unsqueeze(2), src_key_padding_mask[:, video_length:].unsqueeze(1)
attn_mask = torch.matmul(qmask.float(), kmask.float()).bool().repeat(self.nhead, 1, 1)
# - key_padding_mask: :math:`(S)` or :math:`(N, S)` where N is the batch size, S is the source sequence length.
# If a FloatTensor is provided, it will be directly added to the value.
# If a BoolTensor is provided, the positions with the
# value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
# - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
# 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
# S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
# positions. If a BoolTensor is provided, positions with ``True``
# are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
# is provided, it will be added to the attention weight.
# print(q.shape, k.shape, v.shape, attn_mask.shape, src_key_padding_mask[:, video_length + 1:].shape)
# import pdb; pdb.set_trace()
src2, attn_weights = self.self_attn(q, k, v, attn_mask=attn_mask, key_padding_mask=src_key_padding_mask[:, video_length:], dummy=dummy)
src2 = src[:video_length] + self.dropout1(src2)
src3 = self.norm1(src2)
src3 = self.linear2(self.dropout(self.activation(self.linear1(src3))))
src2 = src2 + self.dropout2(src3)
src2 = self.norm2(src2)
src = torch.cat([src2, src[video_length:]])
return src, attn_weights
def forward_pre(self, src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None, dummy=True):
pass
def forward(self, src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None, dummy=True,
**kwargs):
if self.normalize_before:
return self.forward_pre(src, src_mask, src_key_padding_mask, pos, dummy=dummy)
return self.forward_post(src, src_mask, src_key_padding_mask, pos, dummy=dummy, **kwargs)
class TransformerEncoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
activation="relu", normalize_before=False):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = DropPath(dropout)
self.dropout2 = DropPath(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(self,
src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None):
q = k = self.with_pos_embed(src, pos)
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
key_padding_mask=src_key_padding_mask)[0]
src = src + self.dropout1(src2)
src = self.norm1(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = src + self.dropout2(src2)
src = self.norm2(src)
return src
def forward_pre(self, src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None):
pass
def forward(self, src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None):
if self.normalize_before:
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
class TransformerDecoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
activation="relu", normalize_before=False, keep_query_pos=False,
rm_self_attn_decoder=False):
super().__init__()
# Decoder Self-Attention
if not rm_self_attn_decoder:
self.sa_qcontent_proj = nn.Linear(d_model, d_model)
self.sa_qpos_proj = nn.Linear(d_model, d_model)
self.sa_kcontent_proj = nn.Linear(d_model, d_model)
self.sa_kpos_proj = nn.Linear(d_model, d_model)
self.sa_v_proj = nn.Linear(d_model, d_model)
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, vdim=d_model)
self.norm1 = nn.LayerNorm(d_model)
self.dropout1 = DropPath(dropout)
# Decoder Cross-Attention
self.ca_qcontent_proj = nn.Linear(d_model, d_model)
self.ca_qpos_proj = nn.Linear(d_model, d_model)
self.ca_kcontent_proj = nn.Linear(d_model, d_model)
self.ca_kpos_proj = nn.Linear(d_model, d_model)
self.ca_v_proj = nn.Linear(d_model, d_model)
self.ca_qpos_sine_proj = nn.Linear(d_model, d_model)
self.cross_attn = MultiheadAttention(d_model * 2, nhead, dropout=dropout, vdim=d_model)
self.nhead = nhead
self.rm_self_attn_decoder = rm_self_attn_decoder
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout2 = DropPath(dropout)
self.dropout3 = DropPath(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
self.keep_query_pos = keep_query_pos
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward(self, tgt, memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
query_sine_embed=None,
is_first=False):
# ========== Begin of Self-Attention =============
if not self.rm_self_attn_decoder:
# Apply projections here
# shape: num_queries x batch_size x 256
q_content = self.sa_qcontent_proj(tgt) # target is the input of the first decoder layer. zero by default.
q_pos = self.sa_qpos_proj(query_pos)
k_content = self.sa_kcontent_proj(tgt)
k_pos = self.sa_kpos_proj(query_pos)
v = self.sa_v_proj(tgt)
num_queries, bs, n_model = q_content.shape
hw, _, _ = k_content.shape
q = q_content + q_pos
k = k_content + k_pos
tgt2 = self.self_attn(q, k, value=v, attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask)[0]
# ========== End of Self-Attention =============
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
# ========== Begin of Cross-Attention =============
# Apply projections here
# shape: num_queries x batch_size x 256
q_content = self.ca_qcontent_proj(tgt)
k_content = self.ca_kcontent_proj(memory)
v = self.ca_v_proj(memory)
num_queries, bs, n_model = q_content.shape
hw, _, _ = k_content.shape
k_pos = self.ca_kpos_proj(pos)
# For the first decoder layer, we concatenate the positional embedding predicted from
# the object query (the positional embedding) into the original query (key) in DETR.
if is_first or self.keep_query_pos:
q_pos = self.ca_qpos_proj(query_pos)
q = q_content + q_pos
k = k_content + k_pos
else:
q = q_content
k = k_content
q = q.view(num_queries, bs, self.nhead, n_model // self.nhead)
query_sine_embed = self.ca_qpos_sine_proj(query_sine_embed)
query_sine_embed = query_sine_embed.view(num_queries, bs, self.nhead, n_model // self.nhead)
q = torch.cat([q, query_sine_embed], dim=3).view(num_queries, bs, n_model * 2)
k = k.view(hw, bs, self.nhead, n_model // self.nhead)
k_pos = k_pos.view(hw, bs, self.nhead, n_model // self.nhead)
k = torch.cat([k, k_pos], dim=3).view(hw, bs, n_model * 2)
tgt2 = self.cross_attn(query=q,
key=k,
value=v, attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask)[0]
# ========== End of Cross-Attention =============
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout3(tgt2)
tgt = self.norm3(tgt)
return tgt
class TransformerDecoderLayerThin(nn.Module):
"""removed intermediate layer"""
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
activation="relu", normalize_before=False):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
# self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = DropPath(dropout)
self.dropout2 = DropPath(dropout)
# self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(self, tgt, memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
q = k = self.with_pos_embed(tgt, query_pos)
tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory, attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask)[0]
tgt2 = self.linear1(tgt2)
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
return tgt
def forward_pre(self, tgt, memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
tgt2 = self.norm1(tgt)
q = k = self.with_pos_embed(tgt2, query_pos)
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout1(tgt2)
tgt2 = self.norm2(tgt)
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory, attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask)[0]
tgt = tgt + self.dropout2(tgt2)
tgt2 = self.norm3(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout3(tgt2)
return tgt
def forward(self, tgt, memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
if self.normalize_before:
return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
return self.forward_post(tgt, memory, tgt_mask, memory_mask,
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
def build_transformer(args):
return Transformer(
d_model=args.hidden_dim,
dropout=args.dropout,
nhead=args.nheads,
dim_feedforward=args.dim_feedforward,
num_encoder_layers=args.enc_layers,
num_decoder_layers=args.dec_layers,
normalize_before=args.pre_norm,
return_intermediate_dec=True,
activation='prelu',
args=args
)
def drop_path(x, drop_prob=0.0, training=False):
"""
Stochastic Depth per sample.
"""
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
mask = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
mask.floor_()
x = x.div(keep_prob) * mask
return x
class DropPath(nn.Module):
"""
Drop paths per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
x = x.permute(1, 0, 2)
res = drop_path(x, self.drop_prob, self.training)
return res.permute(1, 0, 2)
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
if activation == "prelu":
return nn.PReLU()
if activation == "selu":
return F.selu
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")