From 78851e63d76365add67cba747a48bc2de5f9f166 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 5 Aug 2022 19:21:15 +0100 Subject: [PATCH 01/17] Extending to support MViTv2 --- .../ModelTester.test_mvit_v2_s_expect.pkl | Bin 0 -> 939 bytes test/test_models.py | 3 + torchvision/models/video/mvit.py | 347 ++++++++++++++++-- 3 files changed, 323 insertions(+), 27 deletions(-) create mode 100644 test/expect/ModelTester.test_mvit_v2_s_expect.pkl diff --git a/test/expect/ModelTester.test_mvit_v2_s_expect.pkl b/test/expect/ModelTester.test_mvit_v2_s_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..836084297d7a1186194c0c53bd54da08ab2529f4 GIT binary patch literal 939 zcmWIWW@cev;NW1u00Im`42ea_8JT6N`YDMeiFyUuIc`pT3{fbcfhoBpAE-(%zO*DW zr zf)S|3ppZF&8AvA=loqmh8ZvUemW=jY_4CYNO9=M{7L z7p0^YrKY%KCYNv(a%ct>a+VZw1r>7Z1$eV_Fj*X^nFTZrgadH;l#f9R#i#lPZcb`w z{zUOK66b;3E8FmU>HB|NU9#VA(@gs+cb$DxqMGdv?U--tpn7N@Ltw;SyIq_2wM6aQ zC$+kB-`~=vy$2`C*)!dXvgfNUvTKa&wL77qzW20LzP&ZKzMaZ`vwccR3-(?5b!z`P z*Hv~5UmEv*kk_{rT64%wW#JcFo;%F@3?hv7F6z9vFKFF_ed`{T*mD{0+&AT)!2X7e z?|a1bFYIf|xMH_KG~ZU-ZJXWHl??V<4~Xr5C^>JR==$mQHWPX5`TYC#&oDl-_lvgl z{@Ig!>=QoZ+k--DrU&z(6TpxHVcem`&tMG?t+LdjVqh$|Ihhj~Tu31YVH$IRY`!=z zJ(LM(D+mX8GlD31nnaF60gwa=Ku@9Qx{>|FhobWpkcX^W-vC`Nva9$}^hy9-2-6D< zg8*+fHXW!UIc8nBa!_Ih0T{g torch.Tensor: + if rel_pos.shape[0] == d: + return rel_pos + + return ( + nn.functional.interpolate( + rel_pos.permute(1, 0).unsqueeze(0), + size=d, + mode="linear", + ) + .squeeze(0) + .permute(1, 0) + ) + + +def cal_rel_pos_spatial( + attn: torch.Tensor, + q: torch.Tensor, + q_shape: Tuple[int, int, int], + k_shape: Tuple[int, int, int], + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, +) -> torch.Tensor: + q_t, q_h, q_w = q_shape + k_t, k_h, k_w = k_shape + dh = int(2 * max(q_h, k_h) - 1) + dw = int(2 * max(q_w, k_w) - 1) + + # Scale up rel pos if shapes for q and k are different. + q_h_ratio = max(k_h / q_h, 1.0) + k_h_ratio = max(q_h / k_h, 1.0) + dist_h = torch.arange(q_h)[:, None] * q_h_ratio - torch.arange(k_h)[None, :] * k_h_ratio + dist_h += (k_h - 1) * k_h_ratio + q_w_ratio = max(k_w / q_w, 1.0) + k_w_ratio = max(q_w / k_w, 1.0) + dist_w = torch.arange(q_w)[:, None] * q_w_ratio - torch.arange(k_w)[None, :] * k_w_ratio + dist_w += (k_w - 1) * k_w_ratio + + # Intepolate rel pos if needed. + rel_pos_h = get_rel_pos(rel_pos_h, dh) + rel_pos_w = get_rel_pos(rel_pos_w, dw) + Rh = rel_pos_h[dist_h.long()] + Rw = rel_pos_w[dist_w.long()] + + B, n_head, q_N, dim = q.shape + + r_q = q[:, :, 1:].reshape(B, n_head, q_t, q_h, q_w, dim) + rel_h_q = torch.einsum("bythwc,hkc->bythwk", r_q, Rh) # [B, H, q_t, qh, qw, k_h] + rel_w_q = torch.einsum("bythwc,wkc->bythwk", r_q, Rw) # [B, H, q_t, qh, qw, k_w] + + attn[:, :, 1:, 1:] = ( + attn[:, :, 1:, 1:].view(B, -1, q_t, q_h, q_w, k_t, k_h, k_w) + + rel_h_q[:, :, :, :, :, None, :, None] + + rel_w_q[:, :, :, :, :, None, None, :] + ).view(B, -1, q_t * q_h * q_w, k_t * k_h * k_w) + + return attn + + +def cal_rel_pos_temporal( + attn: torch.Tensor, + q: torch.Tensor, + q_shape: Tuple[int, int, int], + k_shape: Tuple[int, int, int], + rel_pos_t: torch.Tensor, +) -> torch.Tensor: + """ + Temporal Relative Positional Embeddings. + """ + q_t, q_h, q_w = q_shape + k_t, k_h, k_w = k_shape + dt = int(2 * max(q_t, k_t) - 1) + # Intepolate rel pos if needed. + rel_pos_t = get_rel_pos(rel_pos_t, dt) + + # Scale up rel pos if shapes for q and k are different. + q_t_ratio = max(k_t / q_t, 1.0) + k_t_ratio = max(q_t / k_t, 1.0) + dist_t = torch.arange(q_t)[:, None] * q_t_ratio - torch.arange(k_t)[None, :] * k_t_ratio + dist_t += (k_t - 1) * k_t_ratio + Rt = rel_pos_t[dist_t.long()] + + B, n_head, q_N, dim = q.shape + + r_q = q[:, :, 1:].reshape(B, n_head, q_t, q_h, q_w, dim) + # [B, H, q_t, q_h, q_w, dim] -> [q_t, B, H, q_h, q_w, dim] -> [q_t, B*H*q_h*q_w, dim] + r_q = r_q.permute(2, 0, 1, 3, 4, 5).reshape(q_t, B * n_head * q_h * q_w, dim) + + # [q_t, B*H*q_h*q_w, dim] * [q_t, dim, k_t] = [q_t, B*H*q_h*q_w, k_t] -> [B*H*q_h*q_w, q_t, k_t] + rel = torch.matmul(r_q, Rt.transpose(1, 2)).transpose(0, 1) + # [B*H*q_h*q_w, q_t, k_t] -> [B, H, q_t, q_h, q_w, k_t] + rel = rel.view(B, n_head, q_h, q_w, q_t, k_t).permute(0, 1, 4, 2, 3, 5) + + attn[:, :, 1:, 1:] = ( + attn[:, :, 1:, 1:].view(B, -1, q_t, q_h, q_w, k_t, k_h, k_w) + rel[:, :, :, :, :, :, None, None] + ).view(B, -1, q_t * q_h * q_w, k_t * k_h * k_w) + + return attn + + # TODO: Consider handle 2d input if Temporal is 1 @@ -106,28 +210,37 @@ def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Ten return x, (T, H, W) +torch.fx.wrap("get_rel_pos") +torch.fx.wrap("cal_rel_pos_spatial") +torch.fx.wrap("cal_rel_pos_temporal") + + class MultiscaleAttention(nn.Module): def __init__( self, embed_dim: int, + dim_out: int, + input_size: Tuple[int, int, int], # TODO: switch to List num_heads: int, kernel_q: List[int], kernel_kv: List[int], stride_q: List[int], stride_kv: List[int], residual_pool: bool, + rel_pos: bool, dropout: float = 0.0, norm_layer: Callable[..., nn.Module] = nn.LayerNorm, ) -> None: super().__init__() self.embed_dim = embed_dim + self.dim_out = dim_out self.num_heads = num_heads - self.head_dim = embed_dim // num_heads + self.head_dim = dim_out // num_heads self.scaler = 1.0 / math.sqrt(self.head_dim) self.residual_pool = residual_pool - self.qkv = nn.Linear(embed_dim, 3 * embed_dim) - layers: List[nn.Module] = [nn.Linear(embed_dim, embed_dim)] + self.qkv = nn.Linear(embed_dim, 3 * dim_out) + layers: List[nn.Module] = [nn.Linear(dim_out, dim_out)] if dropout > 0.0: layers.append(nn.Dropout(dropout, inplace=True)) self.project = nn.Sequential(*layers) @@ -177,24 +290,59 @@ def __init__( norm_layer(self.head_dim), ) + self.rel_pos_h: Optional[nn.Module] = None + self.rel_pos_w: Optional[nn.Module] = None + self.rel_pos_t: Optional[nn.Module] = None + if rel_pos: + assert input_size[1] == input_size[2] # TODO: remove this limitation + size = input_size[1] + q_size = size // stride_q[1] if len(stride_q) > 0 else size + kv_size = size // stride_kv[1] if len(stride_kv) > 0 else size + rel_sp_dim = 2 * max(q_size, kv_size) - 1 + self.rel_pos_h = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim)) + self.rel_pos_t = nn.Parameter(torch.zeros(2 * input_size[0] - 1, self.head_dim)) + nn.init.trunc_normal_(self.rel_pos_h, std=0.02) + nn.init.trunc_normal_(self.rel_pos_w, std=0.02) + nn.init.trunc_normal_(self.rel_pos_t, std=0.02) + def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Tensor, Tuple[int, int, int]]: B, N, C = x.shape q, k, v = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).transpose(1, 3).unbind(dim=2) if self.pool_k is not None: - k = self.pool_k(k, thw)[0] + k, k_shape = self.pool_k(k, thw) + else: + k_shape = thw if self.pool_v is not None: v = self.pool_v(v, thw)[0] if self.pool_q is not None: q, thw = self.pool_q(q, thw) attn = torch.matmul(self.scaler * q, k.transpose(2, 3)) + if self.rel_pos_h is not None and self.rel_pos_w is not None: + attn = cal_rel_pos_spatial( + attn, + q, + thw, + k_shape, + self.rel_pos_h, + self.rel_pos_w, + ) + if self.rel_pos_t is not None: + attn = cal_rel_pos_temporal( + attn, + q, + thw, + k_shape, + self.rel_pos_t, + ) attn = attn.softmax(dim=-1) x = torch.matmul(attn, v) if self.residual_pool: - x.add_(q) - x = x.transpose(1, 2).reshape(B, -1, C) + x.add_(q) # TODO: check x[:, :, 1:, :] += q[:, :, 1:, :] + x = x.transpose(1, 2).reshape(B, -1, self.dim_out) x = self.project(x) return x, thw @@ -203,13 +351,17 @@ def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Ten class MultiscaleBlock(nn.Module): def __init__( self, + input_size: Tuple[int, int, int], # TODO: switch to List cnf: MSBlockConfig, residual_pool: bool, + rel_pos: bool, + dim_mul_in_att: bool, dropout: float = 0.0, stochastic_depth_prob: float = 0.0, norm_layer: Callable[..., nn.Module] = nn.LayerNorm, ) -> None: super().__init__() + self.dim_mul_in_att = dim_mul_in_att self.pool_skip: Optional[nn.Module] = None if _prod(cnf.stride_q) > 1: @@ -219,24 +371,29 @@ def __init__( nn.MaxPool3d(kernel_skip, stride=cnf.stride_q, padding=padding_skip), None # type: ignore[arg-type] ) + att_dim = cnf.output_channels if dim_mul_in_att else cnf.input_channels + self.norm1 = norm_layer(cnf.input_channels) - self.norm2 = norm_layer(cnf.input_channels) + self.norm2 = norm_layer(att_dim) self.needs_transposal = isinstance(self.norm1, nn.BatchNorm1d) self.attn = MultiscaleAttention( cnf.input_channels, + att_dim, + input_size, cnf.num_heads, kernel_q=cnf.kernel_q, kernel_kv=cnf.kernel_kv, stride_q=cnf.stride_q, stride_kv=cnf.stride_kv, + rel_pos=rel_pos, residual_pool=residual_pool, dropout=dropout, norm_layer=norm_layer, ) self.mlp = MLP( - cnf.input_channels, - [4 * cnf.input_channels, cnf.output_channels], + att_dim, + [4 * att_dim, cnf.output_channels], activation_layer=nn.GELU, dropout=dropout, inplace=None, @@ -249,36 +406,45 @@ def __init__( self.project = nn.Linear(cnf.input_channels, cnf.output_channels) def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Tensor, Tuple[int, int, int]]: + x_norm1 = self.norm1(x.transpose(1, 2)).transpose(1, 2) if self.needs_transposal else self.norm1(x) + x_att, thw_new = self.attn(x_norm1, thw) + x = x if self.project is None or not self.dim_mul_in_att else self.project(x_norm1) x_skip = x if self.pool_skip is None else self.pool_skip(x, thw)[0] + x = x_skip + self.stochastic_depth(x_att) - x = self.norm1(x.transpose(1, 2)).transpose(1, 2) if self.needs_transposal else self.norm1(x) - x, thw = self.attn(x, thw) - x = x_skip + self.stochastic_depth(x) - - x_norm = self.norm2(x.transpose(1, 2)).transpose(1, 2) if self.needs_transposal else self.norm2(x) - x_proj = x if self.project is None else self.project(x_norm) + x_norm2 = self.norm2(x.transpose(1, 2)).transpose(1, 2) if self.needs_transposal else self.norm2(x) + x_proj = x if self.project is None or self.dim_mul_in_att else self.project(x_norm2) - return x_proj + self.stochastic_depth(self.mlp(x_norm)), thw + return x_proj + self.stochastic_depth(self.mlp(x_norm2)), thw_new class PositionalEncoding(nn.Module): - def __init__(self, embed_size: int, spatial_size: Tuple[int, int], temporal_size: int) -> None: + def __init__(self, embed_size: int, spatial_size: Tuple[int, int], temporal_size: int, rel_pos: bool) -> None: super().__init__() self.spatial_size = spatial_size self.temporal_size = temporal_size self.class_token = nn.Parameter(torch.zeros(embed_size)) - self.spatial_pos = nn.Parameter(torch.zeros(self.spatial_size[0] * self.spatial_size[1], embed_size)) - self.temporal_pos = nn.Parameter(torch.zeros(self.temporal_size, embed_size)) - self.class_pos = nn.Parameter(torch.zeros(embed_size)) + self.spatial_pos: Optional[nn.Parameter] = None + self.temporal_pos: Optional[nn.Parameter] = None + self.class_pos: Optional[nn.Parameter] = None + if not rel_pos: + self.spatial_pos = nn.Parameter(torch.zeros(self.spatial_size[0] * self.spatial_size[1], embed_size)) + self.temporal_pos = nn.Parameter(torch.zeros(self.temporal_size, embed_size)) + self.class_pos = nn.Parameter(torch.zeros(embed_size)) def forward(self, x: torch.Tensor) -> torch.Tensor: - hw_size, embed_size = self.spatial_pos.shape - pos_embedding = torch.repeat_interleave(self.temporal_pos, hw_size, dim=0) - pos_embedding.add_(self.spatial_pos.unsqueeze(0).expand(self.temporal_size, -1, -1).reshape(-1, embed_size)) - pos_embedding = torch.cat((self.class_pos.unsqueeze(0), pos_embedding), dim=0).unsqueeze(0) class_token = self.class_token.expand(x.size(0), -1).unsqueeze(1) - return torch.cat((class_token, x), dim=1).add_(pos_embedding) + x = torch.cat((class_token, x), dim=1) + + if self.spatial_pos is not None and self.temporal_pos is not None and self.class_pos is not None: + hw_size, embed_size = self.spatial_pos.shape + pos_embedding = torch.repeat_interleave(self.temporal_pos, hw_size, dim=0) + pos_embedding.add_(self.spatial_pos.unsqueeze(0).expand(self.temporal_size, -1, -1).reshape(-1, embed_size)) + pos_embedding = torch.cat((self.class_pos.unsqueeze(0), pos_embedding), dim=0).unsqueeze(0) + x.add_(pos_embedding) + + return x class MViT(nn.Module): @@ -288,6 +454,8 @@ def __init__( temporal_size: int, block_setting: Sequence[MSBlockConfig], residual_pool: bool, + rel_pos: bool, + dim_mul_in_att: bool, dropout: float = 0.5, attention_dropout: float = 0.0, stochastic_depth_prob: float = 0.0, @@ -335,11 +503,14 @@ def __init__( padding=(1, 3, 3), ) + input_size = [size // stride for size, stride in zip((temporal_size,) + spatial_size, self.conv_proj.stride)] + # Spatio-Temporal Class Positional Encoding self.pos_encoding = PositionalEncoding( embed_size=block_setting[0].input_channels, - spatial_size=(spatial_size[0] // self.conv_proj.stride[1], spatial_size[1] // self.conv_proj.stride[2]), - temporal_size=temporal_size // self.conv_proj.stride[0], + spatial_size=tuple(input_size[1:]), + temporal_size=input_size[0], + rel_pos=rel_pos, ) # Encoder module @@ -350,13 +521,19 @@ def __init__( self.blocks.append( block( + input_size=input_size, cnf=cnf, residual_pool=residual_pool, + rel_pos=rel_pos, + dim_mul_in_att=dim_mul_in_att, dropout=attention_dropout, stochastic_depth_prob=sd_prob, norm_layer=norm_layer, ) ) + + if len(cnf.stride_q) > 0: + input_size = [size // stride for size, stride in zip(input_size, cnf.stride_q)] self.norm = norm_layer(block_setting[-1].output_channels) # Classifier module @@ -420,6 +597,8 @@ def _mvit( temporal_size=temporal_size, block_setting=block_setting, residual_pool=kwargs.pop("residual_pool", False), + rel_pos=kwargs.pop("rel_pos", False), + dim_mul_in_att=kwargs.pop("dim_mul_in_att", False), stochastic_depth_prob=stochastic_depth_prob, **kwargs, ) @@ -461,6 +640,10 @@ class MViT_V1_B_Weights(WeightsEnum): DEFAULT = KINETICS400_V1 +class MViT_V2_S_Weights(WeightsEnum): + pass + + @register_model() def mvit_v1_b(*, weights: Optional[MViT_V1_B_Weights] = None, progress: bool = True, **kwargs: Any) -> MViT: """ @@ -553,3 +736,113 @@ def mvit_v1_b(*, weights: Optional[MViT_V1_B_Weights] = None, progress: bool = T progress=progress, **kwargs, ) + + +@register_model() +def mvit_v2_s(*, weights: Optional[MViT_V2_S_Weights] = None, progress: bool = True, **kwargs: Any) -> MViT: + weights = MViT_V1_B_Weights.verify(weights) + + config: Dict[str, List] = { + "num_heads": [1, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 8, 8], + "input_channels": [96, 96, 192, 192, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 768], + "output_channels": [96, 192, 192, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 768, 768], + "kernel_q": [ + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + ], + "kernel_kv": [ + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + ], + "stride_q": [ + [1, 1, 1], + [1, 2, 2], + [1, 1, 1], + [1, 2, 2], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 2, 2], + [1, 1, 1], + ], + "stride_kv": [ + [1, 8, 8], + [1, 4, 4], + [1, 4, 4], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 1, 1], + [1, 1, 1], + ], + } + + block_setting = [] + for i in range(len(config["num_heads"])): + block_setting.append( + MSBlockConfig( + num_heads=config["num_heads"][i], + input_channels=config["input_channels"][i], + output_channels=config["output_channels"][i], + kernel_q=config["kernel_q"][i], + kernel_kv=config["kernel_kv"][i], + stride_q=config["stride_q"][i], + stride_kv=config["stride_kv"][i], + ) + ) + + return _mvit( + spatial_size=(224, 224), + temporal_size=16, + block_setting=block_setting, + residual_pool=True, + rel_pos=True, + dim_mul_in_att=True, + stochastic_depth_prob=kwargs.pop("stochastic_depth_prob", 0.2), + weights=weights, + progress=progress, + **kwargs, + ) From 22c9850908f1da65a4ea235be45d0791c80b8bb5 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 8 Aug 2022 09:54:22 +0100 Subject: [PATCH 02/17] Fix docs, mypy and linter --- torchvision/models/video/mvit.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/torchvision/models/video/mvit.py b/torchvision/models/video/mvit.py index 48430611da2..499215109a5 100644 --- a/torchvision/models/video/mvit.py +++ b/torchvision/models/video/mvit.py @@ -26,6 +26,7 @@ # Reference: https://github.com/facebookresearch/SlowFast/commit/1aebd71a2efad823d52b827a3deaf15a56cf4932# + def get_rel_pos(rel_pos: torch.Tensor, d: int) -> torch.Tensor: if rel_pos.shape[0] == d: return rel_pos @@ -290,9 +291,9 @@ def __init__( norm_layer(self.head_dim), ) - self.rel_pos_h: Optional[nn.Module] = None - self.rel_pos_w: Optional[nn.Module] = None - self.rel_pos_t: Optional[nn.Module] = None + self.rel_pos_h: Optional[nn.Parameter] = None + self.rel_pos_w: Optional[nn.Parameter] = None + self.rel_pos_t: Optional[nn.Parameter] = None if rel_pos: assert input_size[1] == input_size[2] # TODO: remove this limitation size = input_size[1] @@ -471,6 +472,8 @@ def __init__( temporal_size (int): The temporal size ``T`` of the input. block_setting (sequence of MSBlockConfig): The Network structure. residual_pool (bool): If True, use MViTv2 pooling residual connection. + rel_pos (bool): TODO + dim_mul_in_att (bool): TODO dropout (float): Dropout rate. Default: 0.0. attention_dropout (float): Attention dropout rate. Default: 0.0. stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0. @@ -508,7 +511,7 @@ def __init__( # Spatio-Temporal Class Positional Encoding self.pos_encoding = PositionalEncoding( embed_size=block_setting[0].input_channels, - spatial_size=tuple(input_size[1:]), + spatial_size=(input_size[1], input_size[2]), temporal_size=input_size[0], rel_pos=rel_pos, ) From cb62dce8fd9be7fa6adbcae5c2cb0ac66b57f8ab Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 8 Aug 2022 11:00:45 +0100 Subject: [PATCH 03/17] Refactor the relative positional code. --- torchvision/models/video/mvit.py | 101 ++++++++++--------------------- 1 file changed, 32 insertions(+), 69 deletions(-) diff --git a/torchvision/models/video/mvit.py b/torchvision/models/video/mvit.py index 499215109a5..5ff5cba9e94 100644 --- a/torchvision/models/video/mvit.py +++ b/torchvision/models/video/mvit.py @@ -24,10 +24,7 @@ ] -# Reference: https://github.com/facebookresearch/SlowFast/commit/1aebd71a2efad823d52b827a3deaf15a56cf4932# - - -def get_rel_pos(rel_pos: torch.Tensor, d: int) -> torch.Tensor: +def _interpolate(rel_pos: torch.Tensor, d: int) -> torch.Tensor: if rel_pos.shape[0] == d: return rel_pos @@ -42,87 +39,62 @@ def get_rel_pos(rel_pos: torch.Tensor, d: int) -> torch.Tensor: ) -def cal_rel_pos_spatial( +def add_rel_pos( attn: torch.Tensor, q: torch.Tensor, q_shape: Tuple[int, int, int], k_shape: Tuple[int, int, int], rel_pos_h: torch.Tensor, rel_pos_w: torch.Tensor, + rel_pos_t: torch.Tensor, ) -> torch.Tensor: + # Modified code from: https://github.com/facebookresearch/SlowFast/commit/1aebd71a2efad823d52b827a3deaf15a56cf4932 q_t, q_h, q_w = q_shape k_t, k_h, k_w = k_shape dh = int(2 * max(q_h, k_h) - 1) dw = int(2 * max(q_w, k_w) - 1) + dt = int(2 * max(q_t, k_t) - 1) # Scale up rel pos if shapes for q and k are different. q_h_ratio = max(k_h / q_h, 1.0) k_h_ratio = max(q_h / k_h, 1.0) - dist_h = torch.arange(q_h)[:, None] * q_h_ratio - torch.arange(k_h)[None, :] * k_h_ratio - dist_h += (k_h - 1) * k_h_ratio + dist_h = torch.arange(q_h)[:, None] * q_h_ratio - (torch.arange(k_h)[None, :] + (1.0 - k_h)) * k_h_ratio q_w_ratio = max(k_w / q_w, 1.0) k_w_ratio = max(q_w / k_w, 1.0) - dist_w = torch.arange(q_w)[:, None] * q_w_ratio - torch.arange(k_w)[None, :] * k_w_ratio - dist_w += (k_w - 1) * k_w_ratio + dist_w = torch.arange(q_w)[:, None] * q_w_ratio - (torch.arange(k_w)[None, :] + (1.0 - k_w)) * k_w_ratio + q_t_ratio = max(k_t / q_t, 1.0) + k_t_ratio = max(q_t / k_t, 1.0) + dist_t = torch.arange(q_t)[:, None] * q_t_ratio - (torch.arange(k_t)[None, :] + (1.0 - k_t)) * k_t_ratio # Intepolate rel pos if needed. - rel_pos_h = get_rel_pos(rel_pos_h, dh) - rel_pos_w = get_rel_pos(rel_pos_w, dw) + rel_pos_h = _interpolate(rel_pos_h, dh) + rel_pos_w = _interpolate(rel_pos_w, dw) + rel_pos_t = _interpolate(rel_pos_t, dt) Rh = rel_pos_h[dist_h.long()] Rw = rel_pos_w[dist_w.long()] + Rt = rel_pos_t[dist_t.long()] - B, n_head, q_N, dim = q.shape + B, n_head, _, dim = q.shape r_q = q[:, :, 1:].reshape(B, n_head, q_t, q_h, q_w, dim) rel_h_q = torch.einsum("bythwc,hkc->bythwk", r_q, Rh) # [B, H, q_t, qh, qw, k_h] rel_w_q = torch.einsum("bythwc,wkc->bythwk", r_q, Rw) # [B, H, q_t, qh, qw, k_w] - - attn[:, :, 1:, 1:] = ( - attn[:, :, 1:, 1:].view(B, -1, q_t, q_h, q_w, k_t, k_h, k_w) - + rel_h_q[:, :, :, :, :, None, :, None] - + rel_w_q[:, :, :, :, :, None, None, :] - ).view(B, -1, q_t * q_h * q_w, k_t * k_h * k_w) - - return attn - - -def cal_rel_pos_temporal( - attn: torch.Tensor, - q: torch.Tensor, - q_shape: Tuple[int, int, int], - k_shape: Tuple[int, int, int], - rel_pos_t: torch.Tensor, -) -> torch.Tensor: - """ - Temporal Relative Positional Embeddings. - """ - q_t, q_h, q_w = q_shape - k_t, k_h, k_w = k_shape - dt = int(2 * max(q_t, k_t) - 1) - # Intepolate rel pos if needed. - rel_pos_t = get_rel_pos(rel_pos_t, dt) - - # Scale up rel pos if shapes for q and k are different. - q_t_ratio = max(k_t / q_t, 1.0) - k_t_ratio = max(q_t / k_t, 1.0) - dist_t = torch.arange(q_t)[:, None] * q_t_ratio - torch.arange(k_t)[None, :] * k_t_ratio - dist_t += (k_t - 1) * k_t_ratio - Rt = rel_pos_t[dist_t.long()] - - B, n_head, q_N, dim = q.shape - - r_q = q[:, :, 1:].reshape(B, n_head, q_t, q_h, q_w, dim) # [B, H, q_t, q_h, q_w, dim] -> [q_t, B, H, q_h, q_w, dim] -> [q_t, B*H*q_h*q_w, dim] r_q = r_q.permute(2, 0, 1, 3, 4, 5).reshape(q_t, B * n_head * q_h * q_w, dim) - # [q_t, B*H*q_h*q_w, dim] * [q_t, dim, k_t] = [q_t, B*H*q_h*q_w, k_t] -> [B*H*q_h*q_w, q_t, k_t] - rel = torch.matmul(r_q, Rt.transpose(1, 2)).transpose(0, 1) + rel_q_t = torch.matmul(r_q, Rt.transpose(1, 2)).transpose(0, 1) # [B*H*q_h*q_w, q_t, k_t] -> [B, H, q_t, q_h, q_w, k_t] - rel = rel.view(B, n_head, q_h, q_w, q_t, k_t).permute(0, 1, 4, 2, 3, 5) + rel_q_t = rel_q_t.view(B, n_head, q_h, q_w, q_t, k_t).permute(0, 1, 4, 2, 3, 5) + + # Combine rel pos. + rel_pos = ( + rel_h_q[:, :, :, :, :, None, :, None] + + rel_w_q[:, :, :, :, :, None, None, :] + + rel_q_t[:, :, :, :, :, :, None, None] + ).reshape(B, n_head, q_t * q_h * q_w, k_t * k_h * k_w) - attn[:, :, 1:, 1:] = ( - attn[:, :, 1:, 1:].view(B, -1, q_t, q_h, q_w, k_t, k_h, k_w) + rel[:, :, :, :, :, :, None, None] - ).view(B, -1, q_t * q_h * q_w, k_t * k_h * k_w) + # Add it to attention + attn[:, :, 1:, 1:] += rel_pos return attn @@ -211,9 +183,7 @@ def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Ten return x, (T, H, W) -torch.fx.wrap("get_rel_pos") -torch.fx.wrap("cal_rel_pos_spatial") -torch.fx.wrap("cal_rel_pos_temporal") +torch.fx.wrap("add_rel_pos") class MultiscaleAttention(nn.Module): @@ -312,30 +282,23 @@ def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Ten q, k, v = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).transpose(1, 3).unbind(dim=2) if self.pool_k is not None: - k, k_shape = self.pool_k(k, thw) + k, k_thw = self.pool_k(k, thw) else: - k_shape = thw + k_thw = thw if self.pool_v is not None: v = self.pool_v(v, thw)[0] if self.pool_q is not None: q, thw = self.pool_q(q, thw) attn = torch.matmul(self.scaler * q, k.transpose(2, 3)) - if self.rel_pos_h is not None and self.rel_pos_w is not None: - attn = cal_rel_pos_spatial( + if self.rel_pos_h is not None and self.rel_pos_w is not None and self.rel_pos_t is not None: + attn = add_rel_pos( attn, q, thw, - k_shape, + k_thw, self.rel_pos_h, self.rel_pos_w, - ) - if self.rel_pos_t is not None: - attn = cal_rel_pos_temporal( - attn, - q, - thw, - k_shape, self.rel_pos_t, ) attn = attn.softmax(dim=-1) From 9f1dcaa4882b97945ea2271cce8b257ae7946fd6 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 8 Aug 2022 12:20:43 +0100 Subject: [PATCH 04/17] Code refactoring. --- torchvision/models/video/mvit.py | 64 ++++++++++++++++---------------- 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/torchvision/models/video/mvit.py b/torchvision/models/video/mvit.py index 5ff5cba9e94..281db631916 100644 --- a/torchvision/models/video/mvit.py +++ b/torchvision/models/video/mvit.py @@ -24,13 +24,13 @@ ] -def _interpolate(rel_pos: torch.Tensor, d: int) -> torch.Tensor: - if rel_pos.shape[0] == d: - return rel_pos +def _interpolate(embedding: torch.Tensor, d: int) -> torch.Tensor: + if embedding.shape[0] == d: + return embedding return ( nn.functional.interpolate( - rel_pos.permute(1, 0).unsqueeze(0), + embedding.permute(1, 0).unsqueeze(0), size=d, mode="linear", ) @@ -42,15 +42,15 @@ def _interpolate(rel_pos: torch.Tensor, d: int) -> torch.Tensor: def add_rel_pos( attn: torch.Tensor, q: torch.Tensor, - q_shape: Tuple[int, int, int], - k_shape: Tuple[int, int, int], + q_thw: Tuple[int, int, int], + k_thw: Tuple[int, int, int], rel_pos_h: torch.Tensor, rel_pos_w: torch.Tensor, rel_pos_t: torch.Tensor, ) -> torch.Tensor: # Modified code from: https://github.com/facebookresearch/SlowFast/commit/1aebd71a2efad823d52b827a3deaf15a56cf4932 - q_t, q_h, q_w = q_shape - k_t, k_h, k_w = k_shape + q_t, q_h, q_w = q_thw + k_t, k_h, k_w = k_thw dh = int(2 * max(q_h, k_h) - 1) dw = int(2 * max(q_w, k_w) - 1) dt = int(2 * max(q_t, k_t) - 1) @@ -189,9 +189,9 @@ def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Ten class MultiscaleAttention(nn.Module): def __init__( self, + input_size: List[int], embed_dim: int, - dim_out: int, - input_size: Tuple[int, int, int], # TODO: switch to List + output_dim: int, num_heads: int, kernel_q: List[int], kernel_kv: List[int], @@ -204,14 +204,14 @@ def __init__( ) -> None: super().__init__() self.embed_dim = embed_dim - self.dim_out = dim_out + self.output_dim = output_dim self.num_heads = num_heads - self.head_dim = dim_out // num_heads + self.head_dim = output_dim // num_heads self.scaler = 1.0 / math.sqrt(self.head_dim) self.residual_pool = residual_pool - self.qkv = nn.Linear(embed_dim, 3 * dim_out) - layers: List[nn.Module] = [nn.Linear(dim_out, dim_out)] + self.qkv = nn.Linear(embed_dim, 3 * output_dim) + layers: List[nn.Module] = [nn.Linear(output_dim, output_dim)] if dropout > 0.0: layers.append(nn.Dropout(dropout, inplace=True)) self.project = nn.Sequential(*layers) @@ -306,7 +306,7 @@ def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Ten x = torch.matmul(attn, v) if self.residual_pool: x.add_(q) # TODO: check x[:, :, 1:, :] += q[:, :, 1:, :] - x = x.transpose(1, 2).reshape(B, -1, self.dim_out) + x = x.transpose(1, 2).reshape(B, -1, self.output_dim) x = self.project(x) return x, thw @@ -315,17 +315,17 @@ def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Ten class MultiscaleBlock(nn.Module): def __init__( self, - input_size: Tuple[int, int, int], # TODO: switch to List + input_size: List[int], cnf: MSBlockConfig, residual_pool: bool, rel_pos: bool, - dim_mul_in_att: bool, + proj_after_att: bool, dropout: float = 0.0, stochastic_depth_prob: float = 0.0, norm_layer: Callable[..., nn.Module] = nn.LayerNorm, ) -> None: super().__init__() - self.dim_mul_in_att = dim_mul_in_att + self.proj_after_att = proj_after_att self.pool_skip: Optional[nn.Module] = None if _prod(cnf.stride_q) > 1: @@ -335,16 +335,16 @@ def __init__( nn.MaxPool3d(kernel_skip, stride=cnf.stride_q, padding=padding_skip), None # type: ignore[arg-type] ) - att_dim = cnf.output_channels if dim_mul_in_att else cnf.input_channels + attn_dim = cnf.output_channels if proj_after_att else cnf.input_channels self.norm1 = norm_layer(cnf.input_channels) - self.norm2 = norm_layer(att_dim) + self.norm2 = norm_layer(attn_dim) self.needs_transposal = isinstance(self.norm1, nn.BatchNorm1d) self.attn = MultiscaleAttention( - cnf.input_channels, - att_dim, input_size, + cnf.input_channels, + attn_dim, cnf.num_heads, kernel_q=cnf.kernel_q, kernel_kv=cnf.kernel_kv, @@ -356,8 +356,8 @@ def __init__( norm_layer=norm_layer, ) self.mlp = MLP( - att_dim, - [4 * att_dim, cnf.output_channels], + attn_dim, + [4 * attn_dim, cnf.output_channels], activation_layer=nn.GELU, dropout=dropout, inplace=None, @@ -372,12 +372,12 @@ def __init__( def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Tensor, Tuple[int, int, int]]: x_norm1 = self.norm1(x.transpose(1, 2)).transpose(1, 2) if self.needs_transposal else self.norm1(x) x_att, thw_new = self.attn(x_norm1, thw) - x = x if self.project is None or not self.dim_mul_in_att else self.project(x_norm1) + x = x if self.project is None or not self.proj_after_att else self.project(x_norm1) x_skip = x if self.pool_skip is None else self.pool_skip(x, thw)[0] x = x_skip + self.stochastic_depth(x_att) x_norm2 = self.norm2(x.transpose(1, 2)).transpose(1, 2) if self.needs_transposal else self.norm2(x) - x_proj = x if self.project is None or self.dim_mul_in_att else self.project(x_norm2) + x_proj = x if self.project is None or self.proj_after_att else self.project(x_norm2) return x_proj + self.stochastic_depth(self.mlp(x_norm2)), thw_new @@ -419,7 +419,7 @@ def __init__( block_setting: Sequence[MSBlockConfig], residual_pool: bool, rel_pos: bool, - dim_mul_in_att: bool, + proj_after_att: bool, dropout: float = 0.5, attention_dropout: float = 0.0, stochastic_depth_prob: float = 0.0, @@ -435,8 +435,8 @@ def __init__( temporal_size (int): The temporal size ``T`` of the input. block_setting (sequence of MSBlockConfig): The Network structure. residual_pool (bool): If True, use MViTv2 pooling residual connection. - rel_pos (bool): TODO - dim_mul_in_att (bool): TODO + rel_pos (bool): If True, use MViTv2's relative positional embeddings. + proj_after_att (bool): If True, do the projection step on the attention output. dropout (float): Dropout rate. Default: 0.0. attention_dropout (float): Attention dropout rate. Default: 0.0. stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0. @@ -491,7 +491,7 @@ def __init__( cnf=cnf, residual_pool=residual_pool, rel_pos=rel_pos, - dim_mul_in_att=dim_mul_in_att, + proj_after_att=proj_after_att, dropout=attention_dropout, stochastic_depth_prob=sd_prob, norm_layer=norm_layer, @@ -564,7 +564,7 @@ def _mvit( block_setting=block_setting, residual_pool=kwargs.pop("residual_pool", False), rel_pos=kwargs.pop("rel_pos", False), - dim_mul_in_att=kwargs.pop("dim_mul_in_att", False), + proj_after_att=kwargs.pop("proj_after_att", False), stochastic_depth_prob=stochastic_depth_prob, **kwargs, ) @@ -806,7 +806,7 @@ def mvit_v2_s(*, weights: Optional[MViT_V2_S_Weights] = None, progress: bool = T block_setting=block_setting, residual_pool=True, rel_pos=True, - dim_mul_in_att=True, + proj_after_att=True, stochastic_depth_prob=kwargs.pop("stochastic_depth_prob", 0.2), weights=weights, progress=progress, From 9378abdefd26b6de851450e0885dd034ecd45474 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 8 Aug 2022 13:52:03 +0100 Subject: [PATCH 05/17] Rename vars. --- torchvision/models/video/mvit.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/torchvision/models/video/mvit.py b/torchvision/models/video/mvit.py index 281db631916..6a67d9a6343 100644 --- a/torchvision/models/video/mvit.py +++ b/torchvision/models/video/mvit.py @@ -319,13 +319,13 @@ def __init__( cnf: MSBlockConfig, residual_pool: bool, rel_pos: bool, - proj_after_att: bool, + proj_after_attn: bool, dropout: float = 0.0, stochastic_depth_prob: float = 0.0, norm_layer: Callable[..., nn.Module] = nn.LayerNorm, ) -> None: super().__init__() - self.proj_after_att = proj_after_att + self.proj_after_attn = proj_after_attn self.pool_skip: Optional[nn.Module] = None if _prod(cnf.stride_q) > 1: @@ -335,7 +335,7 @@ def __init__( nn.MaxPool3d(kernel_skip, stride=cnf.stride_q, padding=padding_skip), None # type: ignore[arg-type] ) - attn_dim = cnf.output_channels if proj_after_att else cnf.input_channels + attn_dim = cnf.output_channels if proj_after_attn else cnf.input_channels self.norm1 = norm_layer(cnf.input_channels) self.norm2 = norm_layer(attn_dim) @@ -371,13 +371,13 @@ def __init__( def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Tensor, Tuple[int, int, int]]: x_norm1 = self.norm1(x.transpose(1, 2)).transpose(1, 2) if self.needs_transposal else self.norm1(x) - x_att, thw_new = self.attn(x_norm1, thw) - x = x if self.project is None or not self.proj_after_att else self.project(x_norm1) + x_attn, thw_new = self.attn(x_norm1, thw) + x = x if self.project is None or not self.proj_after_attn else self.project(x_norm1) x_skip = x if self.pool_skip is None else self.pool_skip(x, thw)[0] - x = x_skip + self.stochastic_depth(x_att) + x = x_skip + self.stochastic_depth(x_attn) x_norm2 = self.norm2(x.transpose(1, 2)).transpose(1, 2) if self.needs_transposal else self.norm2(x) - x_proj = x if self.project is None or self.proj_after_att else self.project(x_norm2) + x_proj = x if self.project is None or self.proj_after_attn else self.project(x_norm2) return x_proj + self.stochastic_depth(self.mlp(x_norm2)), thw_new @@ -419,7 +419,7 @@ def __init__( block_setting: Sequence[MSBlockConfig], residual_pool: bool, rel_pos: bool, - proj_after_att: bool, + proj_after_attn: bool, dropout: float = 0.5, attention_dropout: float = 0.0, stochastic_depth_prob: float = 0.0, @@ -436,7 +436,7 @@ def __init__( block_setting (sequence of MSBlockConfig): The Network structure. residual_pool (bool): If True, use MViTv2 pooling residual connection. rel_pos (bool): If True, use MViTv2's relative positional embeddings. - proj_after_att (bool): If True, do the projection step on the attention output. + proj_after_attn (bool): If True, do the projection step on the attention output. dropout (float): Dropout rate. Default: 0.0. attention_dropout (float): Attention dropout rate. Default: 0.0. stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0. @@ -491,7 +491,7 @@ def __init__( cnf=cnf, residual_pool=residual_pool, rel_pos=rel_pos, - proj_after_att=proj_after_att, + proj_after_attn=proj_after_attn, dropout=attention_dropout, stochastic_depth_prob=sd_prob, norm_layer=norm_layer, @@ -564,7 +564,7 @@ def _mvit( block_setting=block_setting, residual_pool=kwargs.pop("residual_pool", False), rel_pos=kwargs.pop("rel_pos", False), - proj_after_att=kwargs.pop("proj_after_att", False), + proj_after_attn=kwargs.pop("proj_after_attn", False), stochastic_depth_prob=stochastic_depth_prob, **kwargs, ) @@ -806,7 +806,7 @@ def mvit_v2_s(*, weights: Optional[MViT_V2_S_Weights] = None, progress: bool = T block_setting=block_setting, residual_pool=True, rel_pos=True, - proj_after_att=True, + proj_after_attn=True, stochastic_depth_prob=kwargs.pop("stochastic_depth_prob", 0.2), weights=weights, progress=progress, From a7d917efa9a9e7e38ba434141a790e9dc8d884a4 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 8 Aug 2022 16:15:45 +0100 Subject: [PATCH 06/17] Update docs. --- docs/source/models/video_mvit.rst | 3 ++- torchvision/models/video/mvit.py | 22 +++++++++++++++++++++- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/docs/source/models/video_mvit.rst b/docs/source/models/video_mvit.rst index d5be1245ac9..cd23754b7bb 100644 --- a/docs/source/models/video_mvit.rst +++ b/docs/source/models/video_mvit.rst @@ -12,7 +12,7 @@ The MViT model is based on the Model builders -------------- -The following model builders can be used to instantiate a MViT model, with or +The following model builders can be used to instantiate a MViT v1 or v2 model, with or without pre-trained weights. All the model builders internally rely on the ``torchvision.models.video.MViT`` base class. Please refer to the `source code @@ -24,3 +24,4 @@ more details about this class. :template: function.rst mvit_v1_b + mvit_v2_s diff --git a/torchvision/models/video/mvit.py b/torchvision/models/video/mvit.py index 6a67d9a6343..b65547bbc64 100644 --- a/torchvision/models/video/mvit.py +++ b/torchvision/models/video/mvit.py @@ -706,7 +706,27 @@ def mvit_v1_b(*, weights: Optional[MViT_V1_B_Weights] = None, progress: bool = T @register_model() def mvit_v2_s(*, weights: Optional[MViT_V2_S_Weights] = None, progress: bool = True, **kwargs: Any) -> MViT: - weights = MViT_V1_B_Weights.verify(weights) + """ + Constructs a small MViTV2 architecture from + `Multiscale Vision Transformers `__. + + Args: + weights (:class:`~torchvision.models.video.MViT_V2_S_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.video.MViT_V2_S_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.video.MViT`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.video.MViT_V2_S_Weights + :members: + """ + weights = MViT_V2_S_Weights.verify(weights) config: Dict[str, List] = { "num_heads": [1, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 8, 8], From b39ac57c69bd638f9dfb9f7fea0b4bfa9de3c889 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 8 Aug 2022 16:30:14 +0100 Subject: [PATCH 07/17] Replace assert with exception. --- torchvision/models/video/mvit.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchvision/models/video/mvit.py b/torchvision/models/video/mvit.py index b65547bbc64..2318149363b 100644 --- a/torchvision/models/video/mvit.py +++ b/torchvision/models/video/mvit.py @@ -265,7 +265,8 @@ def __init__( self.rel_pos_w: Optional[nn.Parameter] = None self.rel_pos_t: Optional[nn.Parameter] = None if rel_pos: - assert input_size[1] == input_size[2] # TODO: remove this limitation + if input_size[1] != input_size[2]: + raise ValueError("Relative Positional Embeddings require square input shape.") size = input_size[1] q_size = size // stride_q[1] if len(stride_q) > 0 else size kv_size = size // stride_kv[1] if len(stride_kv) > 0 else size From 1bad54f36c0a6c00e5bf668065273b73867baa56 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 8 Aug 2022 16:35:06 +0100 Subject: [PATCH 08/17] Updat docs. --- torchvision/models/video/mvit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/video/mvit.py b/torchvision/models/video/mvit.py index 2318149363b..d3cac60eaf0 100644 --- a/torchvision/models/video/mvit.py +++ b/torchvision/models/video/mvit.py @@ -437,7 +437,7 @@ def __init__( block_setting (sequence of MSBlockConfig): The Network structure. residual_pool (bool): If True, use MViTv2 pooling residual connection. rel_pos (bool): If True, use MViTv2's relative positional embeddings. - proj_after_attn (bool): If True, do the projection step on the attention output. + proj_after_attn (bool): If True, apply the projection after the attention. dropout (float): Dropout rate. Default: 0.0. attention_dropout (float): Attention dropout rate. Default: 0.0. stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0. From 2ce2c65f4604cb0e0cd053f26c1c12899274d9b7 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 8 Aug 2022 16:46:04 +0100 Subject: [PATCH 09/17] Minor refactoring. --- torchvision/models/video/mvit.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/torchvision/models/video/mvit.py b/torchvision/models/video/mvit.py index d3cac60eaf0..1dfb6943268 100644 --- a/torchvision/models/video/mvit.py +++ b/torchvision/models/video/mvit.py @@ -270,10 +270,11 @@ def __init__( size = input_size[1] q_size = size // stride_q[1] if len(stride_q) > 0 else size kv_size = size // stride_kv[1] if len(stride_kv) > 0 else size - rel_sp_dim = 2 * max(q_size, kv_size) - 1 - self.rel_pos_h = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim)) - self.rel_pos_w = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim)) - self.rel_pos_t = nn.Parameter(torch.zeros(2 * input_size[0] - 1, self.head_dim)) + spatial_dim = 2 * max(q_size, kv_size) - 1 + temporal_dim = 2 * input_size[0] - 1 + self.rel_pos_h = nn.Parameter(torch.zeros(spatial_dim, self.head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(spatial_dim, self.head_dim)) + self.rel_pos_t = nn.Parameter(torch.zeros(temporal_dim, self.head_dim)) nn.init.trunc_normal_(self.rel_pos_h, std=0.02) nn.init.trunc_normal_(self.rel_pos_w, std=0.02) nn.init.trunc_normal_(self.rel_pos_t, std=0.02) From 03d365cf8ee717905dbe799446f17e9d42976f22 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 9 Aug 2022 08:40:04 +0100 Subject: [PATCH 10/17] Remove the square input limitation. --- torchvision/models/video/mvit.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torchvision/models/video/mvit.py b/torchvision/models/video/mvit.py index 1dfb6943268..cbce67a66b3 100644 --- a/torchvision/models/video/mvit.py +++ b/torchvision/models/video/mvit.py @@ -265,9 +265,7 @@ def __init__( self.rel_pos_w: Optional[nn.Parameter] = None self.rel_pos_t: Optional[nn.Parameter] = None if rel_pos: - if input_size[1] != input_size[2]: - raise ValueError("Relative Positional Embeddings require square input shape.") - size = input_size[1] + size = max(input_size[1:]) q_size = size // stride_q[1] if len(stride_q) > 0 else size kv_size = size // stride_kv[1] if len(stride_kv) > 0 else size spatial_dim = 2 * max(q_size, kv_size) - 1 From f260ecfd68b161cf28ab186361f8299881f58ded Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 9 Aug 2022 09:32:32 +0100 Subject: [PATCH 11/17] Moving methods around. --- torchvision/models/video/mvit.py | 157 ++++++++++++++++--------------- 1 file changed, 79 insertions(+), 78 deletions(-) diff --git a/torchvision/models/video/mvit.py b/torchvision/models/video/mvit.py index cbce67a66b3..968a0a442a8 100644 --- a/torchvision/models/video/mvit.py +++ b/torchvision/models/video/mvit.py @@ -24,81 +24,6 @@ ] -def _interpolate(embedding: torch.Tensor, d: int) -> torch.Tensor: - if embedding.shape[0] == d: - return embedding - - return ( - nn.functional.interpolate( - embedding.permute(1, 0).unsqueeze(0), - size=d, - mode="linear", - ) - .squeeze(0) - .permute(1, 0) - ) - - -def add_rel_pos( - attn: torch.Tensor, - q: torch.Tensor, - q_thw: Tuple[int, int, int], - k_thw: Tuple[int, int, int], - rel_pos_h: torch.Tensor, - rel_pos_w: torch.Tensor, - rel_pos_t: torch.Tensor, -) -> torch.Tensor: - # Modified code from: https://github.com/facebookresearch/SlowFast/commit/1aebd71a2efad823d52b827a3deaf15a56cf4932 - q_t, q_h, q_w = q_thw - k_t, k_h, k_w = k_thw - dh = int(2 * max(q_h, k_h) - 1) - dw = int(2 * max(q_w, k_w) - 1) - dt = int(2 * max(q_t, k_t) - 1) - - # Scale up rel pos if shapes for q and k are different. - q_h_ratio = max(k_h / q_h, 1.0) - k_h_ratio = max(q_h / k_h, 1.0) - dist_h = torch.arange(q_h)[:, None] * q_h_ratio - (torch.arange(k_h)[None, :] + (1.0 - k_h)) * k_h_ratio - q_w_ratio = max(k_w / q_w, 1.0) - k_w_ratio = max(q_w / k_w, 1.0) - dist_w = torch.arange(q_w)[:, None] * q_w_ratio - (torch.arange(k_w)[None, :] + (1.0 - k_w)) * k_w_ratio - q_t_ratio = max(k_t / q_t, 1.0) - k_t_ratio = max(q_t / k_t, 1.0) - dist_t = torch.arange(q_t)[:, None] * q_t_ratio - (torch.arange(k_t)[None, :] + (1.0 - k_t)) * k_t_ratio - - # Intepolate rel pos if needed. - rel_pos_h = _interpolate(rel_pos_h, dh) - rel_pos_w = _interpolate(rel_pos_w, dw) - rel_pos_t = _interpolate(rel_pos_t, dt) - Rh = rel_pos_h[dist_h.long()] - Rw = rel_pos_w[dist_w.long()] - Rt = rel_pos_t[dist_t.long()] - - B, n_head, _, dim = q.shape - - r_q = q[:, :, 1:].reshape(B, n_head, q_t, q_h, q_w, dim) - rel_h_q = torch.einsum("bythwc,hkc->bythwk", r_q, Rh) # [B, H, q_t, qh, qw, k_h] - rel_w_q = torch.einsum("bythwc,wkc->bythwk", r_q, Rw) # [B, H, q_t, qh, qw, k_w] - # [B, H, q_t, q_h, q_w, dim] -> [q_t, B, H, q_h, q_w, dim] -> [q_t, B*H*q_h*q_w, dim] - r_q = r_q.permute(2, 0, 1, 3, 4, 5).reshape(q_t, B * n_head * q_h * q_w, dim) - # [q_t, B*H*q_h*q_w, dim] * [q_t, dim, k_t] = [q_t, B*H*q_h*q_w, k_t] -> [B*H*q_h*q_w, q_t, k_t] - rel_q_t = torch.matmul(r_q, Rt.transpose(1, 2)).transpose(0, 1) - # [B*H*q_h*q_w, q_t, k_t] -> [B, H, q_t, q_h, q_w, k_t] - rel_q_t = rel_q_t.view(B, n_head, q_h, q_w, q_t, k_t).permute(0, 1, 4, 2, 3, 5) - - # Combine rel pos. - rel_pos = ( - rel_h_q[:, :, :, :, :, None, :, None] - + rel_w_q[:, :, :, :, :, None, None, :] - + rel_q_t[:, :, :, :, :, :, None, None] - ).reshape(B, n_head, q_t * q_h * q_w, k_t * k_h * k_w) - - # Add it to attention - attn[:, :, 1:, 1:] += rel_pos - - return attn - - # TODO: Consider handle 2d input if Temporal is 1 @@ -183,7 +108,82 @@ def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Ten return x, (T, H, W) -torch.fx.wrap("add_rel_pos") +def _interpolate(embedding: torch.Tensor, d: int) -> torch.Tensor: + if embedding.shape[0] == d: + return embedding + + return ( + nn.functional.interpolate( + embedding.permute(1, 0).unsqueeze(0), + size=d, + mode="linear", + ) + .squeeze(0) + .permute(1, 0) + ) + + +def _add_rel_pos( + attn: torch.Tensor, + q: torch.Tensor, + q_thw: Tuple[int, int, int], + k_thw: Tuple[int, int, int], + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + rel_pos_t: torch.Tensor, +) -> torch.Tensor: + # Modified code from: https://github.com/facebookresearch/SlowFast/commit/1aebd71a2efad823d52b827a3deaf15a56cf4932 + q_t, q_h, q_w = q_thw + k_t, k_h, k_w = k_thw + dh = int(2 * max(q_h, k_h) - 1) + dw = int(2 * max(q_w, k_w) - 1) + dt = int(2 * max(q_t, k_t) - 1) + + # Scale up rel pos if shapes for q and k are different. + q_h_ratio = max(k_h / q_h, 1.0) + k_h_ratio = max(q_h / k_h, 1.0) + dist_h = torch.arange(q_h)[:, None] * q_h_ratio - (torch.arange(k_h)[None, :] + (1.0 - k_h)) * k_h_ratio + q_w_ratio = max(k_w / q_w, 1.0) + k_w_ratio = max(q_w / k_w, 1.0) + dist_w = torch.arange(q_w)[:, None] * q_w_ratio - (torch.arange(k_w)[None, :] + (1.0 - k_w)) * k_w_ratio + q_t_ratio = max(k_t / q_t, 1.0) + k_t_ratio = max(q_t / k_t, 1.0) + dist_t = torch.arange(q_t)[:, None] * q_t_ratio - (torch.arange(k_t)[None, :] + (1.0 - k_t)) * k_t_ratio + + # Intepolate rel pos if needed. + rel_pos_h = _interpolate(rel_pos_h, dh) + rel_pos_w = _interpolate(rel_pos_w, dw) + rel_pos_t = _interpolate(rel_pos_t, dt) + Rh = rel_pos_h[dist_h.long()] + Rw = rel_pos_w[dist_w.long()] + Rt = rel_pos_t[dist_t.long()] + + B, n_head, _, dim = q.shape + + r_q = q[:, :, 1:].reshape(B, n_head, q_t, q_h, q_w, dim) + rel_h_q = torch.einsum("bythwc,hkc->bythwk", r_q, Rh) # [B, H, q_t, qh, qw, k_h] + rel_w_q = torch.einsum("bythwc,wkc->bythwk", r_q, Rw) # [B, H, q_t, qh, qw, k_w] + # [B, H, q_t, q_h, q_w, dim] -> [q_t, B, H, q_h, q_w, dim] -> [q_t, B*H*q_h*q_w, dim] + r_q = r_q.permute(2, 0, 1, 3, 4, 5).reshape(q_t, B * n_head * q_h * q_w, dim) + # [q_t, B*H*q_h*q_w, dim] * [q_t, dim, k_t] = [q_t, B*H*q_h*q_w, k_t] -> [B*H*q_h*q_w, q_t, k_t] + rel_q_t = torch.matmul(r_q, Rt.transpose(1, 2)).transpose(0, 1) + # [B*H*q_h*q_w, q_t, k_t] -> [B, H, q_t, q_h, q_w, k_t] + rel_q_t = rel_q_t.view(B, n_head, q_h, q_w, q_t, k_t).permute(0, 1, 4, 2, 3, 5) + + # Combine rel pos. + rel_pos = ( + rel_h_q[:, :, :, :, :, None, :, None] + + rel_w_q[:, :, :, :, :, None, None, :] + + rel_q_t[:, :, :, :, :, :, None, None] + ).reshape(B, n_head, q_t * q_h * q_w, k_t * k_h * k_w) + + # Add it to attention + attn[:, :, 1:, 1:] += rel_pos + + return attn + + +torch.fx.wrap("_add_rel_pos") class MultiscaleAttention(nn.Module): @@ -209,6 +209,7 @@ def __init__( self.head_dim = output_dim // num_heads self.scaler = 1.0 / math.sqrt(self.head_dim) self.residual_pool = residual_pool + self.rel_pos = rel_pos self.qkv = nn.Linear(embed_dim, 3 * output_dim) layers: List[nn.Module] = [nn.Linear(output_dim, output_dim)] @@ -264,7 +265,7 @@ def __init__( self.rel_pos_h: Optional[nn.Parameter] = None self.rel_pos_w: Optional[nn.Parameter] = None self.rel_pos_t: Optional[nn.Parameter] = None - if rel_pos: + if self.rel_pos: size = max(input_size[1:]) q_size = size // stride_q[1] if len(stride_q) > 0 else size kv_size = size // stride_kv[1] if len(stride_kv) > 0 else size @@ -292,7 +293,7 @@ def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Ten attn = torch.matmul(self.scaler * q, k.transpose(2, 3)) if self.rel_pos_h is not None and self.rel_pos_w is not None and self.rel_pos_t is not None: - attn = add_rel_pos( + attn = _add_rel_pos( attn, q, thw, From 076c353e4cffdfe34e7769adf525af4dbf622c73 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 9 Aug 2022 09:53:19 +0100 Subject: [PATCH 12/17] Modify the shortcut in the attention layer. --- .../ModelTester.test_mvit_v2_s_expect.pkl | Bin 939 -> 939 bytes torchvision/models/video/mvit.py | 11 ++++++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/test/expect/ModelTester.test_mvit_v2_s_expect.pkl b/test/expect/ModelTester.test_mvit_v2_s_expect.pkl index 836084297d7a1186194c0c53bd54da08ab2529f4..5ae3e4a0d768f0f04a78cc8fc0cd705cd85e172f 100644 GIT binary patch delta 230 zcmV+;P-GV<%rZ7JYQzSoe$45W!0YX0tl{~&SgvY)Leu_Rc zsXo7r6}3Js+Bm<&_Z>do-t|7Ob+$e<-^IRyixlMn*X1j;8$xs&7q$6T#!U;qFB delta 230 zcmVi_uHtk*3NyLBl>QLE zfN1x;7BA4gf@st}uoZ7T7e=-|m8k$fwZIm?;TfL36|a{+Jdq4P4o{B1m@~}1^eY>` go03dFVDN80P)i30nMeb|lMn*X1er(!!jt3z$Dzq Tuple[torch.Ten x = torch.matmul(attn, v) if self.residual_pool: - x.add_(q) # TODO: check x[:, :, 1:, :] += q[:, :, 1:, :] + _add_shortcut(x, q, self.rel_pos) x = x.transpose(1, 2).reshape(B, -1, self.output_dim) x = self.project(x) From be6d7e080a9b22de59089efd8268f99c19dfa5c9 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 9 Aug 2022 14:02:09 +0100 Subject: [PATCH 13/17] Add ported weights. --- torchvision/models/video/mvit.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/torchvision/models/video/mvit.py b/torchvision/models/video/mvit.py index 5b4f803f858..80b3ee635ae 100644 --- a/torchvision/models/video/mvit.py +++ b/torchvision/models/video/mvit.py @@ -617,7 +617,34 @@ class MViT_V1_B_Weights(WeightsEnum): class MViT_V2_S_Weights(WeightsEnum): - pass + KINETICS400_V1 = Weights( + url="https://download.pytorch.org/models/mvit_v2_s-ae3be167.pth", + transforms=partial( + VideoClassification, + crop_size=(224, 224), + resize_size=(256,), + mean=(0.45, 0.45, 0.45), + std=(0.225, 0.225, 0.225), + ), + meta={ + "min_size": (224, 224), + "min_temporal_size": 16, + "categories": _KINETICS400_CATEGORIES, + "recipe": "https://github.com/facebookresearch/SlowFast/blob/main/MODEL_ZOO.md", + "_docs": ( + "The weights were ported from the paper. The accuracies are estimated on video-level " + "with parameters `frame_rate=7.5`, `clips_per_video=5`, and `clip_len=16`" + ), + "num_params": 34537744, + "_metrics": { + "Kinetics-400": { + "acc@1": 80.757, + "acc@5": 94.665, + } + }, + }, + ) + DEFAULT = KINETICS400_V1 @register_model() From a4173acb191da84a3c54b7bfbe0cfce03ef97329 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 9 Aug 2022 20:51:41 +0100 Subject: [PATCH 14/17] Introduce a `residual_cls` config on the attention layer. --- torchvision/models/video/mvit.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/torchvision/models/video/mvit.py b/torchvision/models/video/mvit.py index 80b3ee635ae..bfe65e46226 100644 --- a/torchvision/models/video/mvit.py +++ b/torchvision/models/video/mvit.py @@ -183,11 +183,11 @@ def _add_rel_pos( return attn -def _add_shortcut(x: torch.Tensor, shortcut: torch.Tensor, skip_cls: bool): - if skip_cls: - x[:, :, 1:, :] += shortcut[:, :, 1:, :] - else: +def _add_shortcut(x: torch.Tensor, shortcut: torch.Tensor, residual_cls: bool): + if residual_cls: x.add_(shortcut) + else: + x[:, :, 1:, :] += shortcut[:, :, 1:, :] return x @@ -207,6 +207,7 @@ def __init__( stride_q: List[int], stride_kv: List[int], residual_pool: bool, + residual_cls: bool, rel_pos: bool, dropout: float = 0.0, norm_layer: Callable[..., nn.Module] = nn.LayerNorm, @@ -219,6 +220,7 @@ def __init__( self.scaler = 1.0 / math.sqrt(self.head_dim) self.residual_pool = residual_pool self.rel_pos = rel_pos + self.residual_cls = residual_cls self.qkv = nn.Linear(embed_dim, 3 * output_dim) layers: List[nn.Module] = [nn.Linear(output_dim, output_dim)] @@ -315,7 +317,7 @@ def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Ten x = torch.matmul(attn, v) if self.residual_pool: - _add_shortcut(x, q, self.rel_pos) + _add_shortcut(x, q, self.residual_cls) x = x.transpose(1, 2).reshape(B, -1, self.output_dim) x = self.project(x) @@ -328,6 +330,7 @@ def __init__( input_size: List[int], cnf: MSBlockConfig, residual_pool: bool, + residual_cls: bool, rel_pos: bool, proj_after_attn: bool, dropout: float = 0.0, @@ -362,6 +365,7 @@ def __init__( stride_kv=cnf.stride_kv, rel_pos=rel_pos, residual_pool=residual_pool, + residual_cls=residual_cls, dropout=dropout, norm_layer=norm_layer, ) @@ -428,6 +432,7 @@ def __init__( temporal_size: int, block_setting: Sequence[MSBlockConfig], residual_pool: bool, + residual_cls: bool, rel_pos: bool, proj_after_attn: bool, dropout: float = 0.5, @@ -445,6 +450,7 @@ def __init__( temporal_size (int): The temporal size ``T`` of the input. block_setting (sequence of MSBlockConfig): The Network structure. residual_pool (bool): If True, use MViTv2 pooling residual connection. + residual_cls (bool): If True, the addition on the residual connection will include the class embedding. rel_pos (bool): If True, use MViTv2's relative positional embeddings. proj_after_attn (bool): If True, apply the projection after the attention. dropout (float): Dropout rate. Default: 0.0. @@ -500,6 +506,7 @@ def __init__( input_size=input_size, cnf=cnf, residual_pool=residual_pool, + residual_cls=residual_cls, rel_pos=rel_pos, proj_after_attn=proj_after_attn, dropout=attention_dropout, @@ -573,6 +580,7 @@ def _mvit( temporal_size=temporal_size, block_setting=block_setting, residual_pool=kwargs.pop("residual_pool", False), + residual_cls=kwargs.pop("residual_cls", True), rel_pos=kwargs.pop("rel_pos", False), proj_after_attn=kwargs.pop("proj_after_attn", False), stochastic_depth_prob=stochastic_depth_prob, @@ -734,6 +742,7 @@ def mvit_v1_b(*, weights: Optional[MViT_V1_B_Weights] = None, progress: bool = T temporal_size=16, block_setting=block_setting, residual_pool=False, + residual_cls=False, stochastic_depth_prob=kwargs.pop("stochastic_depth_prob", 0.2), weights=weights, progress=progress, @@ -862,6 +871,7 @@ def mvit_v2_s(*, weights: Optional[MViT_V2_S_Weights] = None, progress: bool = T temporal_size=16, block_setting=block_setting, residual_pool=True, + residual_cls=False, rel_pos=True, proj_after_attn=True, stochastic_depth_prob=kwargs.pop("stochastic_depth_prob", 0.2), From 94e510bb5f7e2dd0f741dd90021fb119424fe880 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 10 Aug 2022 10:47:58 +0100 Subject: [PATCH 15/17] Make the patch_embed kernel/padding/stride configurable. --- torchvision/models/video/mvit.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/torchvision/models/video/mvit.py b/torchvision/models/video/mvit.py index bfe65e46226..0710123fafb 100644 --- a/torchvision/models/video/mvit.py +++ b/torchvision/models/video/mvit.py @@ -441,6 +441,9 @@ def __init__( num_classes: int = 400, block: Optional[Callable[..., nn.Module]] = None, norm_layer: Optional[Callable[..., nn.Module]] = None, + patch_embed_kernel: Tuple[int, int, int] = (3, 7, 7), + patch_embed_stride: Tuple[int, int, int] = (2, 4, 4), + patch_embed_padding: Tuple[int, int, int] = (1, 3, 3), ) -> None: """ MViT main class. @@ -459,6 +462,9 @@ def __init__( num_classes (int): The number of classes. block (callable, optional): Module specifying the layer which consists of the attention and mlp. norm_layer (callable, optional): Module specifying the normalization layer to use. + patch_embed_kernel (tuple of ints): The kernel of the convolution that patchifies the input. + patch_embed_stride (tuple of ints): The stride of the convolution that patchifies the input. + patch_embed_padding (tuple of ints): The padding of the convolution that patchifies the input. """ super().__init__() # This implementation employs a different parameterization scheme than the one used at PyTorch Video: @@ -480,9 +486,9 @@ def __init__( self.conv_proj = nn.Conv3d( in_channels=3, out_channels=block_setting[0].input_channels, - kernel_size=(3, 7, 7), - stride=(2, 4, 4), - padding=(1, 3, 3), + kernel_size=patch_embed_kernel, + stride=patch_embed_stride, + padding=patch_embed_padding, ) input_size = [size // stride for size, stride in zip((temporal_size,) + spatial_size, self.conv_proj.stride)] @@ -540,6 +546,8 @@ def __init__( nn.init.trunc_normal_(weights, std=0.02) def forward(self, x: torch.Tensor) -> torch.Tensor: + # Convert if necessary (B, C, H, W) -> (B, C, 1, H, W) + x = _unsqueeze(x, 5, 2)[0] # patchify and reshape: (B, C, T, H, W) -> (B, embed_channels[0], T', H', W') -> (B, THW', embed_channels[0]) x = self.conv_proj(x) x = x.flatten(2).transpose(1, 2) From 538ffb519b80c8181d5a0870c4fa1367f95582c8 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 10 Aug 2022 11:44:41 +0100 Subject: [PATCH 16/17] Apply changes from code-review. --- torchvision/models/video/mvit.py | 52 ++++++++++++++++---------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/torchvision/models/video/mvit.py b/torchvision/models/video/mvit.py index 0710123fafb..723e82c3bdf 100644 --- a/torchvision/models/video/mvit.py +++ b/torchvision/models/video/mvit.py @@ -183,8 +183,8 @@ def _add_rel_pos( return attn -def _add_shortcut(x: torch.Tensor, shortcut: torch.Tensor, residual_cls: bool): - if residual_cls: +def _add_shortcut(x: torch.Tensor, shortcut: torch.Tensor, residual_with_cls_embed: bool): + if residual_with_cls_embed: x.add_(shortcut) else: x[:, :, 1:, :] += shortcut[:, :, 1:, :] @@ -207,8 +207,8 @@ def __init__( stride_q: List[int], stride_kv: List[int], residual_pool: bool, - residual_cls: bool, - rel_pos: bool, + residual_with_cls_embed: bool, + rel_pos_embed: bool, dropout: float = 0.0, norm_layer: Callable[..., nn.Module] = nn.LayerNorm, ) -> None: @@ -219,8 +219,7 @@ def __init__( self.head_dim = output_dim // num_heads self.scaler = 1.0 / math.sqrt(self.head_dim) self.residual_pool = residual_pool - self.rel_pos = rel_pos - self.residual_cls = residual_cls + self.residual_with_cls_embed = residual_with_cls_embed self.qkv = nn.Linear(embed_dim, 3 * output_dim) layers: List[nn.Module] = [nn.Linear(output_dim, output_dim)] @@ -276,7 +275,7 @@ def __init__( self.rel_pos_h: Optional[nn.Parameter] = None self.rel_pos_w: Optional[nn.Parameter] = None self.rel_pos_t: Optional[nn.Parameter] = None - if self.rel_pos: + if rel_pos_embed: size = max(input_size[1:]) q_size = size // stride_q[1] if len(stride_q) > 0 else size kv_size = size // stride_kv[1] if len(stride_kv) > 0 else size @@ -317,7 +316,7 @@ def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Ten x = torch.matmul(attn, v) if self.residual_pool: - _add_shortcut(x, q, self.residual_cls) + _add_shortcut(x, q, self.residual_with_cls_embed) x = x.transpose(1, 2).reshape(B, -1, self.output_dim) x = self.project(x) @@ -330,8 +329,8 @@ def __init__( input_size: List[int], cnf: MSBlockConfig, residual_pool: bool, - residual_cls: bool, - rel_pos: bool, + residual_with_cls_embed: bool, + rel_pos_embed: bool, proj_after_attn: bool, dropout: float = 0.0, stochastic_depth_prob: float = 0.0, @@ -363,9 +362,9 @@ def __init__( kernel_kv=cnf.kernel_kv, stride_q=cnf.stride_q, stride_kv=cnf.stride_kv, - rel_pos=rel_pos, + rel_pos_embed=rel_pos_embed, residual_pool=residual_pool, - residual_cls=residual_cls, + residual_with_cls_embed=residual_with_cls_embed, dropout=dropout, norm_layer=norm_layer, ) @@ -397,7 +396,7 @@ def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Ten class PositionalEncoding(nn.Module): - def __init__(self, embed_size: int, spatial_size: Tuple[int, int], temporal_size: int, rel_pos: bool) -> None: + def __init__(self, embed_size: int, spatial_size: Tuple[int, int], temporal_size: int, rel_pos_embed: bool) -> None: super().__init__() self.spatial_size = spatial_size self.temporal_size = temporal_size @@ -406,7 +405,7 @@ def __init__(self, embed_size: int, spatial_size: Tuple[int, int], temporal_size self.spatial_pos: Optional[nn.Parameter] = None self.temporal_pos: Optional[nn.Parameter] = None self.class_pos: Optional[nn.Parameter] = None - if not rel_pos: + if not rel_pos_embed: self.spatial_pos = nn.Parameter(torch.zeros(self.spatial_size[0] * self.spatial_size[1], embed_size)) self.temporal_pos = nn.Parameter(torch.zeros(self.temporal_size, embed_size)) self.class_pos = nn.Parameter(torch.zeros(embed_size)) @@ -432,8 +431,8 @@ def __init__( temporal_size: int, block_setting: Sequence[MSBlockConfig], residual_pool: bool, - residual_cls: bool, - rel_pos: bool, + residual_with_cls_embed: bool, + rel_pos_embed: bool, proj_after_attn: bool, dropout: float = 0.5, attention_dropout: float = 0.0, @@ -453,8 +452,9 @@ def __init__( temporal_size (int): The temporal size ``T`` of the input. block_setting (sequence of MSBlockConfig): The Network structure. residual_pool (bool): If True, use MViTv2 pooling residual connection. - residual_cls (bool): If True, the addition on the residual connection will include the class embedding. - rel_pos (bool): If True, use MViTv2's relative positional embeddings. + residual_with_cls_embed (bool): If True, the addition on the residual connection will include + the class embedding. + rel_pos_embed (bool): If True, use MViTv2's relative positional embeddings. proj_after_attn (bool): If True, apply the projection after the attention. dropout (float): Dropout rate. Default: 0.0. attention_dropout (float): Attention dropout rate. Default: 0.0. @@ -498,7 +498,7 @@ def __init__( embed_size=block_setting[0].input_channels, spatial_size=(input_size[1], input_size[2]), temporal_size=input_size[0], - rel_pos=rel_pos, + rel_pos_embed=rel_pos_embed, ) # Encoder module @@ -512,8 +512,8 @@ def __init__( input_size=input_size, cnf=cnf, residual_pool=residual_pool, - residual_cls=residual_cls, - rel_pos=rel_pos, + residual_with_cls_embed=residual_with_cls_embed, + rel_pos_embed=rel_pos_embed, proj_after_attn=proj_after_attn, dropout=attention_dropout, stochastic_depth_prob=sd_prob, @@ -588,8 +588,8 @@ def _mvit( temporal_size=temporal_size, block_setting=block_setting, residual_pool=kwargs.pop("residual_pool", False), - residual_cls=kwargs.pop("residual_cls", True), - rel_pos=kwargs.pop("rel_pos", False), + residual_with_cls_embed=kwargs.pop("residual_with_cls_embed", True), + rel_pos_embed=kwargs.pop("rel_pos_embed", False), proj_after_attn=kwargs.pop("proj_after_attn", False), stochastic_depth_prob=stochastic_depth_prob, **kwargs, @@ -750,7 +750,7 @@ def mvit_v1_b(*, weights: Optional[MViT_V1_B_Weights] = None, progress: bool = T temporal_size=16, block_setting=block_setting, residual_pool=False, - residual_cls=False, + residual_with_cls_embed=False, stochastic_depth_prob=kwargs.pop("stochastic_depth_prob", 0.2), weights=weights, progress=progress, @@ -879,8 +879,8 @@ def mvit_v2_s(*, weights: Optional[MViT_V2_S_Weights] = None, progress: bool = T temporal_size=16, block_setting=block_setting, residual_pool=True, - residual_cls=False, - rel_pos=True, + residual_with_cls_embed=False, + rel_pos_embed=True, proj_after_attn=True, stochastic_depth_prob=kwargs.pop("stochastic_depth_prob", 0.2), weights=weights, From bae10696d94f0506be9f8a15430ee4949293fc3a Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 10 Aug 2022 11:48:04 +0100 Subject: [PATCH 17/17] Remove stale todo. --- torchvision/models/video/mvit.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/torchvision/models/video/mvit.py b/torchvision/models/video/mvit.py index 723e82c3bdf..7283a21bb0d 100644 --- a/torchvision/models/video/mvit.py +++ b/torchvision/models/video/mvit.py @@ -24,9 +24,6 @@ ] -# TODO: Consider handle 2d input if Temporal is 1 - - @dataclass class MSBlockConfig: num_heads: int