From 56107dd324bea4f4d3e56d4e95f771541b0b7bb1 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 24 May 2022 10:42:50 +0100 Subject: [PATCH 01/26] Adding implementation from pytorch video --- torchvision/models/video/mvit.py | 1204 ++++++++++++++++++++++++++++++ 1 file changed, 1204 insertions(+) create mode 100644 torchvision/models/video/mvit.py diff --git a/torchvision/models/video/mvit.py b/torchvision/models/video/mvit.py new file mode 100644 index 00000000000..63e056afd23 --- /dev/null +++ b/torchvision/models/video/mvit.py @@ -0,0 +1,1204 @@ +import math +from functools import partial +from typing import Callable, List, Optional, Tuple + +import numpy +import torch +import torch.nn as nn +from torch.nn.common_types import _size_2_t, _size_3_t +from torchvision.ops import StochasticDepth + + +__all__ = ["create_mvit_b_16", "create_multiscale_vision_transformers"] + + +class Mlp(nn.Module): + """ + A MLP block that contains two linear layers with a normalization layer. The MLP + block is used in a transformer model after the attention block. + + :: + + Linear (in_features, hidden_features) + ↓ + Normalization (act_layer) + ↓ + Dropout (p=dropout_rate) + ↓ + Linear (hidden_features, out_features) + ↓ + Dropout (p=dropout_rate) + """ + + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable = nn.GELU, + dropout_rate: float = 0.0, + bias_on: bool = True, + ) -> None: + """ + Args: + in_features (int): Input feature dimension. + hidden_features (Optional[int]): Hidden feature dimension. By default, + hidden feature is set to input feature dimension. + out_features (Optional[int]): Output feature dimension. By default, output + features dimension is set to input feature dimension. + act_layer (Callable): Activation layer used after the first linear layer. + dropout_rate (float): Dropout rate after each linear layer. Dropout is not used + by default. + """ + super().__init__() + self.dropout_rate = dropout_rate + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias_on) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias_on) + if self.dropout_rate > 0.0: + self.dropout = nn.Dropout(dropout_rate) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (tensor): Input tensor. + """ + x = self.fc1(x) + x = self.act(x) + if self.dropout_rate > 0.0: + x = self.dropout(x) + x = self.fc2(x) + if self.dropout_rate > 0.0: + x = self.dropout(x) + return x + + +def _attention_pool( + tensor: torch.Tensor, + pool: Optional[Callable], + thw_shape: List[int], + norm: Optional[Callable] = None, +) -> torch.Tensor: + """ + Apply pool to a flattened input (given pool operation and the unflattened shape). + + + Input + ↓ + Reshape + ↓ + Pool + ↓ + Reshape + ↓ + Norm + + + Args: + tensor (torch.Tensor): Input tensor. + pool (Optional[Callable]): Pool operation that is applied to the input tensor. + If pool is none, return the input tensor. + thw_shape (List): The shape of the input tensor (before flattening). + norm: (Optional[Callable]): Optional normalization operation applied to + tensor after pool. + + Returns: + tensor (torch.Tensor): Input tensor after pool. + thw_shape (List[int]): Output tensor shape (before flattening). + """ + if pool is None: + return tensor, thw_shape + tensor_dim = tensor.ndim + if tensor_dim == 4: + pass + elif tensor_dim == 3: + tensor = tensor.unsqueeze(1) + else: + raise NotImplementedError(f"Unsupported input dimension {tensor.shape}") + + cls_tok, tensor = tensor[:, :, :1, :], tensor[:, :, 1:, :] + + B, N, L, C = tensor.shape + T, H, W = thw_shape + tensor = tensor.reshape(B * N, T, H, W, C).permute(0, 4, 1, 2, 3).contiguous() + + if isinstance(norm, (nn.BatchNorm3d, nn.Identity)): + # If use BN, we apply norm before pooling instead of after pooling. + tensor = norm(tensor) + # We also empirically find that adding a GELU here is beneficial. + tensor = nn.functional.gelu(tensor) + + tensor = pool(tensor) + + thw_shape = [tensor.shape[2], tensor.shape[3], tensor.shape[4]] + L_pooled = tensor.shape[2] * tensor.shape[3] * tensor.shape[4] + tensor = tensor.reshape(B, N, C, L_pooled).transpose(2, 3) + + tensor = torch.cat((cls_tok, tensor), dim=2) + if norm is not None and not isinstance(norm, nn.BatchNorm3d): + tensor = norm(tensor) + + if tensor_dim == 4: + pass + else: # For the case tensor_dim == 3. + tensor = tensor.squeeze(1) + return tensor, thw_shape + + +class MultiScaleAttention(nn.Module): + """ + Implementation of a multiscale attention block. Compare to a conventional attention + block, a multiscale attention block optionally supports pooling (either + before or after qkv projection). If pooling is not used, a multiscale attention + block is equivalent to a conventional attention block. + + :: + Input + | + |----------------|-----------------| + ↓ ↓ ↓ + Linear Linear Linear + & & & + Pool (Q) Pool (K) Pool (V) + → -------------- ← | + ↓ | + MatMul & Scale | + ↓ | + Softmax | + → ----------------------- ← + ↓ + MatMul & Scale + ↓ + DropOut + """ + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + dropout_rate: float = 0.0, + kernel_q: _size_3_t = (1, 1, 1), + kernel_kv: _size_3_t = (1, 1, 1), + stride_q: _size_3_t = (1, 1, 1), + stride_kv: _size_3_t = (1, 1, 1), + norm_layer: Callable = nn.LayerNorm, + depthwise_conv: bool = True, + bias_on: bool = True, + ) -> None: + """ + Args: + dim (int): Input feature dimension. + num_heads (int): Number of heads in the attention layer. + qkv_bias (bool): If set to False, the qkv layer will not learn an additive + bias. Default: False. + dropout_rate (float): Dropout rate. + kernel_q (_size_3_t): Pooling kernel size for q. If both pooling kernel + size and pooling stride size are 1 for all the dimensions, pooling is + disabled. + kernel_kv (_size_3_t): Pooling kernel size for kv. If both pooling kernel + size and pooling stride size are 1 for all the dimensions, pooling is + disabled. + stride_q (_size_3_t): Pooling kernel stride for q. + stride_kv (_size_3_t): Pooling kernel stride for kv. + norm_layer (nn.Module): Normalization layer used after pooling. + depthwise_conv (bool): Wether use depthwise or full convolution for pooling. + bias_on (bool): Wether use biases for linear layers. + """ + + super().__init__() + + self.dropout_rate = dropout_rate + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + padding_q = [int(q // 2) for q in kernel_q] + padding_kv = [int(kv // 2) for kv in kernel_kv] + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim, bias=True if bias_on else False) + if dropout_rate > 0.0: + self.proj_drop = nn.Dropout(dropout_rate) + + # Skip pooling with kernel and stride size of (1, 1, 1). + if kernel_q is not None and numpy.prod(kernel_q) == 1 and numpy.prod(stride_q) == 1: + kernel_q = None + if kernel_kv is not None and numpy.prod(kernel_kv) == 1 and numpy.prod(stride_kv) == 1: + kernel_kv = None + + self.pool_q = ( + nn.Conv3d( + head_dim, + head_dim, + kernel_q, + stride=stride_q, + padding=padding_q, + groups=head_dim if depthwise_conv else 1, + bias=False, + ) + if kernel_q is not None + else None + ) + self.norm_q = norm_layer(head_dim) if kernel_q is not None else None + self.pool_k = ( + nn.Conv3d( + head_dim, + head_dim, + kernel_kv, + stride=stride_kv, + padding=padding_kv, + groups=head_dim if depthwise_conv else 1, + bias=False, + ) + if kernel_kv is not None + else None + ) + self.norm_k = norm_layer(head_dim) if kernel_kv is not None else None + self.pool_v = ( + nn.Conv3d( + head_dim, + head_dim, + kernel_kv, + stride=stride_kv, + padding=padding_kv, + groups=head_dim if depthwise_conv else 1, + bias=False, + ) + if kernel_kv is not None + else None + ) + self.norm_v = norm_layer(head_dim) if kernel_kv is not None else None + + def _qkv_proj( + self, + q: torch.Tensor, + q_size: List[int], + k: torch.Tensor, + k_size: List[int], + v: torch.Tensor, + v_size: List[int], + batch_size: List[int], + chan_size: List[int], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + q = self.q(q).reshape(batch_size, q_size, self.num_heads, chan_size // self.num_heads).permute(0, 2, 1, 3) + k = self.k(k).reshape(batch_size, k_size, self.num_heads, chan_size // self.num_heads).permute(0, 2, 1, 3) + v = self.v(v).reshape(batch_size, v_size, self.num_heads, chan_size // self.num_heads).permute(0, 2, 1, 3) + return q, k, v + + def _qkv_pool( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + thw_shape: Tuple[torch.Tensor, List[int]], + ) -> Tuple[torch.Tensor, List[int], torch.Tensor, List[int], torch.Tensor, List[int]]: + q, q_shape = _attention_pool( + q, + self.pool_q, + thw_shape, + norm=self.norm_q if hasattr(self, "norm_q") else None, + ) + k, k_shape = _attention_pool( + k, + self.pool_k, + thw_shape, + norm=self.norm_k if hasattr(self, "norm_k") else None, + ) + v, v_shape = _attention_pool( + v, + self.pool_v, + thw_shape, + norm=self.norm_v if hasattr(self, "norm_v") else None, + ) + return q, q_shape, k, k_shape, v, v_shape + + def _get_qkv_length( + self, + q_shape: List[int], + k_shape: List[int], + v_shape: List[int], + ) -> Tuple[int]: + q_N = numpy.prod(q_shape) + 1 + k_N = numpy.prod(k_shape) + 1 + v_N = numpy.prod(v_shape) + 1 + return q_N, k_N, v_N + + def _reshape_qkv_to_seq( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + q_N: int, + v_N: int, + k_N: int, + B: int, + C: int, + ) -> Tuple[int]: + q = q.permute(0, 2, 1, 3).reshape(B, q_N, C) + v = v.permute(0, 2, 1, 3).reshape(B, v_N, C) + k = k.permute(0, 2, 1, 3).reshape(B, k_N, C) + return q, k, v + + def forward(self, x: torch.Tensor, thw_shape: List[int]) -> Tuple[torch.Tensor, List[int]]: + """ + Args: + x (torch.Tensor): Input tensor. + thw_shape (List): The shape of the input tensor (before flattening). + """ + + B, N, C = x.shape + + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + q, q_shape, k, k_shape, v, v_shape = self._qkv_pool(q, k, v, thw_shape) + + attn = (q * self.scale) @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + + N = q.shape[2] + + x = (attn @ v + q).transpose(1, 2).reshape(B, N, C) + + x = self.proj(x) + if self.dropout_rate > 0.0: + x = self.proj_drop(x) + return x, q_shape + + +class MultiScaleBlock(nn.Module): + """ + Implementation of a multiscale vision transformer block. Each block contains a + multiscale attention layer and a Mlp layer. + + :: + + + Input + |-------------------+ + ↓ | + Norm | + ↓ | + MultiScaleAttention Pool + ↓ | + DropPath | + ↓ | + Summation ←-------------+ + | + |-------------------+ + ↓ | + Norm | + ↓ | + Mlp Proj + ↓ | + DropPath | + ↓ | + Summation ←------------+ + """ + + def __init__( + self, + dim: int, + dim_out: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + dropout_rate: float = 0.0, + droppath_rate: float = 0.0, + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.LayerNorm, + attn_norm_layer: nn.Module = nn.LayerNorm, + kernel_q: _size_3_t = (1, 1, 1), + kernel_kv: _size_3_t = (1, 1, 1), + stride_q: _size_3_t = (1, 1, 1), + stride_kv: _size_3_t = (1, 1, 1), + depthwise_conv: bool = True, + bias_on: bool = True, + ) -> None: + """ + Args: + dim (int): Input feature dimension. + dim_out (int): Output feature dimension. + num_heads (int): Number of heads in the attention layer. + mlp_ratio (float): Mlp ratio which controls the feature dimension in the + hidden layer of the Mlp block. + qkv_bias (bool): If set to False, the qkv layer will not learn an additive + bias. Default: False. + dropout_rate (float): DropOut rate. If set to 0, DropOut is disabled. + droppath_rate (float): DropPath rate. If set to 0, DropPath is disabled. + act_layer (nn.Module): Activation layer used in the Mlp layer. + norm_layer (nn.Module): Normalization layer. + attn_norm_layer (nn.Module): Normalization layer in the attention module. + kernel_q (_size_3_t): Pooling kernel size for q. If pooling kernel size is + 1 for all the dimensions, pooling is not used (by default). + kernel_kv (_size_3_t): Pooling kernel size for kv. If pooling kernel size + is 1 for all the dimensions, pooling is not used. By default, pooling + is disabled. + stride_q (_size_3_t): Pooling kernel stride for q. + stride_kv (_size_3_t): Pooling kernel stride for kv. + has_cls_embed (bool): If set to True, the first token of the input tensor + should be a cls token. Otherwise, the input tensor does not contain a + cls token. Pooling is not applied to the cls token. + depthwise_conv (bool): Wether use depthwise or full convolution for pooling. + bias_on (bool): Wether use biases for linear layers. + """ + super().__init__() + self.dim = dim + self.dim_out = dim_out + self.norm1 = norm_layer(dim) + kernel_skip = [s + 1 if s > 1 else s for s in stride_q] + stride_skip = stride_q + padding_skip = [int(skip // 2) for skip in kernel_skip] + self.attn = MultiScaleAttention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + dropout_rate=dropout_rate, + kernel_q=kernel_q, + kernel_kv=kernel_kv, + stride_q=stride_q, + stride_kv=stride_kv, + norm_layer=attn_norm_layer, + bias_on=bias_on, + depthwise_conv=depthwise_conv, + ) + self.drop_path = StochasticDepth(droppath_rate, "row") if droppath_rate > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + out_features=dim_out, + act_layer=act_layer, + dropout_rate=dropout_rate, + bias_on=bias_on, + ) + if dim != dim_out: + self.proj = nn.Linear(dim, dim_out, bias=bias_on) + + self.pool_skip = ( + nn.MaxPool3d(kernel_skip, stride_skip, padding_skip, ceil_mode=False) + if len(stride_skip) > 0 and numpy.prod(stride_skip) > 1 + else None + ) + + def forward(self, x: torch.Tensor, thw_shape: List[int]) -> Tuple[torch.Tensor, List[int]]: + """ + Args: + x (torch.Tensor): Input tensor. + thw_shape (List): The shape of the input tensor (before flattening). + """ + + x_block, thw_shape_new = self.attn( + ( + self.norm1(x.permute(0, 2, 1)).permute(0, 2, 1) + if isinstance(self.norm1, nn.BatchNorm1d) + else self.norm1(x) + ), + thw_shape, + ) + x_res, _ = _attention_pool(x, self.pool_skip, thw_shape) + x = x_res + self.drop_path(x_block) + x_norm = ( + self.norm2(x.permute(0, 2, 1)).permute(0, 2, 1) if isinstance(self.norm2, nn.BatchNorm1d) else self.norm2(x) + ) + x_mlp = self.mlp(x_norm) + if self.dim != self.dim_out: + x = self.proj(x_norm) + x = x + self.drop_path(x_mlp) + return x, thw_shape_new + + +class SpatioTemporalClsPositionalEncoding(nn.Module): + """ + Add a cls token and apply a spatiotemporal encoding to a tensor. + """ + + def __init__( + self, + embed_dim: int, + patch_embed_shape: Tuple[int, int, int], + ) -> None: + """ + Args: + embed_dim (int): Embedding dimension for input sequence. + patch_embed_shape (Tuple): The number of patches in each dimension + (T, H, W) after patch embedding. + """ + super().__init__() + assert len(patch_embed_shape) == 3, "Patch_embed_shape should be in the form of (T, H, W)." + self._patch_embed_shape = patch_embed_shape + self.num_spatial_patch = patch_embed_shape[1] * patch_embed_shape[2] + self.num_temporal_patch = patch_embed_shape[0] + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + + self.pos_embed_spatial = nn.Parameter(torch.zeros(1, self.num_spatial_patch, embed_dim)) + self.pos_embed_temporal = nn.Parameter(torch.zeros(1, self.num_temporal_patch, embed_dim)) + self.pos_embed_class = nn.Parameter(torch.zeros(1, 1, embed_dim)) + + @property + def patch_embed_shape(self): + return self._patch_embed_shape + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): Input tensor. + """ + B, N, C = x.shape + + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + pos_embed = self.pos_embed_spatial.repeat(1, self.num_temporal_patch, 1) + torch.repeat_interleave( + self.pos_embed_temporal, + self.num_spatial_patch, + dim=1, + ) + pos_embed = torch.cat([self.pos_embed_class, pos_embed], 1) + x = x + pos_embed + + return x + + +def c2_xavier_fill(module: nn.Module) -> None: + """ + Initialize `module.weight` using the "XavierFill" implemented in Caffe2. + Also initializes `module.bias` to 0. + + Args: + module (torch.nn.Module): module to initialize. + """ + # Caffe2 implementation of XavierFill in fact + # corresponds to kaiming_uniform_ in PyTorch + nn.init.kaiming_uniform_(module.weight, a=1) + if module.bias is not None: + # pyre-fixme[6]: Expected `Tensor` for 1st param but got `Union[nn.Module, + # torch.Tensor]`. + nn.init.constant_(module.bias, 0) + + +def c2_msra_fill(module: nn.Module) -> None: + """ + Initialize `module.weight` using the "MSRAFill" implemented in Caffe2. + Also initializes `module.bias` to 0. + + Args: + module (torch.nn.Module): module to initialize. + """ + nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") + if module.bias is not None: + # pyre-fixme[6]: Expected `Tensor` for 1st param but got `Union[nn.Module, + # torch.Tensor]`. + nn.init.constant_(module.bias, 0) + + +def set_attributes(self, params: List[object] = None) -> None: + """ + An utility function used in classes to set attributes from the input list of parameters. + Args: + params (list): list of parameters. + """ + if params: + for k, v in params.items(): + if k != "self": + setattr(self, k, v) + + +def round_width(width, multiplier, min_width=8, divisor=8, ceil=False): + """ + Round width of filters based on width multiplier + Args: + width (int): the channel dimensions of the input. + multiplier (float): the multiplication factor. + min_width (int): the minimum width after multiplication. + divisor (int): the new width should be dividable by divisor. + ceil (bool): If True, use ceiling as the rounding method. + """ + if not multiplier: + return width + + width *= multiplier + min_width = min_width or divisor + if ceil: + width_out = max(min_width, int(math.ceil(width / divisor)) * divisor) + else: + width_out = max(min_width, int(width + divisor / 2) // divisor * divisor) + if width_out < 0.9 * width: + width_out += divisor + return int(width_out) + + +def round_repeats(repeats, multiplier): + """ + Round number of layers based on depth multiplier. + """ + if not multiplier: + return repeats + return int(math.ceil(multiplier * repeats)) + + +class SequencePool(nn.Module): + """ + Sequence pool produces a single embedding from a sequence of embeddings. Currently + it supports "mean" and "cls". + + """ + + def __init__(self) -> None: + + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x[:, 0] + return x + + +class create_vit_basic_head(nn.Module): + def __init__( + self, + # Projection configs. + in_features: int, + out_features: int, + # Pooling configs. + seq_pool_type: str = "cls", + # Dropout configs. + dropout_rate: float = 0.5, + # Activation configs. + activation: Callable = None, + ) -> nn.Module: + """ + Creates vision transformer basic head. + + :: + + + Pooling + ↓ + Dropout + ↓ + Projection + ↓ + Activation + + + Activation examples include: ReLU, Softmax, Sigmoid, and None. + Pool type examples include: cls, mean and none. + + Args: + + in_features: input channel size of the resnet head. + out_features: output channel size of the resnet head. + + pool_type (str): Pooling type. It supports "cls", "mean " and "none". If set to + "cls", it assumes the first element in the input is the cls token and + returns it. If set to "mean", it returns the mean of the entire sequence. + + activation (callable): a callable that constructs vision transformer head + activation layer, examples include: nn.ReLU, nn.Softmax, nn.Sigmoid, and + None (not applying activation). + + dropout_rate (float): dropout rate. + """ + super().__init__() + self.dropout = nn.Dropout(dropout_rate) if dropout_rate > 0.0 else None + self.proj = nn.Linear(in_features, out_features) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Pick cls embedding + x = x[:, 0] + # Performs dropout. + if self.dropout is not None: + x = self.dropout(x) + # Performs projection. + x = self.proj(x) + return x + + +class PatchEmbed(nn.Module): + """ + Transformer basic patch embedding module. Performs patchifying input, flatten and + and transpose. + + :: + + PatchModel + ↓ + flatten + ↓ + transpose + + The builder can be found in `create_patch_embed`. + + """ + + def __init__( + self, + patch_model: nn.Module = None, + ) -> None: + super().__init__() + set_attributes(self, locals()) + assert self.patch_model is not None + + def forward(self, x) -> torch.Tensor: + x = self.patch_model(x) + # B C (T) H W -> B (T)HW C + return x.flatten(2).transpose(1, 2) + + +def create_conv_patch_embed( + in_channels: int, + out_channels: int, + conv_kernel_size: Tuple[int] = (1, 16, 16), + conv_stride: Tuple[int] = (1, 4, 4), + conv_padding: Tuple[int] = (1, 7, 7), + conv_bias: bool = True, +) -> nn.Module: + """ + Creates the transformer basic patch embedding. It performs Convolution, flatten and + transpose. + + :: + + Conv3d + ↓ + flatten + ↓ + transpose + + Args: + in_channels (int): input channel size of the convolution. + out_channels (int): output channel size of the convolution. + conv_kernel_size (tuple): convolutional kernel size(s). + conv_stride (tuple): convolutional stride size(s). + conv_padding (tuple): convolutional padding size(s). + conv_bias (bool): convolutional bias. If true, adds a learnable bias to the + output. + conv (callable): Callable used to build the convolution layer. + + Returns: + (nn.Module): transformer patch embedding layer. + """ + conv_module = nn.Conv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=conv_kernel_size, + stride=conv_stride, + padding=conv_padding, + bias=conv_bias, + ) + return PatchEmbed(patch_model=conv_module) + + +def _init_resnet_weights(model: nn.Module, fc_init_std: float = 0.01) -> None: + """ + Performs ResNet style weight initialization. That is, recursively initialize the + given model in the following way for each type: + Conv - Follow the initialization of kaiming_normal: + https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_ + BatchNorm - Set weight and bias of last BatchNorm at every residual bottleneck + to 0. + Linear - Set weight to 0 mean Gaussian with std deviation fc_init_std and bias + to 0. + Args: + model (nn.Module): Model to be initialized. + fc_init_std (float): the expected standard deviation for fully-connected layer. + """ + for m in model.modules(): + if isinstance(m, (nn.Conv2d, nn.Conv3d)): + """ + Follow the initialization method proposed in: + {He, Kaiming, et al. + "Delving deep into rectifiers: Surpassing human-level + performance on imagenet classification." + arXiv preprint arXiv:1502.01852 (2015)} + """ + c2_msra_fill(m) + elif isinstance(m, nn.modules.batchnorm._NormBase): + if m.weight is not None: + if hasattr(m, "block_final_bn") and m.block_final_bn: + m.weight.data.fill_(0.0) + else: + m.weight.data.fill_(1.0) + if m.bias is not None: + m.bias.data.zero_() + if isinstance(m, nn.Linear): + if hasattr(m, "xavier_init") and m.xavier_init: + c2_xavier_fill(m) + else: + m.weight.data.normal_(mean=0.0, std=fc_init_std) + if m.bias is not None: + m.bias.data.zero_() + return model + + +def _init_vit_weights(model: nn.Module, trunc_normal_std: float = 0.02) -> None: + """ + Weight initialization for vision transformers. + + Args: + model (nn.Module): Model to be initialized. + trunc_normal_std (float): the expected standard deviation for fully-connected + layer and ClsPositionalEncoding. + """ + for m in model.modules(): + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, std=trunc_normal_std) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, SpatioTemporalClsPositionalEncoding): + for weights in m.parameters(): + nn.init.trunc_normal_(weights, std=trunc_normal_std) + + +def init_net_weights( + model: nn.Module, + init_std: float = 0.01, + style: str = "resnet", +) -> None: + """ + Performs weight initialization. Options include ResNet style weight initialization + and transformer style weight initialization. + + Args: + model (nn.Module): Model to be initialized. + init_std (float): The expected standard deviation for initialization. + style (str): Options include "resnet" and "vit". + """ + assert style in ["resnet", "vit"] + if style == "resnet": + return _init_resnet_weights(model, init_std) + elif style == "vit": + return _init_vit_weights(model, init_std) + else: + raise NotImplementedError + + +class MultiscaleVisionTransformers(nn.Module): + """ + Multiscale Vision Transformers + Haoqi Fan, Bo Xiong, Karttikeya Mangalam, Yanghao Li, Zhicheng Yan, Jitendra Malik, + Christoph Feichtenhofer + https://arxiv.org/abs/2104.11227 + + :: + + PatchEmbed + ↓ + PositionalEncoding + ↓ + Dropout + ↓ + Normalization + ↓ + Block 1 + ↓ + . + . + . + ↓ + Block N + ↓ + Normalization + ↓ + Head + + + The builder can be found in `create_mvit`. + """ + + def __init__( + self, + patch_embed: Optional[nn.Module], + cls_positional_encoding: nn.Module, + pos_drop: Optional[nn.Module], + blocks: nn.ModuleList, + norm_embed: Optional[nn.Module], + head: Optional[nn.Module], + ) -> None: + """ + Args: + patch_embed (nn.Module): Patch embed module. + cls_positional_encoding (nn.Module): Positional encoding module. + pos_drop (Optional[nn.Module]): Dropout module after patch embed. + blocks (nn.ModuleList): Stack of multi-scale transformer blocks. + norm_layer (nn.Module): Normalization layer before head. + head (Optional[nn.Module]): Head module. + """ + super().__init__() + set_attributes(self, locals()) + assert hasattr( + cls_positional_encoding, "patch_embed_shape" + ), "cls_positional_encoding should have attribute patch_embed_shape." + init_net_weights(self, init_std=0.02, style="vit") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.patch_embed is not None: + x = self.patch_embed(x) + x = self.cls_positional_encoding(x) + + if self.pos_drop is not None: + x = self.pos_drop(x) + + thw = self.cls_positional_encoding.patch_embed_shape + for blk in self.blocks: + x, thw = blk(x, thw) + if self.norm_embed is not None: + x = self.norm_embed(x) + if self.head is not None: + x = self.head(x) + return x + + +def create_multiscale_vision_transformers( + spatial_size: _size_2_t, + temporal_size: int, + depth: int = 16, + norm: str = "layernorm", + # Patch embed config. + input_channels: int = 3, + patch_embed_dim: int = 96, + conv_patch_embed_kernel: Tuple[int] = (3, 7, 7), + conv_patch_embed_stride: Tuple[int] = (2, 4, 4), + conv_patch_embed_padding: Tuple[int] = (1, 3, 3), + enable_patch_embed_norm: bool = False, + # Attention block config. + num_heads: int = 1, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + dropout_rate_block: float = 0.0, + droppath_rate_block: float = 0.0, + depthwise_conv: bool = True, + bias_on: bool = True, + embed_dim_mul: Optional[List[List[int]]] = ([1, 2.0], [3, 2.0], [14, 2.0]), + atten_head_mul: Optional[List[List[int]]] = ([1, 2.0], [3, 2.0], [14, 2.0]), + pool_q_stride_size: Optional[List[List[int]]] = ([1, 1, 2, 2], [3, 1, 2, 2], [14, 1, 2, 2]), + pool_kv_stride_size: Optional[List[List[int]]] = None, + pool_kv_stride_adaptive: Optional[_size_3_t] = (1, 8, 8), + pool_kvq_kernel: Optional[_size_3_t] = (3, 3, 3), + # Head config. + head: Optional[Callable] = create_vit_basic_head, + head_dropout_rate: float = 0.5, + head_activation: Callable = None, + num_classes: int = 400, + **kwargs, +) -> nn.Module: + """ + Build Multiscale Vision Transformers (MViT) for recognition. A Vision Transformer + (ViT) is a specific case of MViT that only uses a single scale attention block. + + Args: + spatial_size (_size_2_t): Input video spatial resolution (H, W). If a single + int is given, it assumes the width and the height are the same. + temporal_size (int): Number of frames in the input video. + depth (int): The depth of the model. + norm (str): Normalization layer. It currently supports "layernorm". + + input_channels (int): Channel dimension of the input video. + patch_embed_dim (int): Embedding dimension after patchifing the video input. + conv_patch_embed_kernel (Tuple[int]): Kernel size of the convolution for + patchifing the video input. + conv_patch_embed_stride (Tuple[int]): Stride size of the convolution for + patchifing the video input. + conv_patch_embed_padding (Tuple[int]): Padding size of the convolution for + patchifing the video input. + enable_patch_embed_norm (bool): If True, apply normalization after patchifing + the video input. + + num_heads (int): Number of heads in the first transformer block. + mlp_ratio (float): Mlp ratio which controls the feature dimension in the + hidden layer of the Mlp block. + qkv_bias (bool): If set to False, the qkv layer will not learn an additive + bias. Default: True. + dropout_rate_block (float): Dropout rate for the attention block. + droppath_rate_block (float): Droppath rate for the attention block. + depthwise_conv (bool): Wether use depthwise or full convolution for pooling. + bias_on (bool): Wether use biases for linear layers. + embed_dim_mul (Optional[List[List[int]]]): Dimension multiplication at layer i. + If X is used, then the next block will increase the embed dimension by X + times. Format: [depth_i, mul_dim_ratio]. + atten_head_mul (Optional[List[List[int]]]): Head dimension multiplication at + layer i. If X is used, then the next block will increase the head by + X times. Format: [depth_i, mul_dim_ratio]. + pool_q_stride_size (Optional[List[List[int]]]): List of stride sizes for the + pool q at each layer. Format: + [[i, stride_t_i, stride_h_i, stride_w_i], ...,]. + pool_kv_stride_size (Optional[List[List[int]]]): List of stride sizes for the + pool kv at each layer. Format: + [[i, stride_t_i, stride_h_i, stride_w_i], ...,]. + pool_kv_stride_adaptive (Optional[_size_3_t]): Initial kv stride size for the + first block. The stride size will be further reduced at the layer where q + is pooled with the ratio of the stride of q pooling. If + pool_kv_stride_adaptive is set, then pool_kv_stride_size should be none. + pool_kvq_kernel (Optional[_size_3_t]): Pooling kernel size for q and kv. It None, + the kernel_size is [s + 1 if s > 1 else s for s in stride_size]. + + head (Callable): Head model. + head_dropout_rate (float): Dropout rate in the head. + head_activation (Callable): Activation in the head. + num_classes (int): Number of classes in the final classification head. + + Example usage (building a MViT_B model for Kinetics400): + + spatial_size = 224 + temporal_size = 16 + embed_dim_mul = [[1, 2.0], [3, 2.0], [14, 2.0]] + atten_head_mul = [[1, 2.0], [3, 2.0], [14, 2.0]] + pool_q_stride_size = [[1, 1, 2, 2], [3, 1, 2, 2], [14, 1, 2, 2]] + pool_kv_stride_adaptive = [1, 8, 8] + pool_kvq_kernel = [3, 3, 3] + num_classes = 400 + MViT_B = create_multiscale_vision_transformers( + spatial_size=spatial_size, + temporal_size=temporal_size, + embed_dim_mul=embed_dim_mul, + atten_head_mul=atten_head_mul, + pool_q_stride_size=pool_q_stride_size, + pool_kv_stride_adaptive=pool_kv_stride_adaptive, + pool_kvq_kernel=pool_kvq_kernel, + num_classes=num_classes, + ) + """ + + if pool_kv_stride_adaptive is not None: + assert pool_kv_stride_size is None, "pool_kv_stride_size should be none if pool_kv_stride_adaptive is set." + norm_layer = partial(nn.LayerNorm, eps=1e-6) + block_norm_layer = partial(nn.LayerNorm, eps=1e-6) + attn_norm_layer = partial(nn.LayerNorm, eps=1e-6) + + if isinstance(spatial_size, int): + spatial_size = (spatial_size, spatial_size) + + patch_embed = create_conv_patch_embed( + in_channels=input_channels, + out_channels=patch_embed_dim, + conv_kernel_size=conv_patch_embed_kernel, + conv_stride=conv_patch_embed_stride, + conv_padding=conv_patch_embed_padding, + ) + + input_dims = [temporal_size, spatial_size[0], spatial_size[1]] + input_stirde = conv_patch_embed_stride + + patch_embed_shape = [input_dims[i] // input_stirde[i] for i in range(len(input_dims))] + + cls_positional_encoding = SpatioTemporalClsPositionalEncoding( + embed_dim=patch_embed_dim, + patch_embed_shape=patch_embed_shape, + ) + + dpr = [x.item() for x in torch.linspace(0, droppath_rate_block, depth)] # stochastic depth decay rule + + if dropout_rate_block > 0.0: + pos_drop = nn.Dropout(p=dropout_rate_block) + + dim_mul, head_mul = torch.ones(depth + 1), torch.ones(depth + 1) + if embed_dim_mul is not None: + for i in range(len(embed_dim_mul)): + dim_mul[embed_dim_mul[i][0]] = embed_dim_mul[i][1] + if atten_head_mul is not None: + for i in range(len(atten_head_mul)): + head_mul[atten_head_mul[i][0]] = atten_head_mul[i][1] + + mvit_blocks = nn.ModuleList() + + pool_q = [[] for i in range(depth)] + pool_kv = [[] for i in range(depth)] + stride_q = [[] for i in range(depth)] + stride_kv = [[] for i in range(depth)] + + if pool_q_stride_size is not None: + for i in range(len(pool_q_stride_size)): + stride_q[pool_q_stride_size[i][0]] = pool_q_stride_size[i][1:] + if pool_kvq_kernel is not None: + pool_q[pool_q_stride_size[i][0]] = pool_kvq_kernel + else: + pool_q[pool_q_stride_size[i][0]] = [s + 1 if s > 1 else s for s in pool_q_stride_size[i][1:]] + + # If POOL_KV_STRIDE_ADAPTIVE is not None, initialize POOL_KV_STRIDE. + if pool_kv_stride_adaptive is not None: + _stride_kv = pool_kv_stride_adaptive + pool_kv_stride_size = [] + for i in range(depth): + if len(stride_q[i]) > 0: + _stride_kv = [max(_stride_kv[d] // stride_q[i][d], 1) for d in range(len(_stride_kv))] + pool_kv_stride_size.append([i] + _stride_kv) + + if pool_kv_stride_size is not None: + for i in range(len(pool_kv_stride_size)): + stride_kv[pool_kv_stride_size[i][0]] = pool_kv_stride_size[i][1:] + if pool_kvq_kernel is not None: + pool_kv[pool_kv_stride_size[i][0]] = pool_kvq_kernel + else: + pool_kv[pool_kv_stride_size[i][0]] = [s + 1 if s > 1 else s for s in pool_kv_stride_size[i][1:]] + + for i in range(depth): + num_heads = round_width(num_heads, head_mul[i], min_width=1, divisor=1) + patch_embed_dim = round_width(patch_embed_dim, dim_mul[i], divisor=num_heads) + dim_out = round_width( + patch_embed_dim, + dim_mul[i + 1], + divisor=round_width(num_heads, head_mul[i + 1]), + ) + + block_func = MultiScaleBlock + + mvit_blocks.append( + block_func( + dim=patch_embed_dim, + dim_out=dim_out, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + dropout_rate=dropout_rate_block, + droppath_rate=dpr[i], + norm_layer=block_norm_layer, + attn_norm_layer=attn_norm_layer, + kernel_q=pool_q[i], + kernel_kv=pool_kv[i], + stride_q=stride_q[i], + stride_kv=stride_kv[i], + bias_on=bias_on, + depthwise_conv=depthwise_conv, + ) + ) + + embed_dim = dim_out + norm_embed = None if norm_layer is None else norm_layer(embed_dim) + if head is not None: + head_model = head( + in_features=embed_dim, + out_features=num_classes, + dropout_rate=head_dropout_rate, + activation=head_activation, + ) + else: + head_model = None + + return MultiscaleVisionTransformers( + patch_embed=patch_embed, + cls_positional_encoding=cls_positional_encoding, + pos_drop=pos_drop if dropout_rate_block > 0.0 else None, + blocks=mvit_blocks, + norm_embed=norm_embed, + head=head_model, + ) + + +def create_mvit_b_16( + spatial_size=224, + temporal_size=16, + num_classes=400, + **kwargs, +): + return create_multiscale_vision_transformers( + spatial_size=spatial_size, + temporal_size=temporal_size, + num_classes=num_classes, + **kwargs, + ) From 951c91cda146a7292f892c116b4041ed9fbb091c Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 24 May 2022 11:18:37 +0100 Subject: [PATCH 02/26] Validate against expected files on videos --- test/expect/ModelTester.test_mc3_18_expect.pkl | Bin 0 -> 939 bytes .../ModelTester.test_r2plus1d_18_expect.pkl | Bin 0 -> 939 bytes test/expect/ModelTester.test_r3d_18_expect.pkl | Bin 0 -> 939 bytes test/test_models.py | 17 +++++++++++++---- 4 files changed, 13 insertions(+), 4 deletions(-) create mode 100644 test/expect/ModelTester.test_mc3_18_expect.pkl create mode 100644 test/expect/ModelTester.test_r2plus1d_18_expect.pkl create mode 100644 test/expect/ModelTester.test_r3d_18_expect.pkl diff --git a/test/expect/ModelTester.test_mc3_18_expect.pkl b/test/expect/ModelTester.test_mc3_18_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..938c52160508ea183ef471aadb741222607fe326 GIT binary patch literal 939 zcmWIWW@cev;NW1u00Im`42ea_8JT6N`YDMeiFyUuIc`pT3{fbcfhoBpAE-(%zO*DW zr zf)S|3ppZF&8AvA=loqmh8ZvUemW=jY_4CYNO9=M{7L z7p0^YrKY%KCYNv(a%ct>a+VZw1r>7Z1$eV_Fj*X^nFTZrgadH;l#f9R#i#lPZcb`w z{zUOK5@*AUe%pihdH3pRp4q~p>Y#U#3+RFXz+3OI(w(n9XqumOHC3}BJ zNZ20IIJjG5`GvifizIFNH=MBjbn3yLo~;}MsiUoJ!*1IR^{qA!+Vgf-mh#$8S|Mc@#yDjU z--3mEzr2#&0}8FgIpzWGd!5SV~WvNBQz*ul|GAA;)kU|c^H0A=?d~sfS zC=<|D5DxHW1X1ubi5!OlAPE$JooXBrWYCp z0p4tEI#5M&%(`&ppu`LUFnT+L%P zf)S|3ppZF&8AvA=loqmh8ZvUemW=jY_4CYNO9=M{7L z7p0^YrKY%KCYNv(a%ct>a+VZw1r>7Z1$eV_Fj*X^nFTZrgadH;l#f9R#i#lPZcb`w z{zUOK62~j(w(X3A7xpel3fUvlRl8?P&+0vg7G&*_ebHmf71C#WUF(_6tXqrsEbE)M z_tyX8d-tC=vYDd2d+#caCAKb)jIF)OWo+NMZXt1a15Sy(*^;Mf(5hdG_(9hN? zvn2MOJu!9fYTwOUU-iDP^_SWf`?e)kxVqIpgYO8tft4&5Lw{7?H zLfa)VtM>%Avh8KKz1v#k&<)#b{>6JAJQTD2Aj7q{;)I^9z?lf!PZmkGCIyf7{zyM& zvnf(ykKv07TTp1Jiutvj0EQF@;|?u;25Wd|m8BLH17pF>$(+dGLJBzu)0hin^Tm1T zp-ezqK{&vh5k$e$Byt=IfFw`=dJ09?jqE2r6rHbtJY?Pa2IzW`UB!>0R|4olm|kcY z1bDNt=|C09G3&yWgAy|c!07D|F2f|SCqdq21LX|{PpAS=CJ69mWdn&Z10hH~L@fYc Ci}#2C literal 0 HcmV?d00001 diff --git a/test/expect/ModelTester.test_r3d_18_expect.pkl b/test/expect/ModelTester.test_r3d_18_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..11cf06bb1fcd3ead072daad144c5bdc8bd2c37ba GIT binary patch literal 939 zcmWIWW@cev;NW1u00Im`42ea_8JT6N`YDMeiFyUuIc`pT3{fbcfhoBpAE-(%zO*DW zr zf)S|3ppZF&8AvA=loqmh8ZvUemW=jY_4CYNO9=M{7L z7p0^YrKY%KCYNv(a%ct>a+VZw1r>7Z1$eV_Fj*X^nFTZrgadH;l#f9R#i#lPZcb`w z{zUOK66eJ-?!Am^JbNt9AJ}sv@yagVpq+cJ?J%@v{j6%c;vuKawVnTMT$E+4xlK*? zq=OKJeyg4uWa^N)a=o<(%SRw?;#t9#kY65 zEHv4(V27mD7oL4K532X>@o2cQr?FbfR>D@$_6uvJja*yc-ll6_dn5Mm+@<&AhE4a$ zS9_e6>e%?NoWFbDyB{`RH|N-Ry=}F*bmx(c!>lbkk9fVaX<%rvnox6ikAarw9#5v# zd+byv?fuWvWCIGV{0m;wP5?s+gmH%!KZ7+qw8~P8ih;4<=44J}a3O^pglWtLviah? z^iU?CtsorW%?P64X%aaO1waxg06m4G>qhnyABxUbKpwJgeFJp8$gbi?(JKLTAxtkc z3 Date: Tue, 24 May 2022 11:33:10 +0100 Subject: [PATCH 03/26] Plus tests for autocast --- test/test_models.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/test_models.py b/test/test_models.py index acdf232ec97..0acef4dcef6 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -848,6 +848,9 @@ def test_video_model(model_fn, dev): if dev == "cuda": with torch.cuda.amp.autocast(): out = model(x) + # See autocast_flaky_numerics comment at top of file. + if model_name not in autocast_flaky_numerics: + _assert_expected(out.cpu(), model_name, prec=0.1) assert out.shape[-1] == num_classes _check_input_backprop(model, x) From 693666a7cc42b07c7132d4e60e6f1e132879ff53 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 24 May 2022 11:56:31 +0100 Subject: [PATCH 04/26] Fix broken code and fx-traceability --- test/test_models.py | 3 +++ torchvision/models/video/__init__.py | 1 + torchvision/models/video/mvit.py | 21 ++++++++++++++------- 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/test/test_models.py b/test/test_models.py index 0acef4dcef6..f81491171d0 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -309,6 +309,9 @@ def _check_input_backprop(model, inputs): "image_size": 56, "input_shape": (1, 3, 56, 56), }, + "mvit_b_16": { + "input_shape": (1, 3, 16, 224, 224), + }, } # speeding up slow models: slow_models = [ diff --git a/torchvision/models/video/__init__.py b/torchvision/models/video/__init__.py index b792ca6ecf7..8990f64a1dc 100644 --- a/torchvision/models/video/__init__.py +++ b/torchvision/models/video/__init__.py @@ -1 +1,2 @@ +from .mvit import * from .resnet import * diff --git a/torchvision/models/video/mvit.py b/torchvision/models/video/mvit.py index 63e056afd23..574e5e95e9a 100644 --- a/torchvision/models/video/mvit.py +++ b/torchvision/models/video/mvit.py @@ -4,12 +4,13 @@ import numpy import torch +import torch.fx import torch.nn as nn from torch.nn.common_types import _size_2_t, _size_3_t from torchvision.ops import StochasticDepth -__all__ = ["create_mvit_b_16", "create_multiscale_vision_transformers"] +__all__ = ["mvit_b_16"] class Mlp(nn.Module): @@ -80,7 +81,7 @@ def _attention_pool( pool: Optional[Callable], thw_shape: List[int], norm: Optional[Callable] = None, -) -> torch.Tensor: +) -> Tuple[torch.Tensor, List[int]]: """ Apply pool to a flattened input (given pool operation and the unflattened shape). @@ -147,6 +148,9 @@ def _attention_pool( return tensor, thw_shape +torch.fx.wrap("_attention_pool") + + class MultiScaleAttention(nn.Module): """ Implementation of a multiscale attention block. Compare to a conventional attention @@ -292,7 +296,7 @@ def _qkv_pool( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - thw_shape: Tuple[torch.Tensor, List[int]], + thw_shape: List[int], ) -> Tuple[torch.Tensor, List[int], torch.Tensor, List[int], torch.Tensor, List[int]]: q, q_shape = _attention_pool( q, @@ -367,6 +371,9 @@ def forward(self, x: torch.Tensor, thw_shape: List[int]) -> Tuple[torch.Tensor, return x, q_shape + +torch.fx.wrap("_attention_pool") + class MultiScaleBlock(nn.Module): """ Implementation of a multiscale vision transformer block. Each block contains a @@ -667,7 +674,7 @@ def __init__( # Dropout configs. dropout_rate: float = 0.5, # Activation configs. - activation: Callable = None, + activation: Optional[Callable] = None, ) -> nn.Module: """ Creates vision transformer basic head. @@ -984,7 +991,7 @@ def create_multiscale_vision_transformers( # Head config. head: Optional[Callable] = create_vit_basic_head, head_dropout_rate: float = 0.5, - head_activation: Callable = None, + head_activation: Optional[Callable] = None, num_classes: int = 400, **kwargs, ) -> nn.Module: @@ -1127,7 +1134,7 @@ def create_multiscale_vision_transformers( for i in range(depth): if len(stride_q[i]) > 0: _stride_kv = [max(_stride_kv[d] // stride_q[i][d], 1) for d in range(len(_stride_kv))] - pool_kv_stride_size.append([i] + _stride_kv) + pool_kv_stride_size.append([i] + list(_stride_kv)) if pool_kv_stride_size is not None: for i in range(len(pool_kv_stride_size)): @@ -1190,7 +1197,7 @@ def create_multiscale_vision_transformers( ) -def create_mvit_b_16( +def mvit_b_16( spatial_size=224, temporal_size=16, num_classes=400, From 4941ac96de7fc7856b8991b4700399cb4c7bc05a Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 24 May 2022 12:02:56 +0100 Subject: [PATCH 05/26] Use TorchVision's MLP block. --- torchvision/models/video/mvit.py | 74 +------------------------------- 1 file changed, 2 insertions(+), 72 deletions(-) diff --git a/torchvision/models/video/mvit.py b/torchvision/models/video/mvit.py index 574e5e95e9a..787b05a78a4 100644 --- a/torchvision/models/video/mvit.py +++ b/torchvision/models/video/mvit.py @@ -7,75 +7,12 @@ import torch.fx import torch.nn as nn from torch.nn.common_types import _size_2_t, _size_3_t -from torchvision.ops import StochasticDepth +from torchvision.ops import StochasticDepth, MLP __all__ = ["mvit_b_16"] -class Mlp(nn.Module): - """ - A MLP block that contains two linear layers with a normalization layer. The MLP - block is used in a transformer model after the attention block. - - :: - - Linear (in_features, hidden_features) - ↓ - Normalization (act_layer) - ↓ - Dropout (p=dropout_rate) - ↓ - Linear (hidden_features, out_features) - ↓ - Dropout (p=dropout_rate) - """ - - def __init__( - self, - in_features: int, - hidden_features: Optional[int] = None, - out_features: Optional[int] = None, - act_layer: Callable = nn.GELU, - dropout_rate: float = 0.0, - bias_on: bool = True, - ) -> None: - """ - Args: - in_features (int): Input feature dimension. - hidden_features (Optional[int]): Hidden feature dimension. By default, - hidden feature is set to input feature dimension. - out_features (Optional[int]): Output feature dimension. By default, output - features dimension is set to input feature dimension. - act_layer (Callable): Activation layer used after the first linear layer. - dropout_rate (float): Dropout rate after each linear layer. Dropout is not used - by default. - """ - super().__init__() - self.dropout_rate = dropout_rate - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features, bias=bias_on) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features, bias=bias_on) - if self.dropout_rate > 0.0: - self.dropout = nn.Dropout(dropout_rate) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Args: - x (tensor): Input tensor. - """ - x = self.fc1(x) - x = self.act(x) - if self.dropout_rate > 0.0: - x = self.dropout(x) - x = self.fc2(x) - if self.dropout_rate > 0.0: - x = self.dropout(x) - return x - - def _attention_pool( tensor: torch.Tensor, pool: Optional[Callable], @@ -473,14 +410,7 @@ def __init__( self.drop_path = StochasticDepth(droppath_rate, "row") if droppath_rate > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp( - in_features=dim, - hidden_features=mlp_hidden_dim, - out_features=dim_out, - act_layer=act_layer, - dropout_rate=dropout_rate, - bias_on=bias_on, - ) + self.mlp = MLP(dim, [mlp_hidden_dim, dim_out], activation_layer=act_layer, dropout=dropout_rate, bias=bias_on, inplace=None) if dim != dim_out: self.proj = nn.Linear(dim, dim_out, bias=bias_on) From 0426e2ba610c6cded8c1f545e452da7cb06a35b9 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 24 May 2022 12:07:32 +0100 Subject: [PATCH 06/26] Replace @ with matmul --- torchvision/models/video/mvit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/models/video/mvit.py b/torchvision/models/video/mvit.py index 787b05a78a4..934cb6e82fe 100644 --- a/torchvision/models/video/mvit.py +++ b/torchvision/models/video/mvit.py @@ -295,12 +295,12 @@ def forward(self, x: torch.Tensor, thw_shape: List[int]) -> Tuple[torch.Tensor, q, k, v = qkv[0], qkv[1], qkv[2] q, q_shape, k, k_shape, v, v_shape = self._qkv_pool(q, k, v, thw_shape) - attn = (q * self.scale) @ k.transpose(-2, -1) + attn = torch.matmul(q * self.scale, k.transpose(-2, -1)) attn = attn.softmax(dim=-1) N = q.shape[2] - x = (attn @ v + q).transpose(1, 2).reshape(B, N, C) + x = (torch.matmul(attn, v) + q).transpose(1, 2).reshape(B, N, C) x = self.proj(x) if self.dropout_rate > 0.0: From 749c3d8c1f756884d233f1e1ae500a579d7c4e2c Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 24 May 2022 12:31:44 +0100 Subject: [PATCH 07/26] Clean up unused variables and methods, fixing typing annotations and general clean up. --- torchvision/models/video/mvit.py | 83 +++++++------------------------- 1 file changed, 18 insertions(+), 65 deletions(-) diff --git a/torchvision/models/video/mvit.py b/torchvision/models/video/mvit.py index 934cb6e82fe..8c1f7598d63 100644 --- a/torchvision/models/video/mvit.py +++ b/torchvision/models/video/mvit.py @@ -49,11 +49,9 @@ def _attention_pool( if pool is None: return tensor, thw_shape tensor_dim = tensor.ndim - if tensor_dim == 4: - pass - elif tensor_dim == 3: + if tensor_dim == 3: tensor = tensor.unsqueeze(1) - else: + elif tensor_dim != 4: raise NotImplementedError(f"Unsupported input dimension {tensor.shape}") cls_tok, tensor = tensor[:, :, :1, :], tensor[:, :, 1:, :] @@ -78,9 +76,7 @@ def _attention_pool( if norm is not None and not isinstance(norm, nn.BatchNorm3d): tensor = norm(tensor) - if tensor_dim == 4: - pass - else: # For the case tensor_dim == 3. + if tensor_dim == 3: tensor = tensor.squeeze(1) return tensor, thw_shape @@ -239,19 +235,19 @@ def _qkv_pool( q, self.pool_q, thw_shape, - norm=self.norm_q if hasattr(self, "norm_q") else None, + norm=self.norm_q, ) k, k_shape = _attention_pool( k, self.pool_k, thw_shape, - norm=self.norm_k if hasattr(self, "norm_k") else None, + norm=self.norm_k, ) v, v_shape = _attention_pool( v, self.pool_v, thw_shape, - norm=self.norm_v if hasattr(self, "norm_v") else None, + norm=self.norm_v, ) return q, q_shape, k, k_shape, v, v_shape @@ -260,7 +256,7 @@ def _get_qkv_length( q_shape: List[int], k_shape: List[int], v_shape: List[int], - ) -> Tuple[int]: + ) -> Tuple[int, int, int]: q_N = numpy.prod(q_shape) + 1 k_N = numpy.prod(k_shape) + 1 v_N = numpy.prod(v_shape) + 1 @@ -276,7 +272,7 @@ def _reshape_qkv_to_seq( k_N: int, B: int, C: int, - ) -> Tuple[int]: + ) -> Tuple[int, int, int]: q = q.permute(0, 2, 1, 3).reshape(B, q_N, C) v = v.permute(0, 2, 1, 3).reshape(B, v_N, C) k = k.permute(0, 2, 1, 3).reshape(B, k_N, C) @@ -455,7 +451,7 @@ class SpatioTemporalClsPositionalEncoding(nn.Module): def __init__( self, embed_dim: int, - patch_embed_shape: Tuple[int, int, int], + patch_embed_shape: List[int], ) -> None: """ Args: @@ -532,18 +528,6 @@ def c2_msra_fill(module: nn.Module) -> None: nn.init.constant_(module.bias, 0) -def set_attributes(self, params: List[object] = None) -> None: - """ - An utility function used in classes to set attributes from the input list of parameters. - Args: - params (list): list of parameters. - """ - if params: - for k, v in params.items(): - if k != "self": - setattr(self, k, v) - - def round_width(width, multiplier, min_width=8, divisor=8, ceil=False): """ Round width of filters based on width multiplier @@ -568,31 +552,6 @@ def round_width(width, multiplier, min_width=8, divisor=8, ceil=False): return int(width_out) -def round_repeats(repeats, multiplier): - """ - Round number of layers based on depth multiplier. - """ - if not multiplier: - return repeats - return int(math.ceil(multiplier * repeats)) - - -class SequencePool(nn.Module): - """ - Sequence pool produces a single embedding from a sequence of embeddings. Currently - it supports "mean" and "cls". - - """ - - def __init__(self) -> None: - - super().__init__() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = x[:, 0] - return x - - class create_vit_basic_head(nn.Module): def __init__( self, @@ -673,11 +632,10 @@ class PatchEmbed(nn.Module): def __init__( self, - patch_model: nn.Module = None, + patch_model: nn.Module, ) -> None: super().__init__() - set_attributes(self, locals()) - assert self.patch_model is not None + self.patch_model = patch_model def forward(self, x) -> torch.Tensor: x = self.patch_model(x) @@ -868,10 +826,12 @@ def __init__( head (Optional[nn.Module]): Head module. """ super().__init__() - set_attributes(self, locals()) - assert hasattr( - cls_positional_encoding, "patch_embed_shape" - ), "cls_positional_encoding should have attribute patch_embed_shape." + self.patch_embed = patch_embed + self.cls_positional_encoding = cls_positional_encoding + self.pos_drop = pos_drop + self.blocks = blocks + self.norm_embed = norm_embed + self.head = head init_net_weights(self, init_std=0.02, style="vit") def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -896,14 +856,12 @@ def create_multiscale_vision_transformers( spatial_size: _size_2_t, temporal_size: int, depth: int = 16, - norm: str = "layernorm", # Patch embed config. input_channels: int = 3, patch_embed_dim: int = 96, conv_patch_embed_kernel: Tuple[int] = (3, 7, 7), conv_patch_embed_stride: Tuple[int] = (2, 4, 4), conv_patch_embed_padding: Tuple[int] = (1, 3, 3), - enable_patch_embed_norm: bool = False, # Attention block config. num_heads: int = 1, mlp_ratio: float = 4.0, @@ -934,7 +892,6 @@ def create_multiscale_vision_transformers( int is given, it assumes the width and the height are the same. temporal_size (int): Number of frames in the input video. depth (int): The depth of the model. - norm (str): Normalization layer. It currently supports "layernorm". input_channels (int): Channel dimension of the input video. patch_embed_dim (int): Embedding dimension after patchifing the video input. @@ -944,8 +901,6 @@ def create_multiscale_vision_transformers( patchifing the video input. conv_patch_embed_padding (Tuple[int]): Padding size of the convolution for patchifing the video input. - enable_patch_embed_norm (bool): If True, apply normalization after patchifing - the video input. num_heads (int): Number of heads in the first transformer block. mlp_ratio (float): Mlp ratio which controls the feature dimension in the @@ -1083,10 +1038,8 @@ def create_multiscale_vision_transformers( divisor=round_width(num_heads, head_mul[i + 1]), ) - block_func = MultiScaleBlock - mvit_blocks.append( - block_func( + MultiScaleBlock( dim=patch_embed_dim, dim_out=dim_out, num_heads=num_heads, From d91511a318e3190391dc52b43f204aa776e7d2ee Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 24 May 2022 13:03:28 +0100 Subject: [PATCH 08/26] Remove used init methods and replace others with the ones from TorchVision. --- torchvision/models/video/mvit.py | 170 ++++--------------------------- 1 file changed, 18 insertions(+), 152 deletions(-) diff --git a/torchvision/models/video/mvit.py b/torchvision/models/video/mvit.py index 8c1f7598d63..e2b3e677887 100644 --- a/torchvision/models/video/mvit.py +++ b/torchvision/models/video/mvit.py @@ -1,4 +1,3 @@ -import math from functools import partial from typing import Callable, List, Optional, Tuple @@ -7,7 +6,9 @@ import torch.fx import torch.nn as nn from torch.nn.common_types import _size_2_t, _size_3_t -from torchvision.ops import StochasticDepth, MLP + +from ...ops import StochasticDepth, MLP +from .._utils import _make_divisible __all__ = ["mvit_b_16"] @@ -496,62 +497,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -def c2_xavier_fill(module: nn.Module) -> None: - """ - Initialize `module.weight` using the "XavierFill" implemented in Caffe2. - Also initializes `module.bias` to 0. - - Args: - module (torch.nn.Module): module to initialize. - """ - # Caffe2 implementation of XavierFill in fact - # corresponds to kaiming_uniform_ in PyTorch - nn.init.kaiming_uniform_(module.weight, a=1) - if module.bias is not None: - # pyre-fixme[6]: Expected `Tensor` for 1st param but got `Union[nn.Module, - # torch.Tensor]`. - nn.init.constant_(module.bias, 0) - - -def c2_msra_fill(module: nn.Module) -> None: - """ - Initialize `module.weight` using the "MSRAFill" implemented in Caffe2. - Also initializes `module.bias` to 0. - - Args: - module (torch.nn.Module): module to initialize. - """ - nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") - if module.bias is not None: - # pyre-fixme[6]: Expected `Tensor` for 1st param but got `Union[nn.Module, - # torch.Tensor]`. - nn.init.constant_(module.bias, 0) - - -def round_width(width, multiplier, min_width=8, divisor=8, ceil=False): - """ - Round width of filters based on width multiplier - Args: - width (int): the channel dimensions of the input. - multiplier (float): the multiplication factor. - min_width (int): the minimum width after multiplication. - divisor (int): the new width should be dividable by divisor. - ceil (bool): If True, use ceiling as the rounding method. - """ - if not multiplier: - return width - - width *= multiplier - min_width = min_width or divisor - if ceil: - width_out = max(min_width, int(math.ceil(width / divisor)) * divisor) - else: - width_out = max(min_width, int(width + divisor / 2) // divisor * divisor) - if width_out < 0.9 * width: - width_out += divisor - return int(width_out) - - class create_vit_basic_head(nn.Module): def __init__( self, @@ -687,93 +632,6 @@ def create_conv_patch_embed( return PatchEmbed(patch_model=conv_module) -def _init_resnet_weights(model: nn.Module, fc_init_std: float = 0.01) -> None: - """ - Performs ResNet style weight initialization. That is, recursively initialize the - given model in the following way for each type: - Conv - Follow the initialization of kaiming_normal: - https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_ - BatchNorm - Set weight and bias of last BatchNorm at every residual bottleneck - to 0. - Linear - Set weight to 0 mean Gaussian with std deviation fc_init_std and bias - to 0. - Args: - model (nn.Module): Model to be initialized. - fc_init_std (float): the expected standard deviation for fully-connected layer. - """ - for m in model.modules(): - if isinstance(m, (nn.Conv2d, nn.Conv3d)): - """ - Follow the initialization method proposed in: - {He, Kaiming, et al. - "Delving deep into rectifiers: Surpassing human-level - performance on imagenet classification." - arXiv preprint arXiv:1502.01852 (2015)} - """ - c2_msra_fill(m) - elif isinstance(m, nn.modules.batchnorm._NormBase): - if m.weight is not None: - if hasattr(m, "block_final_bn") and m.block_final_bn: - m.weight.data.fill_(0.0) - else: - m.weight.data.fill_(1.0) - if m.bias is not None: - m.bias.data.zero_() - if isinstance(m, nn.Linear): - if hasattr(m, "xavier_init") and m.xavier_init: - c2_xavier_fill(m) - else: - m.weight.data.normal_(mean=0.0, std=fc_init_std) - if m.bias is not None: - m.bias.data.zero_() - return model - - -def _init_vit_weights(model: nn.Module, trunc_normal_std: float = 0.02) -> None: - """ - Weight initialization for vision transformers. - - Args: - model (nn.Module): Model to be initialized. - trunc_normal_std (float): the expected standard deviation for fully-connected - layer and ClsPositionalEncoding. - """ - for m in model.modules(): - if isinstance(m, nn.Linear): - nn.init.trunc_normal_(m.weight, std=trunc_normal_std) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - elif isinstance(m, SpatioTemporalClsPositionalEncoding): - for weights in m.parameters(): - nn.init.trunc_normal_(weights, std=trunc_normal_std) - - -def init_net_weights( - model: nn.Module, - init_std: float = 0.01, - style: str = "resnet", -) -> None: - """ - Performs weight initialization. Options include ResNet style weight initialization - and transformer style weight initialization. - - Args: - model (nn.Module): Model to be initialized. - init_std (float): The expected standard deviation for initialization. - style (str): Options include "resnet" and "vit". - """ - assert style in ["resnet", "vit"] - if style == "resnet": - return _init_resnet_weights(model, init_std) - elif style == "vit": - return _init_vit_weights(model, init_std) - else: - raise NotImplementedError - - class MultiscaleVisionTransformers(nn.Module): """ Multiscale Vision Transformers @@ -832,7 +690,17 @@ def __init__( self.blocks = blocks self.norm_embed = norm_embed self.head = head - init_net_weights(self, init_std=0.02, style="vit") + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, SpatioTemporalClsPositionalEncoding): + for weights in m.parameters(): + nn.init.trunc_normal_(weights, std=0.02) def forward(self, x: torch.Tensor) -> torch.Tensor: if self.patch_embed is not None: @@ -1030,12 +898,10 @@ def create_multiscale_vision_transformers( pool_kv[pool_kv_stride_size[i][0]] = [s + 1 if s > 1 else s for s in pool_kv_stride_size[i][1:]] for i in range(depth): - num_heads = round_width(num_heads, head_mul[i], min_width=1, divisor=1) - patch_embed_dim = round_width(patch_embed_dim, dim_mul[i], divisor=num_heads) - dim_out = round_width( - patch_embed_dim, - dim_mul[i + 1], - divisor=round_width(num_heads, head_mul[i + 1]), + num_heads = _make_divisible(num_heads * head_mul[i], 1) + patch_embed_dim = _make_divisible(patch_embed_dim * dim_mul[i], num_heads, min_value=8) + dim_out = _make_divisible( + patch_embed_dim * dim_mul[i + 1], divisor=_make_divisible(num_heads * head_mul[i + 1], 8), min_value=8, ) mvit_blocks.append( From ee996083997c888f85f2b5dfb2d5aa9a8a92c4bf Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 24 May 2022 13:32:51 +0100 Subject: [PATCH 09/26] Drop unused private methods and further cleanups. --- torchvision/models/video/mvit.py | 135 +++++++------------------------ 1 file changed, 27 insertions(+), 108 deletions(-) diff --git a/torchvision/models/video/mvit.py b/torchvision/models/video/mvit.py index e2b3e677887..b35ec5f4f13 100644 --- a/torchvision/models/video/mvit.py +++ b/torchvision/models/video/mvit.py @@ -1,7 +1,6 @@ from functools import partial -from typing import Callable, List, Optional, Tuple +from typing import Callable, List, Optional, Sequence, Tuple -import numpy import torch import torch.fx import torch.nn as nn @@ -14,6 +13,13 @@ __all__ = ["mvit_b_16"] +def _prod(s: Sequence[int]) -> int: + product = 1 + for v in s: + product *= v + return product + + def _attention_pool( tensor: torch.Tensor, pool: Optional[Callable], @@ -161,9 +167,9 @@ def __init__( self.proj_drop = nn.Dropout(dropout_rate) # Skip pooling with kernel and stride size of (1, 1, 1). - if kernel_q is not None and numpy.prod(kernel_q) == 1 and numpy.prod(stride_q) == 1: + if kernel_q is not None and _prod(kernel_q) == 1 and _prod(stride_q) == 1: kernel_q = None - if kernel_kv is not None and numpy.prod(kernel_kv) == 1 and numpy.prod(stride_kv) == 1: + if kernel_kv is not None and _prod(kernel_kv) == 1 and _prod(stride_kv) == 1: kernel_kv = None self.pool_q = ( @@ -209,22 +215,6 @@ def __init__( ) self.norm_v = norm_layer(head_dim) if kernel_kv is not None else None - def _qkv_proj( - self, - q: torch.Tensor, - q_size: List[int], - k: torch.Tensor, - k_size: List[int], - v: torch.Tensor, - v_size: List[int], - batch_size: List[int], - chan_size: List[int], - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - q = self.q(q).reshape(batch_size, q_size, self.num_heads, chan_size // self.num_heads).permute(0, 2, 1, 3) - k = self.k(k).reshape(batch_size, k_size, self.num_heads, chan_size // self.num_heads).permute(0, 2, 1, 3) - v = self.v(v).reshape(batch_size, v_size, self.num_heads, chan_size // self.num_heads).permute(0, 2, 1, 3) - return q, k, v - def _qkv_pool( self, q: torch.Tensor, @@ -252,33 +242,6 @@ def _qkv_pool( ) return q, q_shape, k, k_shape, v, v_shape - def _get_qkv_length( - self, - q_shape: List[int], - k_shape: List[int], - v_shape: List[int], - ) -> Tuple[int, int, int]: - q_N = numpy.prod(q_shape) + 1 - k_N = numpy.prod(k_shape) + 1 - v_N = numpy.prod(v_shape) + 1 - return q_N, k_N, v_N - - def _reshape_qkv_to_seq( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - q_N: int, - v_N: int, - k_N: int, - B: int, - C: int, - ) -> Tuple[int, int, int]: - q = q.permute(0, 2, 1, 3).reshape(B, q_N, C) - v = v.permute(0, 2, 1, 3).reshape(B, v_N, C) - k = k.permute(0, 2, 1, 3).reshape(B, k_N, C) - return q, k, v - def forward(self, x: torch.Tensor, thw_shape: List[int]) -> Tuple[torch.Tensor, List[int]]: """ Args: @@ -413,7 +376,7 @@ def __init__( self.pool_skip = ( nn.MaxPool3d(kernel_skip, stride_skip, padding_skip, ceil_mode=False) - if len(stride_skip) > 0 and numpy.prod(stride_skip) > 1 + if len(stride_skip) > 0 and _prod(stride_skip) > 1 else None ) @@ -452,7 +415,7 @@ class SpatioTemporalClsPositionalEncoding(nn.Module): def __init__( self, embed_dim: int, - patch_embed_shape: List[int], + patch_embed_shape: Tuple[int, int, int], ) -> None: """ Args: @@ -503,8 +466,6 @@ def __init__( # Projection configs. in_features: int, out_features: int, - # Pooling configs. - seq_pool_type: str = "cls", # Dropout configs. dropout_rate: float = 0.5, # Activation configs. @@ -588,50 +549,6 @@ def forward(self, x) -> torch.Tensor: return x.flatten(2).transpose(1, 2) -def create_conv_patch_embed( - in_channels: int, - out_channels: int, - conv_kernel_size: Tuple[int] = (1, 16, 16), - conv_stride: Tuple[int] = (1, 4, 4), - conv_padding: Tuple[int] = (1, 7, 7), - conv_bias: bool = True, -) -> nn.Module: - """ - Creates the transformer basic patch embedding. It performs Convolution, flatten and - transpose. - - :: - - Conv3d - ↓ - flatten - ↓ - transpose - - Args: - in_channels (int): input channel size of the convolution. - out_channels (int): output channel size of the convolution. - conv_kernel_size (tuple): convolutional kernel size(s). - conv_stride (tuple): convolutional stride size(s). - conv_padding (tuple): convolutional padding size(s). - conv_bias (bool): convolutional bias. If true, adds a learnable bias to the - output. - conv (callable): Callable used to build the convolution layer. - - Returns: - (nn.Module): transformer patch embedding layer. - """ - conv_module = nn.Conv3d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=conv_kernel_size, - stride=conv_stride, - padding=conv_padding, - bias=conv_bias, - ) - return PatchEmbed(patch_model=conv_module) - - class MultiscaleVisionTransformers(nn.Module): """ Multiscale Vision Transformers @@ -834,18 +751,20 @@ def create_multiscale_vision_transformers( if isinstance(spatial_size, int): spatial_size = (spatial_size, spatial_size) - patch_embed = create_conv_patch_embed( - in_channels=input_channels, - out_channels=patch_embed_dim, - conv_kernel_size=conv_patch_embed_kernel, - conv_stride=conv_patch_embed_stride, - conv_padding=conv_patch_embed_padding, + patch_embed = PatchEmbed( + patch_model=nn.Conv3d( + in_channels=input_channels, + out_channels=patch_embed_dim, + kernel_size=conv_patch_embed_kernel, + stride=conv_patch_embed_stride, + padding=conv_patch_embed_padding, + bias=True, + ) ) - input_dims = [temporal_size, spatial_size[0], spatial_size[1]] - input_stirde = conv_patch_embed_stride + input_dims = (temporal_size, spatial_size[0], spatial_size[1]) - patch_embed_shape = [input_dims[i] // input_stirde[i] for i in range(len(input_dims))] + patch_embed_shape = tuple(v // conv_patch_embed_stride[i] for i, v in enumerate(input_dims)) cls_positional_encoding = SpatioTemporalClsPositionalEncoding( embed_dim=patch_embed_dim, @@ -867,10 +786,10 @@ def create_multiscale_vision_transformers( mvit_blocks = nn.ModuleList() - pool_q = [[] for i in range(depth)] - pool_kv = [[] for i in range(depth)] - stride_q = [[] for i in range(depth)] - stride_kv = [[] for i in range(depth)] + pool_q = [[] for _ in range(depth)] + pool_kv = [[] for _ in range(depth)] + stride_q = [[] for _ in range(depth)] + stride_kv = [[] for _ in range(depth)] if pool_q_stride_size is not None: for i in range(len(pool_q_stride_size)): From 03625fd5437b0d96c299b5538c999238f5438d86 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 24 May 2022 14:00:08 +0100 Subject: [PATCH 10/26] Fixing classifier head. --- torchvision/models/video/mvit.py | 77 +++++++------------------------- 1 file changed, 15 insertions(+), 62 deletions(-) diff --git a/torchvision/models/video/mvit.py b/torchvision/models/video/mvit.py index b35ec5f4f13..1faf33347fb 100644 --- a/torchvision/models/video/mvit.py +++ b/torchvision/models/video/mvit.py @@ -310,9 +310,9 @@ def __init__( qkv_bias: bool = False, dropout_rate: float = 0.0, droppath_rate: float = 0.0, - act_layer: nn.Module = nn.GELU, - norm_layer: nn.Module = nn.LayerNorm, - attn_norm_layer: nn.Module = nn.LayerNorm, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_norm_layer: Callable[..., nn.Module] = nn.LayerNorm, kernel_q: _size_3_t = (1, 1, 1), kernel_kv: _size_3_t = (1, 1, 1), stride_q: _size_3_t = (1, 1, 1), @@ -460,7 +460,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -class create_vit_basic_head(nn.Module): +class ClassificationHead(nn.Module): def __init__( self, # Projection configs. @@ -468,52 +468,16 @@ def __init__( out_features: int, # Dropout configs. dropout_rate: float = 0.5, - # Activation configs. - activation: Optional[Callable] = None, - ) -> nn.Module: - """ - Creates vision transformer basic head. - - :: - - - Pooling - ↓ - Dropout - ↓ - Projection - ↓ - Activation - - - Activation examples include: ReLU, Softmax, Sigmoid, and None. - Pool type examples include: cls, mean and none. - - Args: - - in_features: input channel size of the resnet head. - out_features: output channel size of the resnet head. - - pool_type (str): Pooling type. It supports "cls", "mean " and "none". If set to - "cls", it assumes the first element in the input is the cls token and - returns it. If set to "mean", it returns the mean of the entire sequence. - - activation (callable): a callable that constructs vision transformer head - activation layer, examples include: nn.ReLU, nn.Softmax, nn.Sigmoid, and - None (not applying activation). - - dropout_rate (float): dropout rate. - """ + ) -> None: super().__init__() - self.dropout = nn.Dropout(dropout_rate) if dropout_rate > 0.0 else None + self.dropout = nn.Dropout(dropout_rate) self.proj = nn.Linear(in_features, out_features) def forward(self, x: torch.Tensor) -> torch.Tensor: # Pick cls embedding x = x[:, 0] # Performs dropout. - if self.dropout is not None: - x = self.dropout(x) + x = self.dropout(x) # Performs projection. x = self.proj(x) return x @@ -662,9 +626,7 @@ def create_multiscale_vision_transformers( pool_kv_stride_adaptive: Optional[_size_3_t] = (1, 8, 8), pool_kvq_kernel: Optional[_size_3_t] = (3, 3, 3), # Head config. - head: Optional[Callable] = create_vit_basic_head, head_dropout_rate: float = 0.5, - head_activation: Optional[Callable] = None, num_classes: int = 400, **kwargs, ) -> nn.Module: @@ -715,9 +677,7 @@ def create_multiscale_vision_transformers( pool_kvq_kernel (Optional[_size_3_t]): Pooling kernel size for q and kv. It None, the kernel_size is [s + 1 if s > 1 else s for s in stride_size]. - head (Callable): Head model. head_dropout_rate (float): Dropout rate in the head. - head_activation (Callable): Activation in the head. num_classes (int): Number of classes in the final classification head. Example usage (building a MViT_B model for Kinetics400): @@ -773,9 +733,6 @@ def create_multiscale_vision_transformers( dpr = [x.item() for x in torch.linspace(0, droppath_rate_block, depth)] # stochastic depth decay rule - if dropout_rate_block > 0.0: - pos_drop = nn.Dropout(p=dropout_rate_block) - dim_mul, head_mul = torch.ones(depth + 1), torch.ones(depth + 1) if embed_dim_mul is not None: for i in range(len(embed_dim_mul)): @@ -816,6 +773,7 @@ def create_multiscale_vision_transformers( else: pool_kv[pool_kv_stride_size[i][0]] = [s + 1 if s > 1 else s for s in pool_kv_stride_size[i][1:]] + dim_out = 0 for i in range(depth): num_heads = _make_divisible(num_heads * head_mul[i], 1) patch_embed_dim = _make_divisible(patch_embed_dim * dim_mul[i], num_heads, min_value=8) @@ -844,23 +802,18 @@ def create_multiscale_vision_transformers( ) embed_dim = dim_out - norm_embed = None if norm_layer is None else norm_layer(embed_dim) - if head is not None: - head_model = head( - in_features=embed_dim, - out_features=num_classes, - dropout_rate=head_dropout_rate, - activation=head_activation, - ) - else: - head_model = None + head_model = ClassificationHead( + in_features=embed_dim, + out_features=num_classes, + dropout_rate=head_dropout_rate, + ) return MultiscaleVisionTransformers( patch_embed=patch_embed, cls_positional_encoding=cls_positional_encoding, - pos_drop=pos_drop if dropout_rate_block > 0.0 else None, + pos_drop=nn.Dropout(p=dropout_rate_block) if dropout_rate_block > 0.0 else None, blocks=mvit_blocks, - norm_embed=norm_embed, + norm_embed=norm_layer(embed_dim), head=head_model, ) From 2c8e2393ff626b5309d303d867e6ddf0bd3d4745 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 24 May 2022 14:19:06 +0100 Subject: [PATCH 11/26] Fixing typing info. --- torchvision/models/video/mvit.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchvision/models/video/mvit.py b/torchvision/models/video/mvit.py index 1faf33347fb..a5083f984bc 100644 --- a/torchvision/models/video/mvit.py +++ b/torchvision/models/video/mvit.py @@ -22,9 +22,9 @@ def _prod(s: Sequence[int]) -> int: def _attention_pool( tensor: torch.Tensor, - pool: Optional[Callable], + pool: Optional[nn.Module], thw_shape: List[int], - norm: Optional[Callable] = None, + norm: Optional[nn.Module] = None, ) -> Tuple[torch.Tensor, List[int]]: """ Apply pool to a flattened input (given pool operation and the unflattened shape). @@ -128,7 +128,7 @@ def __init__( kernel_kv: _size_3_t = (1, 1, 1), stride_q: _size_3_t = (1, 1, 1), stride_kv: _size_3_t = (1, 1, 1), - norm_layer: Callable = nn.LayerNorm, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, depthwise_conv: bool = True, bias_on: bool = True, ) -> None: From 2f57775d19ab7c135257d3fe73469bd8c943ac55 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 24 May 2022 14:36:13 +0100 Subject: [PATCH 12/26] Remove identity option from attention pool. --- torchvision/models/video/mvit.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchvision/models/video/mvit.py b/torchvision/models/video/mvit.py index a5083f984bc..be7a02a117e 100644 --- a/torchvision/models/video/mvit.py +++ b/torchvision/models/video/mvit.py @@ -67,7 +67,7 @@ def _attention_pool( T, H, W = thw_shape tensor = tensor.reshape(B * N, T, H, W, C).permute(0, 4, 1, 2, 3).contiguous() - if isinstance(norm, (nn.BatchNorm3d, nn.Identity)): + if isinstance(norm, nn.BatchNorm3d): # If use BN, we apply norm before pooling instead of after pooling. tensor = norm(tensor) # We also empirically find that adding a GELU here is beneficial. @@ -367,7 +367,7 @@ def __init__( bias_on=bias_on, depthwise_conv=depthwise_conv, ) - self.drop_path = StochasticDepth(droppath_rate, "row") if droppath_rate > 0.0 else nn.Identity() + self.stochastic_depth = StochasticDepth(droppath_rate, "row") self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = MLP(dim, [mlp_hidden_dim, dim_out], activation_layer=act_layer, dropout=dropout_rate, bias=bias_on, inplace=None) @@ -396,14 +396,14 @@ def forward(self, x: torch.Tensor, thw_shape: List[int]) -> Tuple[torch.Tensor, thw_shape, ) x_res, _ = _attention_pool(x, self.pool_skip, thw_shape) - x = x_res + self.drop_path(x_block) + x = x_res + self.stochastic_depth(x_block) x_norm = ( self.norm2(x.permute(0, 2, 1)).permute(0, 2, 1) if isinstance(self.norm2, nn.BatchNorm1d) else self.norm2(x) ) x_mlp = self.mlp(x_norm) if self.dim != self.dim_out: x = self.proj(x_norm) - x = x + self.drop_path(x_mlp) + x = x + self.stochastic_depth(x_mlp) return x, thw_shape_new From 589a6430f8b223cf8105ae1f052d9e22cdbe2c2e Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 24 May 2022 18:05:44 +0100 Subject: [PATCH 13/26] Fixing JIT-scriptability --- torchvision/models/video/mvit.py | 190 +++++++++++++------------------ 1 file changed, 78 insertions(+), 112 deletions(-) diff --git a/torchvision/models/video/mvit.py b/torchvision/models/video/mvit.py index be7a02a117e..b1333e31ebe 100644 --- a/torchvision/models/video/mvit.py +++ b/torchvision/models/video/mvit.py @@ -20,75 +20,64 @@ def _prod(s: Sequence[int]) -> int: return product -def _attention_pool( - tensor: torch.Tensor, - pool: Optional[nn.Module], - thw_shape: List[int], - norm: Optional[nn.Module] = None, -) -> Tuple[torch.Tensor, List[int]]: - """ - Apply pool to a flattened input (given pool operation and the unflattened shape). - - - Input - ↓ - Reshape - ↓ - Pool - ↓ - Reshape - ↓ - Norm - - - Args: - tensor (torch.Tensor): Input tensor. - pool (Optional[Callable]): Pool operation that is applied to the input tensor. - If pool is none, return the input tensor. - thw_shape (List): The shape of the input tensor (before flattening). - norm: (Optional[Callable]): Optional normalization operation applied to - tensor after pool. - - Returns: - tensor (torch.Tensor): Input tensor after pool. - thw_shape (List[int]): Output tensor shape (before flattening). - """ - if pool is None: - return tensor, thw_shape - tensor_dim = tensor.ndim +def _unsqueeze(tensor: torch.Tensor) -> Tuple[torch.Tensor, int]: + tensor_dim = tensor.dim() if tensor_dim == 3: tensor = tensor.unsqueeze(1) elif tensor_dim != 4: raise NotImplementedError(f"Unsupported input dimension {tensor.shape}") + return tensor, tensor_dim - cls_tok, tensor = tensor[:, :, :1, :], tensor[:, :, 1:, :] +def _squeeze(tensor: torch.Tensor, tensor_dim: int) -> torch.Tensor: + if tensor_dim == 3: + tensor = tensor.squeeze(1) + return tensor - B, N, L, C = tensor.shape - T, H, W = thw_shape - tensor = tensor.reshape(B * N, T, H, W, C).permute(0, 4, 1, 2, 3).contiguous() - if isinstance(norm, nn.BatchNorm3d): - # If use BN, we apply norm before pooling instead of after pooling. - tensor = norm(tensor) - # We also empirically find that adding a GELU here is beneficial. - tensor = nn.functional.gelu(tensor) +torch.fx.wrap("_unsqueeze") +torch.fx.wrap("_squeeze") - tensor = pool(tensor) - thw_shape = [tensor.shape[2], tensor.shape[3], tensor.shape[4]] - L_pooled = tensor.shape[2] * tensor.shape[3] * tensor.shape[4] - tensor = tensor.reshape(B, N, C, L_pooled).transpose(2, 3) +class AttentionPool(nn.Module): + def __init__(self, pool: Optional[nn.Module], norm: Optional[nn.Module]): + super().__init__() + self.pool = pool + self.norm = norm + self._norm_before_pool = isinstance(norm, nn.BatchNorm3d) - tensor = torch.cat((cls_tok, tensor), dim=2) - if norm is not None and not isinstance(norm, nn.BatchNorm3d): - tensor = norm(tensor) + def forward( + self, + tensor: torch.Tensor, + thw_shape: List[int], + ) -> Tuple[torch.Tensor, List[int]]: + if self.pool is None: + return tensor, thw_shape + tensor, tensor_dim = _unsqueeze(tensor) - if tensor_dim == 3: - tensor = tensor.squeeze(1) - return tensor, thw_shape + cls_tok, tensor = tensor[:, :, :1, :], tensor[:, :, 1:, :] + B, N, L, C = tensor.shape + T, H, W = thw_shape + tensor = tensor.reshape(B * N, T, H, W, C).permute(0, 4, 1, 2, 3).contiguous() -torch.fx.wrap("_attention_pool") + if self.norm is not None and self._norm_before_pool: + # If use BN, we apply norm before pooling instead of after pooling. + tensor = self.norm(tensor) + # We also empirically find that adding a GELU here is beneficial. + tensor = nn.functional.gelu(tensor) + + tensor = self.pool(tensor) + + thw_shape = [tensor.shape[2], tensor.shape[3], tensor.shape[4]] + L_pooled = tensor.shape[2] * tensor.shape[3] * tensor.shape[4] + tensor = tensor.reshape(B, N, C, L_pooled).transpose(2, 3) + + tensor = torch.cat((cls_tok, tensor), dim=2) + if self.norm is not None and not self._norm_before_pool: + tensor = self.norm(tensor) + + tensor = _squeeze(tensor, tensor_dim) + return tensor, thw_shape class MultiScaleAttention(nn.Module): @@ -163,8 +152,7 @@ def __init__( self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.proj = nn.Linear(dim, dim, bias=True if bias_on else False) - if dropout_rate > 0.0: - self.proj_drop = nn.Dropout(dropout_rate) + self.proj_drop = nn.Dropout(dropout_rate) # Skip pooling with kernel and stride size of (1, 1, 1). if kernel_q is not None and _prod(kernel_q) == 1 and _prod(stride_q) == 1: @@ -172,7 +160,7 @@ def __init__( if kernel_kv is not None and _prod(kernel_kv) == 1 and _prod(stride_kv) == 1: kernel_kv = None - self.pool_q = ( + self.att_pool_q = AttentionPool( nn.Conv3d( head_dim, head_dim, @@ -181,12 +169,10 @@ def __init__( padding=padding_q, groups=head_dim if depthwise_conv else 1, bias=False, - ) - if kernel_q is not None - else None + ) if kernel_q is not None else None, + norm_layer(head_dim) if kernel_q is not None else None ) - self.norm_q = norm_layer(head_dim) if kernel_q is not None else None - self.pool_k = ( + self.att_pool_k = AttentionPool( nn.Conv3d( head_dim, head_dim, @@ -195,12 +181,10 @@ def __init__( padding=padding_kv, groups=head_dim if depthwise_conv else 1, bias=False, - ) - if kernel_kv is not None - else None + ) if kernel_kv is not None else None, + norm_layer(head_dim) if kernel_kv is not None else None ) - self.norm_k = norm_layer(head_dim) if kernel_kv is not None else None - self.pool_v = ( + self.att_pool_v = AttentionPool( nn.Conv3d( head_dim, head_dim, @@ -209,38 +193,9 @@ def __init__( padding=padding_kv, groups=head_dim if depthwise_conv else 1, bias=False, - ) - if kernel_kv is not None - else None + ) if kernel_kv is not None else None, + norm_layer(head_dim) if kernel_kv is not None else None ) - self.norm_v = norm_layer(head_dim) if kernel_kv is not None else None - - def _qkv_pool( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - thw_shape: List[int], - ) -> Tuple[torch.Tensor, List[int], torch.Tensor, List[int], torch.Tensor, List[int]]: - q, q_shape = _attention_pool( - q, - self.pool_q, - thw_shape, - norm=self.norm_q, - ) - k, k_shape = _attention_pool( - k, - self.pool_k, - thw_shape, - norm=self.norm_k, - ) - v, v_shape = _attention_pool( - v, - self.pool_v, - thw_shape, - norm=self.norm_v, - ) - return q, q_shape, k, k_shape, v, v_shape def forward(self, x: torch.Tensor, thw_shape: List[int]) -> Tuple[torch.Tensor, List[int]]: """ @@ -253,7 +208,18 @@ def forward(self, x: torch.Tensor, thw_shape: List[int]) -> Tuple[torch.Tensor, qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] - q, q_shape, k, k_shape, v, v_shape = self._qkv_pool(q, k, v, thw_shape) + q, q_shape = self.att_pool_q( + q, + thw_shape, + ) + k, k_shape = self.att_pool_k( + k, + thw_shape, + ) + v, v_shape = self.att_pool_v( + v, + thw_shape, + ) attn = torch.matmul(q * self.scale, k.transpose(-2, -1)) attn = attn.softmax(dim=-1) @@ -263,14 +229,10 @@ def forward(self, x: torch.Tensor, thw_shape: List[int]) -> Tuple[torch.Tensor, x = (torch.matmul(attn, v) + q).transpose(1, 2).reshape(B, N, C) x = self.proj(x) - if self.dropout_rate > 0.0: - x = self.proj_drop(x) + x = self.proj_drop(x) return x, q_shape - -torch.fx.wrap("_attention_pool") - class MultiScaleBlock(nn.Module): """ Implementation of a multiscale vision transformer block. Each block contains a @@ -373,12 +335,16 @@ def __init__( self.mlp = MLP(dim, [mlp_hidden_dim, dim_out], activation_layer=act_layer, dropout=dropout_rate, bias=bias_on, inplace=None) if dim != dim_out: self.proj = nn.Linear(dim, dim_out, bias=bias_on) + else: + self.proj = None - self.pool_skip = ( + self.att_pool_skip = AttentionPool( nn.MaxPool3d(kernel_skip, stride_skip, padding_skip, ceil_mode=False) if len(stride_skip) > 0 and _prod(stride_skip) > 1 - else None + else None, + None ) + self.need_permutation = [isinstance(self.norm1, nn.BatchNorm1d), isinstance(self.norm2, nn.BatchNorm1d)] def forward(self, x: torch.Tensor, thw_shape: List[int]) -> Tuple[torch.Tensor, List[int]]: """ @@ -390,18 +356,18 @@ def forward(self, x: torch.Tensor, thw_shape: List[int]) -> Tuple[torch.Tensor, x_block, thw_shape_new = self.attn( ( self.norm1(x.permute(0, 2, 1)).permute(0, 2, 1) - if isinstance(self.norm1, nn.BatchNorm1d) + if self.need_permutation[0] else self.norm1(x) ), thw_shape, ) - x_res, _ = _attention_pool(x, self.pool_skip, thw_shape) + x_res, _ = self.att_pool_skip(x, thw_shape) x = x_res + self.stochastic_depth(x_block) x_norm = ( - self.norm2(x.permute(0, 2, 1)).permute(0, 2, 1) if isinstance(self.norm2, nn.BatchNorm1d) else self.norm2(x) + self.norm2(x.permute(0, 2, 1)).permute(0, 2, 1) if self.need_permutation[1] else self.norm2(x) ) x_mlp = self.mlp(x_norm) - if self.dim != self.dim_out: + if self.proj is not None: x = self.proj(x_norm) x = x + self.stochastic_depth(x_mlp) return x, thw_shape_new From 6a6e0b649c1fe65ce9f66aa5c076fad675dfc526 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 25 May 2022 10:26:10 +0100 Subject: [PATCH 14/26] Apply recommendations from code-review. --- torchvision/models/video/mvit.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torchvision/models/video/mvit.py b/torchvision/models/video/mvit.py index b1333e31ebe..0bb64edd14c 100644 --- a/torchvision/models/video/mvit.py +++ b/torchvision/models/video/mvit.py @@ -43,7 +43,9 @@ def __init__(self, pool: Optional[nn.Module], norm: Optional[nn.Module]): super().__init__() self.pool = pool self.norm = norm - self._norm_before_pool = isinstance(norm, nn.BatchNorm3d) + # The standard mvit uses layer norm and normalizes after pooling. Nevertheless in some production use-cases, it + # might be prefered to "absorb" the norm in order to make the inference faster. + self.norm_before_pool = isinstance(norm, (nn.BatchNorm3d, nn.Identity)) def forward( self, @@ -60,7 +62,7 @@ def forward( T, H, W = thw_shape tensor = tensor.reshape(B * N, T, H, W, C).permute(0, 4, 1, 2, 3).contiguous() - if self.norm is not None and self._norm_before_pool: + if self.norm is not None and self.norm_before_pool: # If use BN, we apply norm before pooling instead of after pooling. tensor = self.norm(tensor) # We also empirically find that adding a GELU here is beneficial. @@ -73,7 +75,7 @@ def forward( tensor = tensor.reshape(B, N, C, L_pooled).transpose(2, 3) tensor = torch.cat((cls_tok, tensor), dim=2) - if self.norm is not None and not self._norm_before_pool: + if self.norm is not None and not self.norm_before_pool: tensor = self.norm(tensor) tensor = _squeeze(tensor, tensor_dim) From 9a72fd6be035df10e1e6107520d18cdceaa44548 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 25 May 2022 11:11:51 +0100 Subject: [PATCH 15/26] Adding expected file for `mvit_b_16` --- test/expect/ModelTester.test_mvit_b_16_expect.pkl | Bin 0 -> 939 bytes 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 test/expect/ModelTester.test_mvit_b_16_expect.pkl diff --git a/test/expect/ModelTester.test_mvit_b_16_expect.pkl b/test/expect/ModelTester.test_mvit_b_16_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..0ba169ef1d86077665fe79f0cd1934911cb152a6 GIT binary patch literal 939 zcmWIWW@cev;NW1u00Im`42ea_8JT6N`YDMeiFyUuIc`pT3{fbcfhoBpAE-(%zO*DW zr zf)S|3ppZF&8AvA=loqmh8ZvUemW=jY_4CYNO9=M{7L z7p0^YrKY%KCYNv(a%ct>a+VZw1r>7Z1$eV_Fj*X^nFTZrgadH;l#f9R#i#lPZcb`w z{zUOK5=ZLp=6xssY_oqGHQCPd_|JV|``_=oHb-&)?)6@F%gR{yKaS_z-&NXSdq=e0 zeiak@{yW!;_xW8E-T(T7zuko1D|;I9ysh85d)Ub~d)W&l@817d`@LQM%y9b(=LxpQ z1&r-{UT53!T~D%W&o|gNMM1%Cv8%zO4dW&dQy_sGfa$>u5c z3MzUwtP}e7fkG>6lSbDGU`T;5?$F|Au!e_LS!z)+Fc#dL%!v#xq>zI!jk!QJU!0d7 z$^^6(gaf=8K@>bqBFCWsNCE|*r%-g=$bRBO(fJC4k Date: Wed, 25 May 2022 15:11:26 +0100 Subject: [PATCH 16/26] Fixing linter and some typing issues. --- torchvision/models/video/mvit.py | 99 ++++++++++++++++---------------- 1 file changed, 50 insertions(+), 49 deletions(-) diff --git a/torchvision/models/video/mvit.py b/torchvision/models/video/mvit.py index 0bb64edd14c..f9a9b7841af 100644 --- a/torchvision/models/video/mvit.py +++ b/torchvision/models/video/mvit.py @@ -4,7 +4,6 @@ import torch import torch.fx import torch.nn as nn -from torch.nn.common_types import _size_2_t, _size_3_t from ...ops import StochasticDepth, MLP from .._utils import _make_divisible @@ -28,6 +27,7 @@ def _unsqueeze(tensor: torch.Tensor) -> Tuple[torch.Tensor, int]: raise NotImplementedError(f"Unsupported input dimension {tensor.shape}") return tensor, tensor_dim + def _squeeze(tensor: torch.Tensor, tensor_dim: int) -> torch.Tensor: if tensor_dim == 3: tensor = tensor.squeeze(1) @@ -115,10 +115,10 @@ def __init__( num_heads: int = 8, qkv_bias: bool = False, dropout_rate: float = 0.0, - kernel_q: _size_3_t = (1, 1, 1), - kernel_kv: _size_3_t = (1, 1, 1), - stride_q: _size_3_t = (1, 1, 1), - stride_kv: _size_3_t = (1, 1, 1), + kernel_q: Tuple[int, int, int] = (1, 1, 1), + kernel_kv: Tuple[int, int, int] = (1, 1, 1), + stride_q: Tuple[int, int, int] = (1, 1, 1), + stride_kv: Tuple[int, int, int] = (1, 1, 1), norm_layer: Callable[..., nn.Module] = nn.LayerNorm, depthwise_conv: bool = True, bias_on: bool = True, @@ -130,14 +130,14 @@ def __init__( qkv_bias (bool): If set to False, the qkv layer will not learn an additive bias. Default: False. dropout_rate (float): Dropout rate. - kernel_q (_size_3_t): Pooling kernel size for q. If both pooling kernel + kernel_q (Tuple[int, int, int]): Pooling kernel size for q. If both pooling kernel size and pooling stride size are 1 for all the dimensions, pooling is disabled. - kernel_kv (_size_3_t): Pooling kernel size for kv. If both pooling kernel + kernel_kv (Tuple[int, int, int]): Pooling kernel size for kv. If both pooling kernel size and pooling stride size are 1 for all the dimensions, pooling is disabled. - stride_q (_size_3_t): Pooling kernel stride for q. - stride_kv (_size_3_t): Pooling kernel stride for kv. + stride_q (Tuple[int, int, int]): Pooling kernel stride for q. + stride_kv (Tuple[int, int, int]): Pooling kernel stride for kv. norm_layer (nn.Module): Normalization layer used after pooling. depthwise_conv (bool): Wether use depthwise or full convolution for pooling. bias_on (bool): Wether use biases for linear layers. @@ -157,10 +157,8 @@ def __init__( self.proj_drop = nn.Dropout(dropout_rate) # Skip pooling with kernel and stride size of (1, 1, 1). - if kernel_q is not None and _prod(kernel_q) == 1 and _prod(stride_q) == 1: - kernel_q = None - if kernel_kv is not None and _prod(kernel_kv) == 1 and _prod(stride_kv) == 1: - kernel_kv = None + skip_pool_q = _prod(kernel_q) == 1 and _prod(stride_q) == 1 + skip_pool_kv = _prod(kernel_kv) == 1 and _prod(stride_kv) == 1 self.att_pool_q = AttentionPool( nn.Conv3d( @@ -171,8 +169,10 @@ def __init__( padding=padding_q, groups=head_dim if depthwise_conv else 1, bias=False, - ) if kernel_q is not None else None, - norm_layer(head_dim) if kernel_q is not None else None + ) + if not skip_pool_q + else None, + norm_layer(head_dim) if not skip_pool_q else None, ) self.att_pool_k = AttentionPool( nn.Conv3d( @@ -183,8 +183,10 @@ def __init__( padding=padding_kv, groups=head_dim if depthwise_conv else 1, bias=False, - ) if kernel_kv is not None else None, - norm_layer(head_dim) if kernel_kv is not None else None + ) + if not skip_pool_kv + else None, + norm_layer(head_dim) if not skip_pool_kv else None, ) self.att_pool_v = AttentionPool( nn.Conv3d( @@ -195,8 +197,10 @@ def __init__( padding=padding_kv, groups=head_dim if depthwise_conv else 1, bias=False, - ) if kernel_kv is not None else None, - norm_layer(head_dim) if kernel_kv is not None else None + ) + if not skip_pool_kv + else None, + norm_layer(head_dim) if not skip_pool_kv else None, ) def forward(self, x: torch.Tensor, thw_shape: List[int]) -> Tuple[torch.Tensor, List[int]]: @@ -277,10 +281,10 @@ def __init__( act_layer: Callable[..., nn.Module] = nn.GELU, norm_layer: Callable[..., nn.Module] = nn.LayerNorm, attn_norm_layer: Callable[..., nn.Module] = nn.LayerNorm, - kernel_q: _size_3_t = (1, 1, 1), - kernel_kv: _size_3_t = (1, 1, 1), - stride_q: _size_3_t = (1, 1, 1), - stride_kv: _size_3_t = (1, 1, 1), + kernel_q: Tuple[int, int, int] = (1, 1, 1), + kernel_kv: Tuple[int, int, int] = (1, 1, 1), + stride_q: Tuple[int, int, int] = (1, 1, 1), + stride_kv: Tuple[int, int, int] = (1, 1, 1), depthwise_conv: bool = True, bias_on: bool = True, ) -> None: @@ -298,13 +302,13 @@ def __init__( act_layer (nn.Module): Activation layer used in the Mlp layer. norm_layer (nn.Module): Normalization layer. attn_norm_layer (nn.Module): Normalization layer in the attention module. - kernel_q (_size_3_t): Pooling kernel size for q. If pooling kernel size is + kernel_q (Tuple[int, int, int]): Pooling kernel size for q. If pooling kernel size is 1 for all the dimensions, pooling is not used (by default). - kernel_kv (_size_3_t): Pooling kernel size for kv. If pooling kernel size + kernel_kv (Tuple[int, int, int]): Pooling kernel size for kv. If pooling kernel size is 1 for all the dimensions, pooling is not used. By default, pooling is disabled. - stride_q (_size_3_t): Pooling kernel stride for q. - stride_kv (_size_3_t): Pooling kernel stride for kv. + stride_q (Tuple[int, int, int]): Pooling kernel stride for q. + stride_kv (Tuple[int, int, int]): Pooling kernel stride for kv. has_cls_embed (bool): If set to True, the first token of the input tensor should be a cls token. Otherwise, the input tensor does not contain a cls token. Pooling is not applied to the cls token. @@ -334,17 +338,18 @@ def __init__( self.stochastic_depth = StochasticDepth(droppath_rate, "row") self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = MLP(dim, [mlp_hidden_dim, dim_out], activation_layer=act_layer, dropout=dropout_rate, bias=bias_on, inplace=None) + self.mlp = MLP( + dim, [mlp_hidden_dim, dim_out], activation_layer=act_layer, dropout=dropout_rate, bias=bias_on, inplace=None + ) + self.proj: Optional[nn.Module] = None if dim != dim_out: self.proj = nn.Linear(dim, dim_out, bias=bias_on) - else: - self.proj = None self.att_pool_skip = AttentionPool( nn.MaxPool3d(kernel_skip, stride_skip, padding_skip, ceil_mode=False) if len(stride_skip) > 0 and _prod(stride_skip) > 1 else None, - None + None, ) self.need_permutation = [isinstance(self.norm1, nn.BatchNorm1d), isinstance(self.norm2, nn.BatchNorm1d)] @@ -356,18 +361,12 @@ def forward(self, x: torch.Tensor, thw_shape: List[int]) -> Tuple[torch.Tensor, """ x_block, thw_shape_new = self.attn( - ( - self.norm1(x.permute(0, 2, 1)).permute(0, 2, 1) - if self.need_permutation[0] - else self.norm1(x) - ), + (self.norm1(x.permute(0, 2, 1)).permute(0, 2, 1) if self.need_permutation[0] else self.norm1(x)), thw_shape, ) x_res, _ = self.att_pool_skip(x, thw_shape) x = x_res + self.stochastic_depth(x_block) - x_norm = ( - self.norm2(x.permute(0, 2, 1)).permute(0, 2, 1) if self.need_permutation[1] else self.norm2(x) - ) + x_norm = self.norm2(x.permute(0, 2, 1)).permute(0, 2, 1) if self.need_permutation[1] else self.norm2(x) x_mlp = self.mlp(x_norm) if self.proj is not None: x = self.proj(x_norm) @@ -570,15 +569,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def create_multiscale_vision_transformers( - spatial_size: _size_2_t, + spatial_size: Tuple[int, int], temporal_size: int, depth: int = 16, # Patch embed config. input_channels: int = 3, patch_embed_dim: int = 96, - conv_patch_embed_kernel: Tuple[int] = (3, 7, 7), - conv_patch_embed_stride: Tuple[int] = (2, 4, 4), - conv_patch_embed_padding: Tuple[int] = (1, 3, 3), + conv_patch_embed_kernel: Tuple[int, int, int] = (3, 7, 7), + conv_patch_embed_stride: Tuple[int, int, int] = (2, 4, 4), + conv_patch_embed_padding: Tuple[int, int, int] = (1, 3, 3), # Attention block config. num_heads: int = 1, mlp_ratio: float = 4.0, @@ -591,8 +590,8 @@ def create_multiscale_vision_transformers( atten_head_mul: Optional[List[List[int]]] = ([1, 2.0], [3, 2.0], [14, 2.0]), pool_q_stride_size: Optional[List[List[int]]] = ([1, 1, 2, 2], [3, 1, 2, 2], [14, 1, 2, 2]), pool_kv_stride_size: Optional[List[List[int]]] = None, - pool_kv_stride_adaptive: Optional[_size_3_t] = (1, 8, 8), - pool_kvq_kernel: Optional[_size_3_t] = (3, 3, 3), + pool_kv_stride_adaptive: Optional[Tuple[int, int, int]] = (1, 8, 8), + pool_kvq_kernel: Optional[Tuple[int, int, int]] = (3, 3, 3), # Head config. head_dropout_rate: float = 0.5, num_classes: int = 400, @@ -603,7 +602,7 @@ def create_multiscale_vision_transformers( (ViT) is a specific case of MViT that only uses a single scale attention block. Args: - spatial_size (_size_2_t): Input video spatial resolution (H, W). If a single + spatial_size (Tuple[int, int]): Input video spatial resolution (H, W). If a single int is given, it assumes the width and the height are the same. temporal_size (int): Number of frames in the input video. depth (int): The depth of the model. @@ -638,11 +637,11 @@ def create_multiscale_vision_transformers( pool_kv_stride_size (Optional[List[List[int]]]): List of stride sizes for the pool kv at each layer. Format: [[i, stride_t_i, stride_h_i, stride_w_i], ...,]. - pool_kv_stride_adaptive (Optional[_size_3_t]): Initial kv stride size for the + pool_kv_stride_adaptive (Optional[Tuple[int, int, int]]): Initial kv stride size for the first block. The stride size will be further reduced at the layer where q is pooled with the ratio of the stride of q pooling. If pool_kv_stride_adaptive is set, then pool_kv_stride_size should be none. - pool_kvq_kernel (Optional[_size_3_t]): Pooling kernel size for q and kv. It None, + pool_kvq_kernel (Optional[Tuple[int, int, int]]): Pooling kernel size for q and kv. It None, the kernel_size is [s + 1 if s > 1 else s for s in stride_size]. head_dropout_rate (float): Dropout rate in the head. @@ -746,7 +745,9 @@ def create_multiscale_vision_transformers( num_heads = _make_divisible(num_heads * head_mul[i], 1) patch_embed_dim = _make_divisible(patch_embed_dim * dim_mul[i], num_heads, min_value=8) dim_out = _make_divisible( - patch_embed_dim * dim_mul[i + 1], divisor=_make_divisible(num_heads * head_mul[i + 1], 8), min_value=8, + patch_embed_dim * dim_mul[i + 1], + divisor=_make_divisible(num_heads * head_mul[i + 1], 8), + min_value=8, ) mvit_blocks.append( From c4b08bf29c038ff26222acd17a4a1bba0c84855e Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 25 May 2022 15:13:36 +0100 Subject: [PATCH 17/26] Removing input_channels. --- torchvision/models/video/mvit.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torchvision/models/video/mvit.py b/torchvision/models/video/mvit.py index f9a9b7841af..61b7f52e2d4 100644 --- a/torchvision/models/video/mvit.py +++ b/torchvision/models/video/mvit.py @@ -573,7 +573,6 @@ def create_multiscale_vision_transformers( temporal_size: int, depth: int = 16, # Patch embed config. - input_channels: int = 3, patch_embed_dim: int = 96, conv_patch_embed_kernel: Tuple[int, int, int] = (3, 7, 7), conv_patch_embed_stride: Tuple[int, int, int] = (2, 4, 4), @@ -607,7 +606,6 @@ def create_multiscale_vision_transformers( temporal_size (int): Number of frames in the input video. depth (int): The depth of the model. - input_channels (int): Channel dimension of the input video. patch_embed_dim (int): Embedding dimension after patchifing the video input. conv_patch_embed_kernel (Tuple[int]): Kernel size of the convolution for patchifing the video input. @@ -680,7 +678,7 @@ def create_multiscale_vision_transformers( patch_embed = PatchEmbed( patch_model=nn.Conv3d( - in_channels=input_channels, + in_channels=3, out_channels=patch_embed_dim, kernel_size=conv_patch_embed_kernel, stride=conv_patch_embed_stride, From 8957db3d20bd237828374b661f96c946247f619f Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 25 May 2022 15:16:40 +0100 Subject: [PATCH 18/26] Removing mlp_ratio. --- torchvision/models/video/mvit.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/torchvision/models/video/mvit.py b/torchvision/models/video/mvit.py index 61b7f52e2d4..e42d5b17451 100644 --- a/torchvision/models/video/mvit.py +++ b/torchvision/models/video/mvit.py @@ -274,7 +274,6 @@ def __init__( dim: int, dim_out: int, num_heads: int, - mlp_ratio: float = 4.0, qkv_bias: bool = False, dropout_rate: float = 0.0, droppath_rate: float = 0.0, @@ -293,8 +292,6 @@ def __init__( dim (int): Input feature dimension. dim_out (int): Output feature dimension. num_heads (int): Number of heads in the attention layer. - mlp_ratio (float): Mlp ratio which controls the feature dimension in the - hidden layer of the Mlp block. qkv_bias (bool): If set to False, the qkv layer will not learn an additive bias. Default: False. dropout_rate (float): DropOut rate. If set to 0, DropOut is disabled. @@ -337,7 +334,7 @@ def __init__( ) self.stochastic_depth = StochasticDepth(droppath_rate, "row") self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) + mlp_hidden_dim = 4 * dim # 4x mlp ratio self.mlp = MLP( dim, [mlp_hidden_dim, dim_out], activation_layer=act_layer, dropout=dropout_rate, bias=bias_on, inplace=None ) @@ -579,7 +576,6 @@ def create_multiscale_vision_transformers( conv_patch_embed_padding: Tuple[int, int, int] = (1, 3, 3), # Attention block config. num_heads: int = 1, - mlp_ratio: float = 4.0, qkv_bias: bool = True, dropout_rate_block: float = 0.0, droppath_rate_block: float = 0.0, @@ -615,8 +611,6 @@ def create_multiscale_vision_transformers( patchifing the video input. num_heads (int): Number of heads in the first transformer block. - mlp_ratio (float): Mlp ratio which controls the feature dimension in the - hidden layer of the Mlp block. qkv_bias (bool): If set to False, the qkv layer will not learn an additive bias. Default: True. dropout_rate_block (float): Dropout rate for the attention block. @@ -753,7 +747,6 @@ def create_multiscale_vision_transformers( dim=patch_embed_dim, dim_out=dim_out, num_heads=num_heads, - mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, dropout_rate=dropout_rate_block, droppath_rate=dpr[i], From 0d4d5da9d849dc335da42d251b7c27aa55cffc26 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 25 May 2022 15:25:30 +0100 Subject: [PATCH 19/26] Removing qkv_bias. --- torchvision/models/video/mvit.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/torchvision/models/video/mvit.py b/torchvision/models/video/mvit.py index e42d5b17451..fe47159c624 100644 --- a/torchvision/models/video/mvit.py +++ b/torchvision/models/video/mvit.py @@ -113,7 +113,6 @@ def __init__( self, dim: int, num_heads: int = 8, - qkv_bias: bool = False, dropout_rate: float = 0.0, kernel_q: Tuple[int, int, int] = (1, 1, 1), kernel_kv: Tuple[int, int, int] = (1, 1, 1), @@ -127,8 +126,6 @@ def __init__( Args: dim (int): Input feature dimension. num_heads (int): Number of heads in the attention layer. - qkv_bias (bool): If set to False, the qkv layer will not learn an additive - bias. Default: False. dropout_rate (float): Dropout rate. kernel_q (Tuple[int, int, int]): Pooling kernel size for q. If both pooling kernel size and pooling stride size are 1 for all the dimensions, pooling is @@ -152,7 +149,7 @@ def __init__( padding_q = [int(q // 2) for q in kernel_q] padding_kv = [int(kv // 2) for kv in kernel_kv] - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.qkv = nn.Linear(dim, dim * 3) self.proj = nn.Linear(dim, dim, bias=True if bias_on else False) self.proj_drop = nn.Dropout(dropout_rate) @@ -274,7 +271,6 @@ def __init__( dim: int, dim_out: int, num_heads: int, - qkv_bias: bool = False, dropout_rate: float = 0.0, droppath_rate: float = 0.0, act_layer: Callable[..., nn.Module] = nn.GELU, @@ -292,8 +288,6 @@ def __init__( dim (int): Input feature dimension. dim_out (int): Output feature dimension. num_heads (int): Number of heads in the attention layer. - qkv_bias (bool): If set to False, the qkv layer will not learn an additive - bias. Default: False. dropout_rate (float): DropOut rate. If set to 0, DropOut is disabled. droppath_rate (float): DropPath rate. If set to 0, DropPath is disabled. act_layer (nn.Module): Activation layer used in the Mlp layer. @@ -322,7 +316,6 @@ def __init__( self.attn = MultiScaleAttention( dim, num_heads=num_heads, - qkv_bias=qkv_bias, dropout_rate=dropout_rate, kernel_q=kernel_q, kernel_kv=kernel_kv, @@ -576,7 +569,6 @@ def create_multiscale_vision_transformers( conv_patch_embed_padding: Tuple[int, int, int] = (1, 3, 3), # Attention block config. num_heads: int = 1, - qkv_bias: bool = True, dropout_rate_block: float = 0.0, droppath_rate_block: float = 0.0, depthwise_conv: bool = True, @@ -611,8 +603,6 @@ def create_multiscale_vision_transformers( patchifing the video input. num_heads (int): Number of heads in the first transformer block. - qkv_bias (bool): If set to False, the qkv layer will not learn an additive - bias. Default: True. dropout_rate_block (float): Dropout rate for the attention block. droppath_rate_block (float): Droppath rate for the attention block. depthwise_conv (bool): Wether use depthwise or full convolution for pooling. @@ -747,7 +737,6 @@ def create_multiscale_vision_transformers( dim=patch_embed_dim, dim_out=dim_out, num_heads=num_heads, - qkv_bias=qkv_bias, dropout_rate=dropout_rate_block, droppath_rate=dpr[i], norm_layer=block_norm_layer, From 11c855462bb6e249f0f293893e2564414f1fde06 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 25 May 2022 15:30:27 +0100 Subject: [PATCH 20/26] Removing dropout_rate_block. --- torchvision/models/video/mvit.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/torchvision/models/video/mvit.py b/torchvision/models/video/mvit.py index fe47159c624..8fefdd8ffc3 100644 --- a/torchvision/models/video/mvit.py +++ b/torchvision/models/video/mvit.py @@ -507,7 +507,6 @@ def __init__( self, patch_embed: Optional[nn.Module], cls_positional_encoding: nn.Module, - pos_drop: Optional[nn.Module], blocks: nn.ModuleList, norm_embed: Optional[nn.Module], head: Optional[nn.Module], @@ -516,7 +515,6 @@ def __init__( Args: patch_embed (nn.Module): Patch embed module. cls_positional_encoding (nn.Module): Positional encoding module. - pos_drop (Optional[nn.Module]): Dropout module after patch embed. blocks (nn.ModuleList): Stack of multi-scale transformer blocks. norm_layer (nn.Module): Normalization layer before head. head (Optional[nn.Module]): Head module. @@ -524,7 +522,6 @@ def __init__( super().__init__() self.patch_embed = patch_embed self.cls_positional_encoding = cls_positional_encoding - self.pos_drop = pos_drop self.blocks = blocks self.norm_embed = norm_embed self.head = head @@ -545,9 +542,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.patch_embed(x) x = self.cls_positional_encoding(x) - if self.pos_drop is not None: - x = self.pos_drop(x) - thw = self.cls_positional_encoding.patch_embed_shape for blk in self.blocks: x, thw = blk(x, thw) @@ -569,7 +563,6 @@ def create_multiscale_vision_transformers( conv_patch_embed_padding: Tuple[int, int, int] = (1, 3, 3), # Attention block config. num_heads: int = 1, - dropout_rate_block: float = 0.0, droppath_rate_block: float = 0.0, depthwise_conv: bool = True, bias_on: bool = True, @@ -603,7 +596,6 @@ def create_multiscale_vision_transformers( patchifing the video input. num_heads (int): Number of heads in the first transformer block. - dropout_rate_block (float): Dropout rate for the attention block. droppath_rate_block (float): Droppath rate for the attention block. depthwise_conv (bool): Wether use depthwise or full convolution for pooling. bias_on (bool): Wether use biases for linear layers. @@ -737,7 +729,6 @@ def create_multiscale_vision_transformers( dim=patch_embed_dim, dim_out=dim_out, num_heads=num_heads, - dropout_rate=dropout_rate_block, droppath_rate=dpr[i], norm_layer=block_norm_layer, attn_norm_layer=attn_norm_layer, @@ -760,7 +751,6 @@ def create_multiscale_vision_transformers( return MultiscaleVisionTransformers( patch_embed=patch_embed, cls_positional_encoding=cls_positional_encoding, - pos_drop=nn.Dropout(p=dropout_rate_block) if dropout_rate_block > 0.0 else None, blocks=mvit_blocks, norm_embed=norm_layer(embed_dim), head=head_model, From ca9506a289437b1c87424ef4003b5132950d22a8 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 25 May 2022 15:36:19 +0100 Subject: [PATCH 21/26] Rename var and clean up docs --- torchvision/models/video/mvit.py | 135 ++----------------------------- 1 file changed, 7 insertions(+), 128 deletions(-) diff --git a/torchvision/models/video/mvit.py b/torchvision/models/video/mvit.py index 8fefdd8ffc3..602604e4f11 100644 --- a/torchvision/models/video/mvit.py +++ b/torchvision/models/video/mvit.py @@ -83,32 +83,6 @@ def forward( class MultiScaleAttention(nn.Module): - """ - Implementation of a multiscale attention block. Compare to a conventional attention - block, a multiscale attention block optionally supports pooling (either - before or after qkv projection). If pooling is not used, a multiscale attention - block is equivalent to a conventional attention block. - - :: - Input - | - |----------------|-----------------| - ↓ ↓ ↓ - Linear Linear Linear - & & & - Pool (Q) Pool (K) Pool (V) - → -------------- ← | - ↓ | - MatMul & Scale | - ↓ | - Softmax | - → ----------------------- ← - ↓ - MatMul & Scale - ↓ - DropOut - """ - def __init__( self, dim: int, @@ -201,12 +175,6 @@ def __init__( ) def forward(self, x: torch.Tensor, thw_shape: List[int]) -> Tuple[torch.Tensor, List[int]]: - """ - Args: - x (torch.Tensor): Input tensor. - thw_shape (List): The shape of the input tensor (before flattening). - """ - B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) @@ -237,42 +205,13 @@ def forward(self, x: torch.Tensor, thw_shape: List[int]) -> Tuple[torch.Tensor, class MultiScaleBlock(nn.Module): - """ - Implementation of a multiscale vision transformer block. Each block contains a - multiscale attention layer and a Mlp layer. - - :: - - - Input - |-------------------+ - ↓ | - Norm | - ↓ | - MultiScaleAttention Pool - ↓ | - DropPath | - ↓ | - Summation ←-------------+ - | - |-------------------+ - ↓ | - Norm | - ↓ | - Mlp Proj - ↓ | - DropPath | - ↓ | - Summation ←------------+ - """ - def __init__( self, dim: int, dim_out: int, num_heads: int, dropout_rate: float = 0.0, - droppath_rate: float = 0.0, + stochastic_depth_prob: float = 0.0, act_layer: Callable[..., nn.Module] = nn.GELU, norm_layer: Callable[..., nn.Module] = nn.LayerNorm, attn_norm_layer: Callable[..., nn.Module] = nn.LayerNorm, @@ -289,7 +228,7 @@ def __init__( dim_out (int): Output feature dimension. num_heads (int): Number of heads in the attention layer. dropout_rate (float): DropOut rate. If set to 0, DropOut is disabled. - droppath_rate (float): DropPath rate. If set to 0, DropPath is disabled. + stochastic_depth_prob (float): Stochastic Depth probability. If set to 0, it's disabled. act_layer (nn.Module): Activation layer used in the Mlp layer. norm_layer (nn.Module): Normalization layer. attn_norm_layer (nn.Module): Normalization layer in the attention module. @@ -325,7 +264,7 @@ def __init__( bias_on=bias_on, depthwise_conv=depthwise_conv, ) - self.stochastic_depth = StochasticDepth(droppath_rate, "row") + self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") self.norm2 = norm_layer(dim) mlp_hidden_dim = 4 * dim # 4x mlp ratio self.mlp = MLP( @@ -344,12 +283,6 @@ def __init__( self.need_permutation = [isinstance(self.norm1, nn.BatchNorm1d), isinstance(self.norm2, nn.BatchNorm1d)] def forward(self, x: torch.Tensor, thw_shape: List[int]) -> Tuple[torch.Tensor, List[int]]: - """ - Args: - x (torch.Tensor): Input tensor. - thw_shape (List): The shape of the input tensor (before flattening). - """ - x_block, thw_shape_new = self.attn( (self.norm1(x.permute(0, 2, 1)).permute(0, 2, 1) if self.need_permutation[0] else self.norm1(x)), thw_shape, @@ -365,10 +298,6 @@ def forward(self, x: torch.Tensor, thw_shape: List[int]) -> Tuple[torch.Tensor, class SpatioTemporalClsPositionalEncoding(nn.Module): - """ - Add a cls token and apply a spatiotemporal encoding to a tensor. - """ - def __init__( self, embed_dim: int, @@ -397,10 +326,6 @@ def patch_embed_shape(self): return self._patch_embed_shape def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Args: - x (torch.Tensor): Input tensor. - """ B, N, C = x.shape cls_tokens = self.cls_token.expand(B, -1, -1) @@ -420,10 +345,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class ClassificationHead(nn.Module): def __init__( self, - # Projection configs. in_features: int, out_features: int, - # Dropout configs. dropout_rate: float = 0.5, ) -> None: super().__init__() @@ -431,32 +354,13 @@ def __init__( self.proj = nn.Linear(in_features, out_features) def forward(self, x: torch.Tensor) -> torch.Tensor: - # Pick cls embedding x = x[:, 0] - # Performs dropout. x = self.dropout(x) - # Performs projection. x = self.proj(x) return x class PatchEmbed(nn.Module): - """ - Transformer basic patch embedding module. Performs patchifying input, flatten and - and transpose. - - :: - - PatchModel - ↓ - flatten - ↓ - transpose - - The builder can be found in `create_patch_embed`. - - """ - def __init__( self, patch_model: nn.Module, @@ -476,31 +380,6 @@ class MultiscaleVisionTransformers(nn.Module): Haoqi Fan, Bo Xiong, Karttikeya Mangalam, Yanghao Li, Zhicheng Yan, Jitendra Malik, Christoph Feichtenhofer https://arxiv.org/abs/2104.11227 - - :: - - PatchEmbed - ↓ - PositionalEncoding - ↓ - Dropout - ↓ - Normalization - ↓ - Block 1 - ↓ - . - . - . - ↓ - Block N - ↓ - Normalization - ↓ - Head - - - The builder can be found in `create_mvit`. """ def __init__( @@ -563,7 +442,7 @@ def create_multiscale_vision_transformers( conv_patch_embed_padding: Tuple[int, int, int] = (1, 3, 3), # Attention block config. num_heads: int = 1, - droppath_rate_block: float = 0.0, + stochastic_depth_prob_block: float = 0.0, depthwise_conv: bool = True, bias_on: bool = True, embed_dim_mul: Optional[List[List[int]]] = ([1, 2.0], [3, 2.0], [14, 2.0]), @@ -596,7 +475,7 @@ def create_multiscale_vision_transformers( patchifing the video input. num_heads (int): Number of heads in the first transformer block. - droppath_rate_block (float): Droppath rate for the attention block. + stochastic_depth_prob_block (float): Stochastic Depth probability for the attention block. depthwise_conv (bool): Wether use depthwise or full convolution for pooling. bias_on (bool): Wether use biases for linear layers. embed_dim_mul (Optional[List[List[int]]]): Dimension multiplication at layer i. @@ -672,7 +551,7 @@ def create_multiscale_vision_transformers( patch_embed_shape=patch_embed_shape, ) - dpr = [x.item() for x in torch.linspace(0, droppath_rate_block, depth)] # stochastic depth decay rule + dpr = [x.item() for x in torch.linspace(0, stochastic_depth_prob_block, depth)] # stochastic depth decay rule dim_mul, head_mul = torch.ones(depth + 1), torch.ones(depth + 1) if embed_dim_mul is not None: @@ -729,7 +608,7 @@ def create_multiscale_vision_transformers( dim=patch_embed_dim, dim_out=dim_out, num_heads=num_heads, - droppath_rate=dpr[i], + stochastic_depth_prob=dpr[i], norm_layer=block_norm_layer, attn_norm_layer=attn_norm_layer, kernel_q=pool_q[i], From f386f866407247ce11b4b5a0aae0f9949e09d924 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 25 May 2022 15:39:21 +0100 Subject: [PATCH 22/26] Remove bias_on. --- torchvision/models/video/mvit.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/torchvision/models/video/mvit.py b/torchvision/models/video/mvit.py index 602604e4f11..15fb654f164 100644 --- a/torchvision/models/video/mvit.py +++ b/torchvision/models/video/mvit.py @@ -94,7 +94,6 @@ def __init__( stride_kv: Tuple[int, int, int] = (1, 1, 1), norm_layer: Callable[..., nn.Module] = nn.LayerNorm, depthwise_conv: bool = True, - bias_on: bool = True, ) -> None: """ Args: @@ -111,7 +110,6 @@ def __init__( stride_kv (Tuple[int, int, int]): Pooling kernel stride for kv. norm_layer (nn.Module): Normalization layer used after pooling. depthwise_conv (bool): Wether use depthwise or full convolution for pooling. - bias_on (bool): Wether use biases for linear layers. """ super().__init__() @@ -124,7 +122,7 @@ def __init__( padding_kv = [int(kv // 2) for kv in kernel_kv] self.qkv = nn.Linear(dim, dim * 3) - self.proj = nn.Linear(dim, dim, bias=True if bias_on else False) + self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(dropout_rate) # Skip pooling with kernel and stride size of (1, 1, 1). @@ -220,7 +218,6 @@ def __init__( stride_q: Tuple[int, int, int] = (1, 1, 1), stride_kv: Tuple[int, int, int] = (1, 1, 1), depthwise_conv: bool = True, - bias_on: bool = True, ) -> None: """ Args: @@ -243,7 +240,6 @@ def __init__( should be a cls token. Otherwise, the input tensor does not contain a cls token. Pooling is not applied to the cls token. depthwise_conv (bool): Wether use depthwise or full convolution for pooling. - bias_on (bool): Wether use biases for linear layers. """ super().__init__() self.dim = dim @@ -261,18 +257,17 @@ def __init__( stride_q=stride_q, stride_kv=stride_kv, norm_layer=attn_norm_layer, - bias_on=bias_on, depthwise_conv=depthwise_conv, ) self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") self.norm2 = norm_layer(dim) mlp_hidden_dim = 4 * dim # 4x mlp ratio self.mlp = MLP( - dim, [mlp_hidden_dim, dim_out], activation_layer=act_layer, dropout=dropout_rate, bias=bias_on, inplace=None + dim, [mlp_hidden_dim, dim_out], activation_layer=act_layer, dropout=dropout_rate, inplace=None ) self.proj: Optional[nn.Module] = None if dim != dim_out: - self.proj = nn.Linear(dim, dim_out, bias=bias_on) + self.proj = nn.Linear(dim, dim_out) self.att_pool_skip = AttentionPool( nn.MaxPool3d(kernel_skip, stride_skip, padding_skip, ceil_mode=False) @@ -444,7 +439,6 @@ def create_multiscale_vision_transformers( num_heads: int = 1, stochastic_depth_prob_block: float = 0.0, depthwise_conv: bool = True, - bias_on: bool = True, embed_dim_mul: Optional[List[List[int]]] = ([1, 2.0], [3, 2.0], [14, 2.0]), atten_head_mul: Optional[List[List[int]]] = ([1, 2.0], [3, 2.0], [14, 2.0]), pool_q_stride_size: Optional[List[List[int]]] = ([1, 1, 2, 2], [3, 1, 2, 2], [14, 1, 2, 2]), @@ -477,7 +471,6 @@ def create_multiscale_vision_transformers( num_heads (int): Number of heads in the first transformer block. stochastic_depth_prob_block (float): Stochastic Depth probability for the attention block. depthwise_conv (bool): Wether use depthwise or full convolution for pooling. - bias_on (bool): Wether use biases for linear layers. embed_dim_mul (Optional[List[List[int]]]): Dimension multiplication at layer i. If X is used, then the next block will increase the embed dimension by X times. Format: [depth_i, mul_dim_ratio]. @@ -615,7 +608,6 @@ def create_multiscale_vision_transformers( kernel_kv=pool_kv[i], stride_q=stride_q[i], stride_kv=stride_kv[i], - bias_on=bias_on, depthwise_conv=depthwise_conv, ) ) From d5a3e324d40eab44308c83622a1101920eb2043f Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 25 May 2022 15:41:06 +0100 Subject: [PATCH 23/26] Remove depthwise_conv. --- torchvision/models/video/mvit.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/torchvision/models/video/mvit.py b/torchvision/models/video/mvit.py index 15fb654f164..5451cfa29dd 100644 --- a/torchvision/models/video/mvit.py +++ b/torchvision/models/video/mvit.py @@ -93,7 +93,6 @@ def __init__( stride_q: Tuple[int, int, int] = (1, 1, 1), stride_kv: Tuple[int, int, int] = (1, 1, 1), norm_layer: Callable[..., nn.Module] = nn.LayerNorm, - depthwise_conv: bool = True, ) -> None: """ Args: @@ -109,7 +108,6 @@ def __init__( stride_q (Tuple[int, int, int]): Pooling kernel stride for q. stride_kv (Tuple[int, int, int]): Pooling kernel stride for kv. norm_layer (nn.Module): Normalization layer used after pooling. - depthwise_conv (bool): Wether use depthwise or full convolution for pooling. """ super().__init__() @@ -136,7 +134,7 @@ def __init__( kernel_q, stride=stride_q, padding=padding_q, - groups=head_dim if depthwise_conv else 1, + groups=head_dim, bias=False, ) if not skip_pool_q @@ -150,7 +148,7 @@ def __init__( kernel_kv, stride=stride_kv, padding=padding_kv, - groups=head_dim if depthwise_conv else 1, + groups=head_dim, bias=False, ) if not skip_pool_kv @@ -164,7 +162,7 @@ def __init__( kernel_kv, stride=stride_kv, padding=padding_kv, - groups=head_dim if depthwise_conv else 1, + groups=head_dim, bias=False, ) if not skip_pool_kv @@ -217,7 +215,6 @@ def __init__( kernel_kv: Tuple[int, int, int] = (1, 1, 1), stride_q: Tuple[int, int, int] = (1, 1, 1), stride_kv: Tuple[int, int, int] = (1, 1, 1), - depthwise_conv: bool = True, ) -> None: """ Args: @@ -239,7 +236,6 @@ def __init__( has_cls_embed (bool): If set to True, the first token of the input tensor should be a cls token. Otherwise, the input tensor does not contain a cls token. Pooling is not applied to the cls token. - depthwise_conv (bool): Wether use depthwise or full convolution for pooling. """ super().__init__() self.dim = dim @@ -257,7 +253,6 @@ def __init__( stride_q=stride_q, stride_kv=stride_kv, norm_layer=attn_norm_layer, - depthwise_conv=depthwise_conv, ) self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") self.norm2 = norm_layer(dim) @@ -438,7 +433,6 @@ def create_multiscale_vision_transformers( # Attention block config. num_heads: int = 1, stochastic_depth_prob_block: float = 0.0, - depthwise_conv: bool = True, embed_dim_mul: Optional[List[List[int]]] = ([1, 2.0], [3, 2.0], [14, 2.0]), atten_head_mul: Optional[List[List[int]]] = ([1, 2.0], [3, 2.0], [14, 2.0]), pool_q_stride_size: Optional[List[List[int]]] = ([1, 1, 2, 2], [3, 1, 2, 2], [14, 1, 2, 2]), @@ -470,7 +464,6 @@ def create_multiscale_vision_transformers( num_heads (int): Number of heads in the first transformer block. stochastic_depth_prob_block (float): Stochastic Depth probability for the attention block. - depthwise_conv (bool): Wether use depthwise or full convolution for pooling. embed_dim_mul (Optional[List[List[int]]]): Dimension multiplication at layer i. If X is used, then the next block will increase the embed dimension by X times. Format: [depth_i, mul_dim_ratio]. @@ -608,7 +601,6 @@ def create_multiscale_vision_transformers( kernel_kv=pool_kv[i], stride_q=stride_q[i], stride_kv=stride_kv[i], - depthwise_conv=depthwise_conv, ) ) From 7d569d59f1a373b4968c8df541ecd19ad4e4d9a5 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 25 May 2022 15:55:48 +0100 Subject: [PATCH 24/26] Remove conv_patch_embed_kernel|stride|padding --- torchvision/models/video/mvit.py | 33 ++++++++------------------------ 1 file changed, 8 insertions(+), 25 deletions(-) diff --git a/torchvision/models/video/mvit.py b/torchvision/models/video/mvit.py index 5451cfa29dd..80ed046290a 100644 --- a/torchvision/models/video/mvit.py +++ b/torchvision/models/video/mvit.py @@ -300,7 +300,6 @@ def __init__( (T, H, W) after patch embedding. """ super().__init__() - assert len(patch_embed_shape) == 3, "Patch_embed_shape should be in the form of (T, H, W)." self._patch_embed_shape = patch_embed_shape self.num_spatial_patch = patch_embed_shape[1] * patch_embed_shape[2] self.num_temporal_patch = patch_embed_shape[0] @@ -427,9 +426,6 @@ def create_multiscale_vision_transformers( depth: int = 16, # Patch embed config. patch_embed_dim: int = 96, - conv_patch_embed_kernel: Tuple[int, int, int] = (3, 7, 7), - conv_patch_embed_stride: Tuple[int, int, int] = (2, 4, 4), - conv_patch_embed_padding: Tuple[int, int, int] = (1, 3, 3), # Attention block config. num_heads: int = 1, stochastic_depth_prob_block: float = 0.0, @@ -449,18 +445,11 @@ def create_multiscale_vision_transformers( (ViT) is a specific case of MViT that only uses a single scale attention block. Args: - spatial_size (Tuple[int, int]): Input video spatial resolution (H, W). If a single - int is given, it assumes the width and the height are the same. + spatial_size (Tuple[int, int]): Input video spatial resolution (H, W). temporal_size (int): Number of frames in the input video. depth (int): The depth of the model. patch_embed_dim (int): Embedding dimension after patchifing the video input. - conv_patch_embed_kernel (Tuple[int]): Kernel size of the convolution for - patchifing the video input. - conv_patch_embed_stride (Tuple[int]): Stride size of the convolution for - patchifing the video input. - conv_patch_embed_padding (Tuple[int]): Padding size of the convolution for - patchifing the video input. num_heads (int): Number of heads in the first transformer block. stochastic_depth_prob_block (float): Stochastic Depth probability for the attention block. @@ -488,7 +477,7 @@ def create_multiscale_vision_transformers( Example usage (building a MViT_B model for Kinetics400): - spatial_size = 224 + spatial_size = (224, 224) temporal_size = 16 embed_dim_mul = [[1, 2.0], [3, 2.0], [14, 2.0]] atten_head_mul = [[1, 2.0], [3, 2.0], [14, 2.0]] @@ -514,24 +503,18 @@ def create_multiscale_vision_transformers( block_norm_layer = partial(nn.LayerNorm, eps=1e-6) attn_norm_layer = partial(nn.LayerNorm, eps=1e-6) - if isinstance(spatial_size, int): - spatial_size = (spatial_size, spatial_size) - + s = (2, 4, 4) patch_embed = PatchEmbed( patch_model=nn.Conv3d( in_channels=3, out_channels=patch_embed_dim, - kernel_size=conv_patch_embed_kernel, - stride=conv_patch_embed_stride, - padding=conv_patch_embed_padding, + kernel_size=(3, 7, 7), + stride=s, + padding=(1, 3, 3), bias=True, ) ) - - input_dims = (temporal_size, spatial_size[0], spatial_size[1]) - - patch_embed_shape = tuple(v // conv_patch_embed_stride[i] for i, v in enumerate(input_dims)) - + patch_embed_shape = (temporal_size // s[0], spatial_size[0] // s[1], spatial_size[1] // s[2]) cls_positional_encoding = SpatioTemporalClsPositionalEncoding( embed_dim=patch_embed_dim, patch_embed_shape=patch_embed_shape, @@ -621,7 +604,7 @@ def create_multiscale_vision_transformers( def mvit_b_16( - spatial_size=224, + spatial_size=(224, 224), temporal_size=16, num_classes=400, **kwargs, From 2386e07e34e4dbc51214fb5b451dac7d1774ea16 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 25 May 2022 19:04:51 +0100 Subject: [PATCH 25/26] Remove pool_kv_stride_size, pool_kv_stride_adaptive and pool_kvq_kernel. --- torchvision/models/video/mvit.py | 56 +++++++++----------------------- 1 file changed, 15 insertions(+), 41 deletions(-) diff --git a/torchvision/models/video/mvit.py b/torchvision/models/video/mvit.py index 80ed046290a..58ba8931689 100644 --- a/torchvision/models/video/mvit.py +++ b/torchvision/models/video/mvit.py @@ -432,9 +432,6 @@ def create_multiscale_vision_transformers( embed_dim_mul: Optional[List[List[int]]] = ([1, 2.0], [3, 2.0], [14, 2.0]), atten_head_mul: Optional[List[List[int]]] = ([1, 2.0], [3, 2.0], [14, 2.0]), pool_q_stride_size: Optional[List[List[int]]] = ([1, 1, 2, 2], [3, 1, 2, 2], [14, 1, 2, 2]), - pool_kv_stride_size: Optional[List[List[int]]] = None, - pool_kv_stride_adaptive: Optional[Tuple[int, int, int]] = (1, 8, 8), - pool_kvq_kernel: Optional[Tuple[int, int, int]] = (3, 3, 3), # Head config. head_dropout_rate: float = 0.5, num_classes: int = 400, @@ -462,15 +459,6 @@ def create_multiscale_vision_transformers( pool_q_stride_size (Optional[List[List[int]]]): List of stride sizes for the pool q at each layer. Format: [[i, stride_t_i, stride_h_i, stride_w_i], ...,]. - pool_kv_stride_size (Optional[List[List[int]]]): List of stride sizes for the - pool kv at each layer. Format: - [[i, stride_t_i, stride_h_i, stride_w_i], ...,]. - pool_kv_stride_adaptive (Optional[Tuple[int, int, int]]): Initial kv stride size for the - first block. The stride size will be further reduced at the layer where q - is pooled with the ratio of the stride of q pooling. If - pool_kv_stride_adaptive is set, then pool_kv_stride_size should be none. - pool_kvq_kernel (Optional[Tuple[int, int, int]]): Pooling kernel size for q and kv. It None, - the kernel_size is [s + 1 if s > 1 else s for s in stride_size]. head_dropout_rate (float): Dropout rate in the head. num_classes (int): Number of classes in the final classification head. @@ -482,23 +470,16 @@ def create_multiscale_vision_transformers( embed_dim_mul = [[1, 2.0], [3, 2.0], [14, 2.0]] atten_head_mul = [[1, 2.0], [3, 2.0], [14, 2.0]] pool_q_stride_size = [[1, 1, 2, 2], [3, 1, 2, 2], [14, 1, 2, 2]] - pool_kv_stride_adaptive = [1, 8, 8] - pool_kvq_kernel = [3, 3, 3] num_classes = 400 - MViT_B = create_multiscale_vision_transformers( + MViTv2_S = create_multiscale_vision_transformers( spatial_size=spatial_size, temporal_size=temporal_size, embed_dim_mul=embed_dim_mul, atten_head_mul=atten_head_mul, pool_q_stride_size=pool_q_stride_size, - pool_kv_stride_adaptive=pool_kv_stride_adaptive, - pool_kvq_kernel=pool_kvq_kernel, num_classes=num_classes, ) """ - - if pool_kv_stride_adaptive is not None: - assert pool_kv_stride_size is None, "pool_kv_stride_size should be none if pool_kv_stride_adaptive is set." norm_layer = partial(nn.LayerNorm, eps=1e-6) block_norm_layer = partial(nn.LayerNorm, eps=1e-6) attn_norm_layer = partial(nn.LayerNorm, eps=1e-6) @@ -537,30 +518,22 @@ def create_multiscale_vision_transformers( stride_q = [[] for _ in range(depth)] stride_kv = [[] for _ in range(depth)] + pool_kvq_kernel = (3, 3, 3) if pool_q_stride_size is not None: for i in range(len(pool_q_stride_size)): stride_q[pool_q_stride_size[i][0]] = pool_q_stride_size[i][1:] - if pool_kvq_kernel is not None: - pool_q[pool_q_stride_size[i][0]] = pool_kvq_kernel - else: - pool_q[pool_q_stride_size[i][0]] = [s + 1 if s > 1 else s for s in pool_q_stride_size[i][1:]] - - # If POOL_KV_STRIDE_ADAPTIVE is not None, initialize POOL_KV_STRIDE. - if pool_kv_stride_adaptive is not None: - _stride_kv = pool_kv_stride_adaptive - pool_kv_stride_size = [] - for i in range(depth): - if len(stride_q[i]) > 0: - _stride_kv = [max(_stride_kv[d] // stride_q[i][d], 1) for d in range(len(_stride_kv))] - pool_kv_stride_size.append([i] + list(_stride_kv)) - - if pool_kv_stride_size is not None: - for i in range(len(pool_kv_stride_size)): - stride_kv[pool_kv_stride_size[i][0]] = pool_kv_stride_size[i][1:] - if pool_kvq_kernel is not None: - pool_kv[pool_kv_stride_size[i][0]] = pool_kvq_kernel - else: - pool_kv[pool_kv_stride_size[i][0]] = [s + 1 if s > 1 else s for s in pool_kv_stride_size[i][1:]] + pool_q[pool_q_stride_size[i][0]] = pool_kvq_kernel + + _stride_kv = (1, 8, 8) + pool_kv_stride_size = [] + for i in range(depth): + if len(stride_q[i]) > 0: + _stride_kv = [max(_stride_kv[d] // stride_q[i][d], 1) for d in range(len(_stride_kv))] + pool_kv_stride_size.append([i] + list(_stride_kv)) + + for i in range(len(pool_kv_stride_size)): + stride_kv[pool_kv_stride_size[i][0]] = pool_kv_stride_size[i][1:] + pool_kv[pool_kv_stride_size[i][0]] = pool_kvq_kernel dim_out = 0 for i in range(depth): @@ -603,6 +576,7 @@ def create_multiscale_vision_transformers( ) +# TODO: rename this. This is actually the small version for v2 def mvit_b_16( spatial_size=(224, 224), temporal_size=16, From 48ff9e1fcf9dc49b15fc3ab75051410b2f62f16f Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 25 May 2022 21:22:22 +0100 Subject: [PATCH 26/26] Adding real variants with validation files produced by the original implementation. --- ...l => ModelTester.test_mvitv2_b_expect.pkl} | Bin 939 -> 939 bytes .../ModelTester.test_mvitv2_s_expect.pkl | Bin 0 -> 939 bytes .../ModelTester.test_mvitv2_t_expect.pkl | Bin 0 -> 939 bytes test/test_models.py | 15 +++- torchvision/models/video/mvit.py | 83 +++++++++++++++--- 5 files changed, 81 insertions(+), 17 deletions(-) rename test/expect/{ModelTester.test_mvit_b_16_expect.pkl => ModelTester.test_mvitv2_b_expect.pkl} (65%) create mode 100644 test/expect/ModelTester.test_mvitv2_s_expect.pkl create mode 100644 test/expect/ModelTester.test_mvitv2_t_expect.pkl diff --git a/test/expect/ModelTester.test_mvit_b_16_expect.pkl b/test/expect/ModelTester.test_mvitv2_b_expect.pkl similarity index 65% rename from test/expect/ModelTester.test_mvit_b_16_expect.pkl rename to test/expect/ModelTester.test_mvitv2_b_expect.pkl index 0ba169ef1d86077665fe79f0cd1934911cb152a6..a39dd2e5754797feaba9c8e10f103900e77a6c98 100644 GIT binary patch delta 230 zcmVUD)K(Ia$S*boSaz;E*pdh`{Y?Z#TbJf09vDLq;S)spf`cJ>4G@?HdFde=V zNix6k|71SzdJaF9^=iJD#uh&R=_Nm;+K#?c7{b0+eR$?2z^i>SL z3;~?JyxpU|>OGl1f4MNf(&2?Z_!TWb%;RXke&Q>?Xt<0&rY|NxNQwqN6X}M&N+XWG gD2PzMV-gL&P)i30vKe4hlMn*X1hN@mRFmWa$0WpSF#rGn delta 230 zcmVuWv^*J3`0Z!o@;ARs=Y9X38d|5ZJ>x9&ca<$+N@{ALrrv^2Us?%@N! z%DfFfaPmJtJa`+wo>7gzE;j)`c}qM$0BpuOb;zH-ddl^_rZ zf)S|3ppZF&8AvA=loqmh8ZvUemW=jY_4CYNO9=M{7L z7p0^YrKY%KCYNv(a%ct>a+VZw1r>7Z1$eV_Fj*X^nFTZrgadH;l#f9R#i#lPZcb`w z{zUOK5=U>o(!Q?udi&G=JhWAO>uf(|x4C_RL5lqwr#M^7!iD>e*~i%}|Lr5t{wZNh`=4ETzVD5R;(ocL$#&@lHxrBfrgKjF7X$>`FDPo*)9`=wZV%Q4doC~@ z-FqZ*mE9(7r~S2w#`dv4U)$9rKi*sL_03)z7r%XC{}=CV^XJ#89K2T`Q(|Et_1TdsP7;2j35f0CXwS%03?9|&{HV7Ze&04q3C=C?(c~y%Ind!t_GJ zAi$fAO$Vw-j#(G39F&+r07h?za2Y0nJqhwI8z^ructRC`GC_bhD;r3R83;k@A!-5R CO8RjC literal 0 HcmV?d00001 diff --git a/test/expect/ModelTester.test_mvitv2_t_expect.pkl b/test/expect/ModelTester.test_mvitv2_t_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..384fe05b50c9fc4fcc92c39f62d6dc299282ea9f GIT binary patch literal 939 zcmWIWW@cev;NW1u00Im`42ea_8JT6N`YDMeiFyUuIc`pT3{fbcfhoBpAE-(%zO*DW zr zf)S|3ppZF&8AvA=loqmh8ZvUemW=jY_4CYNO9=M{7L z7p0^YrKY%KCYNv(a%ct>a+VZw1r>7Z1$eV_Fj*X^nFTZrgadH;l#f9R#i#lPZcb`w z{zUOK630<(z3r5iKXw5nDSNkg`|tOO6S5UQG;`mh6C3Qgwkp`asd~Gw?Gu~5$V%@0 z)1L3zJ0pDezJEt<+x@X$yN84Iq8&qpo^{*vpL=aOH`%R;+_rC~v4q`}y$9@UPA;%B zIF)JNakO_oLzaP++ea(=yx6<8M<=xIXIXh@FRRlcJ9br}{jyvRb`BqV_d0m%*oty8 z?AP3?X|FYdbAQF&{reA{{bTz)nA5IrVUm6E62|>)wUzc2_jm5=Ygu5|5M;YYBYK~G z(fSK^e{}cl1%=j*hC4bZfFT9KxI>Gd!5SV~WvNBQz*ul|GAA;)kU|c^H0A=?d~sfS zC=<|D5DxHW1X1ubi5!OlAPE$JooXBrWYCp z0p4tEI#5M&%(`&ppu`LUFnT+L%P int: @@ -428,7 +428,7 @@ def create_multiscale_vision_transformers( patch_embed_dim: int = 96, # Attention block config. num_heads: int = 1, - stochastic_depth_prob_block: float = 0.0, + stochastic_depth_prob: float = 0.0, embed_dim_mul: Optional[List[List[int]]] = ([1, 2.0], [3, 2.0], [14, 2.0]), atten_head_mul: Optional[List[List[int]]] = ([1, 2.0], [3, 2.0], [14, 2.0]), pool_q_stride_size: Optional[List[List[int]]] = ([1, 1, 2, 2], [3, 1, 2, 2], [14, 1, 2, 2]), @@ -449,7 +449,7 @@ def create_multiscale_vision_transformers( patch_embed_dim (int): Embedding dimension after patchifing the video input. num_heads (int): Number of heads in the first transformer block. - stochastic_depth_prob_block (float): Stochastic Depth probability for the attention block. + stochastic_depth_prob (float): Stochastic Depth probability for the attention block. embed_dim_mul (Optional[List[List[int]]]): Dimension multiplication at layer i. If X is used, then the next block will increase the embed dimension by X times. Format: [depth_i, mul_dim_ratio]. @@ -501,7 +501,7 @@ def create_multiscale_vision_transformers( patch_embed_shape=patch_embed_shape, ) - dpr = [x.item() for x in torch.linspace(0, stochastic_depth_prob_block, depth)] # stochastic depth decay rule + dpr = [x.item() for x in torch.linspace(0, stochastic_depth_prob, depth)] # stochastic depth decay rule dim_mul, head_mul = torch.ones(depth + 1), torch.ones(depth + 1) if embed_dim_mul is not None: @@ -576,16 +576,71 @@ def create_multiscale_vision_transformers( ) -# TODO: rename this. This is actually the small version for v2 -def mvit_b_16( - spatial_size=(224, 224), - temporal_size=16, - num_classes=400, - **kwargs, -): +def mvitv2_t(num_classes=400, **kwargs): + return create_multiscale_vision_transformers( + spatial_size=(224, 224), + temporal_size=16, + depth=10, + embed_dim_mul=[[1, 2.0], [3, 2.0], [8, 2.0]], + atten_head_mul=[[1, 2.0], [3, 2.0], [8, 2.0]], + pool_q_stride_size=[[0, 1, 1, 1], [1, 1, 2, 2], [2, 1, 1, 1], [3, 1, 2, 2], [4, 1, 1, 1], [5, 1, 1, 1], + [6, 1, 1, 1], [7, 1, 1, 1], [8, 1, 2, 2], [9, 1, 1, 1]], + droppath_rate_block=0.1, + num_classes=num_classes, + ) + + +def mvitv2_s(num_classes=400, **kwargs): + return create_multiscale_vision_transformers( + spatial_size=(224, 224), + temporal_size = 16, + depth = 16, + embed_dim_mul = [[1, 2.0], [3, 2.0], [14, 2.0]], + atten_head_mul = [[1, 2.0], [3, 2.0], [14, 2.0]], + pool_q_stride_size = [[0, 1, 1, 1], [1, 1, 2, 2], [2, 1, 1, 1], [3, 1, 2, 2], [4, 1, 1, 1], [5, 1, 1, 1], + [6, 1, 1, 1], [7, 1, 1, 1], [8, 1, 1, 1], [9, 1, 1, 1], [10, 1, 1, 1], [11, 1, 1, 1], + [12, 1, 1, 1], [13, 1, 1, 1], [14, 1, 2, 2], [15, 1, 1, 1]], + droppath_rate_block = 0.1, + num_classes=num_classes, + ) + + +def mvitv2_b(num_classes=400, **kwargs): + return create_multiscale_vision_transformers( + spatial_size=(224, 224), + temporal_size=32, + depth=24, + embed_dim_mul=[[2, 2.0], [5, 2.0], [21, 2.0]], + atten_head_mul=[[2, 2.0], [5, 2.0], [21, 2.0]], + pool_q_stride_size=[[0, 1, 1, 1], [1, 1, 1, 1], [2, 1, 2, 2], [3, 1, 1, 1], [4, 1, 1, 1], [5, 1, 2, 2], + [6, 1, 1, 1], [7, 1, 1, 1], [8, 1, 1, 1], [9, 1, 1, 1], [10, 1, 1, 1], [11, 1, 1, 1], + [12, 1, 1, 1], [13, 1, 1, 1], [14, 1, 1, 1], [15, 1, 1, 1], [16, 1, 1, 1], + [17, 1, 1, 1], [18, 1, 1, 1], [19, 1, 1, 1], [20, 1, 1, 1], [21, 1, 2, 2], + [22, 1, 1, 1], [23, 1, 1, 1]], + stochastic_depth_prob=0.3, + num_classes=num_classes, + ) + +""" +def mvitv2_l(num_classes=400, **kwargs): return create_multiscale_vision_transformers( - spatial_size=spatial_size, - temporal_size=temporal_size, + spatial_size=(312, 312), + temporal_size=40, + depth=48, + num_heads=2, + embed_dim_mul=[[2, 2.0], [8, 2.0], [44, 2.0]], + atten_head_mul=[[2, 2.0], [8, 2.0], [44, 2.0]], + pool_q_stride_size=[[0, 1, 1, 1], [1, 1, 1, 1], [2, 1, 2, 2], [3, 1, 1, 1], [4, 1, 1, 1], [5, 1, 1, 1], + [6, 1, 1, 1], [7, 1, 1, 1], [8, 1, 2, 2], [9, 1, 1, 1], [10, 1, 1, 1], [11, 1, 1, 1], + [12, 1, 1, 1], [13, 1, 1, 1], [14, 1, 1, 1], [15, 1, 1, 1], [16, 1, 1, 1], + [17, 1, 1, 1], [18, 1, 1, 1], [19, 1, 1, 1], [20, 1, 1, 1], [21, 1, 1, 1], + [22, 1, 1, 1], [23, 1, 1, 1], [24, 1, 1, 1], [25, 1, 1, 1], [26, 1, 1, 1], + [27, 1, 1, 1], [28, 1, 1, 1], [29, 1, 1, 1], [30, 1, 1, 1], [31, 1, 1, 1], + [32, 1, 1, 1], [33, 1, 1, 1], [34, 1, 1, 1], [35, 1, 1, 1], [36, 1, 1, 1], + [37, 1, 1, 1], [38, 1, 1, 1], [39, 1, 1, 1], [40, 1, 1, 1], [41, 1, 1, 1], + [42, 1, 1, 1], [43, 1, 1, 1], [44, 1, 2, 2], [45, 1, 1, 1], [46, 1, 1, 1], + [47, 1, 1, 1]], + droppath_rate_block=0.5, num_classes=num_classes, - **kwargs, ) +"""