diff --git a/docs/source/models/swin_transformer.rst b/docs/source/models/swin_transformer.rst index 3eb74069176..35b52987954 100644 --- a/docs/source/models/swin_transformer.rst +++ b/docs/source/models/swin_transformer.rst @@ -3,16 +3,18 @@ SwinTransformer .. currentmodule:: torchvision.models -The SwinTransformer model is based on the `Swin Transformer: Hierarchical Vision +The SwinTransformer models are based on the `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows `__ paper. +SwinTransformer V2 models are based on the `Swin Transformer V2: Scaling Up Capacity +and Resolution `__ +paper. Model builders -------------- -The following model builders can be used to instantiate an SwinTransformer model. -`swin_t` can be instantiated with pre-trained weights and all others without. +The following model builders can be used to instantiate an SwinTransformer model (original and V2) with and without pre-trained weights. All the model builders internally rely on the ``torchvision.models.swin_transformer.SwinTransformer`` base class. Please refer to the `source code `_ for @@ -25,3 +27,6 @@ more details about this class. swin_t swin_s swin_b + swin_v2_t + swin_v2_s + swin_v2_b diff --git a/references/classification/README.md b/references/classification/README.md index da30159542b..e8d62134ca2 100644 --- a/references/classification/README.md +++ b/references/classification/README.md @@ -236,6 +236,17 @@ Note that `--val-resize-size` was optimized in a post-training step, see their ` +### SwinTransformer V2 +``` +torchrun --nproc_per_node=8 train.py\ +--model $MODEL --epochs 300 --batch-size 128 --opt adamw --lr 0.001 --weight-decay 0.05 --norm-weight-decay 0.0 --bias-weight-decay 0.0 --transformer-embedding-decay 0.0 --lr-scheduler cosineannealinglr --lr-min 0.00001 --lr-warmup-method linear --lr-warmup-epochs 20 --lr-warmup-decay 0.01 --amp --label-smoothing 0.1 --mixup-alpha 0.8 --clip-grad-norm 5.0 --cutmix-alpha 1.0 --random-erase 0.25 --interpolation bicubic --auto-augment ta_wide --model-ema --ra-sampler --ra-reps 4 --val-resize-size 256 --val-crop-size 256 --train-crop-size 256 +``` +Here `$MODEL` is one of `swin_v2_t`, `swin_v2_s` or `swin_v2_b`. +Note that `--val-resize-size` was optimized in a post-training step, see their `Weights` entry for the exact value. + + + + ### ShuffleNet V2 ``` torchrun --nproc_per_node=8 train.py \ diff --git a/test/expect/ModelTester.test_swin_v2_b_expect.pkl b/test/expect/ModelTester.test_swin_v2_b_expect.pkl new file mode 100644 index 00000000000..5b2be51a2a9 Binary files /dev/null and b/test/expect/ModelTester.test_swin_v2_b_expect.pkl differ diff --git a/test/expect/ModelTester.test_swin_v2_s_expect.pkl b/test/expect/ModelTester.test_swin_v2_s_expect.pkl new file mode 100644 index 00000000000..fe313b0c284 Binary files /dev/null and b/test/expect/ModelTester.test_swin_v2_s_expect.pkl differ diff --git a/test/expect/ModelTester.test_swin_v2_t_expect.pkl b/test/expect/ModelTester.test_swin_v2_t_expect.pkl new file mode 100644 index 00000000000..f3752af6265 Binary files /dev/null and b/test/expect/ModelTester.test_swin_v2_t_expect.pkl differ diff --git a/test/test_models.py b/test/test_models.py index 5061888d71d..9e6251924c5 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -332,6 +332,9 @@ def _check_input_backprop(model, inputs): "swin_t", "swin_s", "swin_b", + "swin_v2_t", + "swin_v2_s", + "swin_v2_b", ] for m in slow_models: _model_params[m] = {"input_shape": (1, 3, 64, 64)} diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index c5bc43a14fd..9f43b546d59 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -1,3 +1,4 @@ +import math from functools import partial from typing import Any, Callable, List, Optional @@ -19,21 +20,45 @@ "Swin_T_Weights", "Swin_S_Weights", "Swin_B_Weights", + "Swin_V2_T_Weights", + "Swin_V2_S_Weights", + "Swin_V2_B_Weights", "swin_t", "swin_s", "swin_b", + "swin_v2_t", + "swin_v2_s", + "swin_v2_b", ] -def _patch_merging_pad(x): +def _patch_merging_pad(x: torch.Tensor) -> torch.Tensor: H, W, _ = x.shape[-3:] x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C + x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C + x2 = x[..., 0::2, 1::2, :] # ... H/2 W/2 C + x3 = x[..., 1::2, 1::2, :] # ... H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # ... H/2 W/2 4*C return x torch.fx.wrap("_patch_merging_pad") +def _get_relative_position_bias( + relative_position_bias_table: torch.Tensor, relative_position_index: torch.Tensor, window_size: List[int] +) -> torch.Tensor: + N = window_size[0] * window_size[1] + relative_position_bias = relative_position_bias_table[relative_position_index] # type: ignore[index] + relative_position_bias = relative_position_bias.view(N, N, -1) + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0) + return relative_position_bias + + +torch.fx.wrap("_get_relative_position_bias") + + class PatchMerging(nn.Module): """Patch Merging Layer. Args: @@ -56,15 +81,35 @@ def forward(self, x: Tensor): Tensor with layout of [..., H/2, W/2, 2*C] """ x = _patch_merging_pad(x) + x = self.norm(x) + x = self.reduction(x) # ... H/2 W/2 2*C + return x - x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C - x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C - x2 = x[..., 0::2, 1::2, :] # ... H/2 W/2 C - x3 = x[..., 1::2, 1::2, :] # ... H/2 W/2 C - x = torch.cat([x0, x1, x2, x3], -1) # ... H/2 W/2 4*C - x = self.norm(x) +class PatchMergingV2(nn.Module): + """Patch Merging Layer for Swin Transformer V2. + Args: + dim (int): Number of input channels. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + """ + + def __init__(self, dim: int, norm_layer: Callable[..., nn.Module] = nn.LayerNorm): + super().__init__() + _log_api_usage_once(self) + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(2 * dim) # difference + + def forward(self, x: Tensor): + """ + Args: + x (Tensor): input tensor with expected layout of [..., H, W, C] + Returns: + Tensor with layout of [..., H/2, W/2, 2*C] + """ + x = _patch_merging_pad(x) x = self.reduction(x) # ... H/2 W/2 2*C + x = self.norm(x) return x @@ -80,6 +125,7 @@ def shifted_window_attention( dropout: float = 0.0, qkv_bias: Optional[Tensor] = None, proj_bias: Optional[Tensor] = None, + logit_scale: Optional[torch.Tensor] = None, ): """ Window based multi-head self attention (W-MSA) module with relative position bias. @@ -96,6 +142,7 @@ def shifted_window_attention( dropout (float): Dropout ratio of output. Default: 0.0. qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None. proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None. + logit_scale (Tensor[out_dim], optional): Logit scale of cosine attention for Swin Transformer V2. Default: None. Returns: Tensor[N, H, W, C]: The output tensor after shifted window attention. """ @@ -123,11 +170,21 @@ def shifted_window_attention( x = x.permute(0, 1, 3, 2, 4, 5).reshape(B * num_windows, window_size[0] * window_size[1], C) # B*nW, Ws*Ws, C # multi-head attention + if logit_scale is not None and qkv_bias is not None: + qkv_bias = qkv_bias.clone() + length = qkv_bias.numel() // 3 + qkv_bias[length : 2 * length].zero_() qkv = F.linear(x, qkv_weight, qkv_bias) qkv = qkv.reshape(x.size(0), x.size(1), 3, num_heads, C // num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] - q = q * (C // num_heads) ** -0.5 - attn = q.matmul(k.transpose(-2, -1)) + if logit_scale is not None: + # cosine attention + attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1) + logit_scale = torch.clamp(logit_scale, max=math.log(100.0)).exp() + attn = attn * logit_scale + else: + q = q * (C // num_heads) ** -0.5 + attn = q.matmul(k.transpose(-2, -1)) # add relative position bias attn = attn + relative_position_bias @@ -200,11 +257,17 @@ def __init__( self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.define_relative_position_bias_table() + self.define_relative_position_index() + + def define_relative_position_bias_table(self): # define a parameter table of relative position bias self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) + torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), self.num_heads) ) # 2*Wh-1 * 2*Ww-1, nH + nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02) + def define_relative_position_index(self): # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.window_size[0]) coords_w = torch.arange(self.window_size[1]) @@ -215,10 +278,13 @@ def __init__( relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 - relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww*Wh*Ww + relative_position_index = relative_coords.sum(-1).flatten() # Wh*Ww*Wh*Ww self.register_buffer("relative_position_index", relative_position_index) - nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02) + def get_relative_position_bias(self) -> torch.Tensor: + return _get_relative_position_bias( + self.relative_position_bias_table, self.relative_position_index, self.window_size # type: ignore[arg-type] + ) def forward(self, x: Tensor): """ @@ -227,12 +293,91 @@ def forward(self, x: Tensor): Returns: Tensor with same layout as input, i.e. [B, H, W, C] """ + relative_position_bias = self.get_relative_position_bias() + return shifted_window_attention( + x, + self.qkv.weight, + self.proj.weight, + relative_position_bias, + self.window_size, + self.num_heads, + shift_size=self.shift_size, + attention_dropout=self.attention_dropout, + dropout=self.dropout, + qkv_bias=self.qkv.bias, + proj_bias=self.proj.bias, + ) + + +class ShiftedWindowAttentionV2(ShiftedWindowAttention): + """ + See :func:`shifted_window_attention_v2`. + """ + + def __init__( + self, + dim: int, + window_size: List[int], + shift_size: List[int], + num_heads: int, + qkv_bias: bool = True, + proj_bias: bool = True, + attention_dropout: float = 0.0, + dropout: float = 0.0, + ): + super().__init__( + dim, + window_size, + shift_size, + num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attention_dropout=attention_dropout, + dropout=dropout, + ) + + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) + # mlp to generate continuous relative position bias + self.cpb_mlp = nn.Sequential( + nn.Linear(2, 512, bias=True), nn.ReLU(inplace=True), nn.Linear(512, num_heads, bias=False) + ) + if qkv_bias: + length = self.qkv.bias.numel() // 3 + self.qkv.bias[length : 2 * length].data.zero_() + + def define_relative_position_bias_table(self): + # get relative_coords_table + relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32) + relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32) + relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w], indexing="ij")) + relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2 + + relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1 + relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1 + + relative_coords_table *= 8 # normalize to -8, 8 + relative_coords_table = ( + torch.sign(relative_coords_table) * torch.log2(torch.abs(relative_coords_table) + 1.0) / 3.0 + ) + self.register_buffer("relative_coords_table", relative_coords_table) - N = self.window_size[0] * self.window_size[1] - relative_position_bias = self.relative_position_bias_table[self.relative_position_index] # type: ignore[index] - relative_position_bias = relative_position_bias.view(N, N, -1) - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0) + def get_relative_position_bias(self) -> torch.Tensor: + relative_position_bias = _get_relative_position_bias( + self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads), + self.relative_position_index, # type: ignore[arg-type] + self.window_size, + ) + relative_position_bias = 16 * torch.sigmoid(relative_position_bias) + return relative_position_bias + def forward(self, x: Tensor): + """ + Args: + x (Tensor): Tensor with layout of [B, H, W, C] + Returns: + Tensor with same layout as input, i.e. [B, H, W, C] + """ + relative_position_bias = self.get_relative_position_bias() return shifted_window_attention( x, self.qkv.weight, @@ -245,6 +390,7 @@ def forward(self, x: Tensor): dropout=self.dropout, qkv_bias=self.qkv.bias, proj_bias=self.proj.bias, + logit_scale=self.logit_scale, ) @@ -305,6 +451,54 @@ def forward(self, x: Tensor): return x +class SwinTransformerBlockV2(SwinTransformerBlock): + """ + Swin Transformer V2 Block. + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (List[int]): Window size. + shift_size (List[int]): Shift size for shifted window attention. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. + dropout (float): Dropout rate. Default: 0.0. + attention_dropout (float): Attention dropout rate. Default: 0.0. + stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + attn_layer (nn.Module): Attention layer. Default: ShiftedWindowAttentionV2. + """ + + def __init__( + self, + dim: int, + num_heads: int, + window_size: List[int], + shift_size: List[int], + mlp_ratio: float = 4.0, + dropout: float = 0.0, + attention_dropout: float = 0.0, + stochastic_depth_prob: float = 0.0, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_layer: Callable[..., nn.Module] = ShiftedWindowAttentionV2, + ): + super().__init__( + dim, + num_heads, + window_size, + shift_size, + mlp_ratio=mlp_ratio, + dropout=dropout, + attention_dropout=attention_dropout, + stochastic_depth_prob=stochastic_depth_prob, + norm_layer=norm_layer, + attn_layer=attn_layer, + ) + + def forward(self, x: Tensor): + x = x + self.stochastic_depth(self.norm1(self.attn(x))) + x = x + self.stochastic_depth(self.norm2(self.mlp(x))) + return x + + class SwinTransformer(nn.Module): """ Implements Swin Transformer from the `"Swin Transformer: Hierarchical Vision Transformer using @@ -318,10 +512,11 @@ class SwinTransformer(nn.Module): mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. dropout (float): Dropout rate. Default: 0.0. attention_dropout (float): Attention dropout rate. Default: 0.0. - stochastic_depth_prob (float): Stochastic depth rate. Default: 0.0. + stochastic_depth_prob (float): Stochastic depth rate. Default: 0.1. num_classes (int): Number of classes for classification head. Default: 1000. block (nn.Module, optional): SwinTransformer Block. Default: None. norm_layer (nn.Module, optional): Normalization layer. Default: None. + downsample_layer (nn.Module): Downsample layer (patch merging). Default: PatchMerging. """ def __init__( @@ -334,10 +529,11 @@ def __init__( mlp_ratio: float = 4.0, dropout: float = 0.0, attention_dropout: float = 0.0, - stochastic_depth_prob: float = 0.0, + stochastic_depth_prob: float = 0.1, num_classes: int = 1000, norm_layer: Optional[Callable[..., nn.Module]] = None, block: Optional[Callable[..., nn.Module]] = None, + downsample_layer: Callable[..., nn.Module] = PatchMerging, ): super().__init__() _log_api_usage_once(self) @@ -345,7 +541,6 @@ def __init__( if block is None: block = SwinTransformerBlock - if norm_layer is None: norm_layer = partial(nn.LayerNorm, eps=1e-5) @@ -387,12 +582,14 @@ def __init__( layers.append(nn.Sequential(*stage)) # add patch merging layer if i_stage < (len(depths) - 1): - layers.append(PatchMerging(dim, norm_layer)) + layers.append(downsample_layer(dim, norm_layer)) self.features = nn.Sequential(*layers) num_features = embed_dim * 2 ** (len(depths) - 1) self.norm = norm_layer(num_features) + self.permute = Permute([0, 3, 1, 2]) self.avgpool = nn.AdaptiveAvgPool2d(1) + self.flatten = nn.Flatten(1) self.head = nn.Linear(num_features, num_classes) for m in self.modules(): @@ -404,9 +601,9 @@ def __init__( def forward(self, x): x = self.features(x) x = self.norm(x) - x = x.permute(0, 3, 1, 2) + x = self.permute(x) x = self.avgpool(x) - x = torch.flatten(x, 1) + x = self.flatten(x) x = self.head(x) return x @@ -515,6 +712,75 @@ class Swin_B_Weights(WeightsEnum): DEFAULT = IMAGENET1K_V1 +class Swin_V2_T_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/swin_v2_t-b137f0e2.pth", + transforms=partial( + ImageClassification, crop_size=256, resize_size=260, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_META, + "num_params": 28351570, + "min_size": (256, 256), + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer-v2", + "_metrics": { + "ImageNet-1K": { + "acc@1": 82.072, + "acc@5": 96.132, + } + }, + "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""", + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class Swin_V2_S_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/swin_v2_s-637d8ceb.pth", + transforms=partial( + ImageClassification, crop_size=256, resize_size=260, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_META, + "num_params": 49737442, + "min_size": (256, 256), + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer-v2", + "_metrics": { + "ImageNet-1K": { + "acc@1": 83.712, + "acc@5": 96.816, + } + }, + "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""", + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class Swin_V2_B_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/swin_v2_b-781e5279.pth", + transforms=partial( + ImageClassification, crop_size=256, resize_size=272, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_META, + "num_params": 87930848, + "min_size": (256, 256), + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer-v2", + "_metrics": { + "ImageNet-1K": { + "acc@1": 84.112, + "acc@5": 96.864, + } + }, + "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""", + }, + ) + DEFAULT = IMAGENET1K_V1 + + @register_model() def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: """ @@ -624,3 +890,120 @@ def swin_b(*, weights: Optional[Swin_B_Weights] = None, progress: bool = True, * progress=progress, **kwargs, ) + + +@register_model() +def swin_v2_t(*, weights: Optional[Swin_V2_T_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: + """ + Constructs a swin_v2_tiny architecture from + `Swin Transformer V2: Scaling Up Capacity and Resolution `_. + + Args: + weights (:class:`~torchvision.models.Swin_V2_T_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.Swin_V2_T_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.Swin_V2_T_Weights + :members: + """ + weights = Swin_V2_T_Weights.verify(weights) + + return _swin_transformer( + patch_size=[4, 4], + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=[8, 8], + stochastic_depth_prob=0.2, + weights=weights, + progress=progress, + block=SwinTransformerBlockV2, + downsample_layer=PatchMergingV2, + **kwargs, + ) + + +@register_model() +def swin_v2_s(*, weights: Optional[Swin_V2_S_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: + """ + Constructs a swin_v2_small architecture from + `Swin Transformer V2: Scaling Up Capacity and Resolution `_. + + Args: + weights (:class:`~torchvision.models.Swin_V2_S_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.Swin_V2_S_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.Swin_V2_S_Weights + :members: + """ + weights = Swin_V2_S_Weights.verify(weights) + + return _swin_transformer( + patch_size=[4, 4], + embed_dim=96, + depths=[2, 2, 18, 2], + num_heads=[3, 6, 12, 24], + window_size=[8, 8], + stochastic_depth_prob=0.3, + weights=weights, + progress=progress, + block=SwinTransformerBlockV2, + downsample_layer=PatchMergingV2, + **kwargs, + ) + + +@register_model() +def swin_v2_b(*, weights: Optional[Swin_V2_B_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: + """ + Constructs a swin_v2_base architecture from + `Swin Transformer V2: Scaling Up Capacity and Resolution `_. + + Args: + weights (:class:`~torchvision.models.Swin_V2_B_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.Swin_V2_B_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.Swin_V2_B_Weights + :members: + """ + weights = Swin_V2_B_Weights.verify(weights) + + return _swin_transformer( + patch_size=[4, 4], + embed_dim=128, + depths=[2, 2, 18, 2], + num_heads=[4, 8, 16, 32], + window_size=[8, 8], + stochastic_depth_prob=0.5, + weights=weights, + progress=progress, + block=SwinTransformerBlockV2, + downsample_layer=PatchMergingV2, + **kwargs, + )