""" Vision Transformer (ViT) in PyTorch

A PyTorch implement of Vision Transformers as described in
'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929

The official jax code is released and available at https://github.com/google-research/vision_transformer

Status/TODO:
* Models updated to be compatible with official impl. Args added to support backward compat for old PyTorch weights.
* Weights ported from official jax impl for 384x384 base and small models, 16x16 and 32x32 patches.
* Trained (supervised on ImageNet-1k) my custom 'small' patch model to 77.9, 'base' to 79.4 top-1 with this code.
* Hopefully find time and GPUs for SSL or unsupervised pretraining on OpenImages w/ ImageNet fine-tune in future.

Acknowledgments:
* The paper authors for releasing code and weights, thanks!
* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
for some einops/einsum fun
* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
* Bert reference code checks against Huggingface Transformers and Tensorflow Bert

Hacked together by / Copyright 2020 Ross Wightman
"""

import math
import torch
from functools import partial
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import drop_path, to_2tuple, trunc_normal_


def _cfg(url="", **kwargs):

    return {
        "url": url,
        "num_classes": 1000,
        "input_size": (3, 224, 224),
        "pool_size": None,
        "crop_pct": 0.9,
        "interpolation": "bicubic",
        "mean": (0.5, 0.5, 0.5),
        "std": (0.5, 0.5, 0.5),
        **kwargs,
    }


def torch_memory(device, tag=""):

    # Checks and prints GPU memory
    print(tag, f"{torch.cuda.memory_allocated(device)/1024/1024:.2f} MB USED")
    print(tag, f"{torch.cuda.memory_reserved(device)/1024/1024:.2f} MB RESERVED")
    print(tag, f"{torch.cuda.max_memory_allocated(device)/1024/1024:.2f} MB USED MAX")
    print(
        tag, f"{torch.cuda.max_memory_reserved(device)/1024/1024:.2f} MB RESERVED MAX"
    )
    print("")


class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks)."""

    def __init__(self, drop_prob=None):

        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):

        return drop_path(x, self.drop_prob, self.training)

    def extra_repr(self) -> str:

        return "p={}".format(self.drop_prob)


class Mlp(nn.Module):
    def __init__(
        self,
        in_features,
        hidden_features=None,
        out_features=None,
        act_layer=nn.GELU,
        drop=0.0,
    ):
        super().__init__()

        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):

        x = self.fc1(x)
        x = self.act(x)
        # x = self.drop(x)
        # commit this for the orignal BERT implement
        x = self.fc2(x)
        x = self.drop(x)
        return x


class Attention(nn.Module):
    def __init__(
        self,
        dim,
        num_heads=8,
        qkv_bias=False,
        qk_scale=None,
        attn_drop=0.0,
        proj_drop=0.0,
        window_size=None,
        attn_head_dim=None,
    ):
        super().__init__()

        self.num_heads = num_heads
        head_dim = dim // num_heads
        if attn_head_dim is not None:
            head_dim = attn_head_dim
        all_head_dim = head_dim * self.num_heads
        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
        self.scale = qk_scale or head_dim**-0.5

        self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
        if qkv_bias:
            self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
            self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
        else:
            self.q_bias = None
            self.v_bias = None

        if window_size:
            self.window_size = window_size
            self.num_relative_distance = (2 * window_size[0] - 1) * (
                2 * window_size[1] - 1
            ) + 3
            self.relative_position_bias_table = nn.Parameter(
                torch.zeros(self.num_relative_distance, num_heads)
            )  # 2*Wh-1 * 2*Ww-1, nH
            # cls to token & token 2 cls & cls to cls

            # get pair-wise relative position index for each token inside the window
            coords_h = torch.arange(window_size[0])
            coords_w = torch.arange(window_size[1])
            coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
            coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
            relative_coords = (
                coords_flatten[:, :, None] - coords_flatten[:, None, :]
            )  # 2, Wh*Ww, Wh*Ww
            relative_coords = relative_coords.permute(
                1, 2, 0
            ).contiguous()  # Wh*Ww, Wh*Ww, 2
            relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0
            relative_coords[:, :, 1] += window_size[1] - 1
            relative_coords[:, :, 0] *= 2 * window_size[1] - 1
            relative_position_index = torch.zeros(
                size=(window_size[0] * window_size[1] + 1,) * 2,
                dtype=relative_coords.dtype,
            )
            relative_position_index[1:, 1:] = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
            relative_position_index[0, 0:] = self.num_relative_distance - 3
            relative_position_index[0:, 0] = self.num_relative_distance - 2
            relative_position_index[0, 0] = self.num_relative_distance - 1

            self.register_buffer("relative_position_index", relative_position_index)

            # trunc_normal_(self.relative_position_bias_table, std=.0)
        else:
            self.window_size = None
            self.relative_position_bias_table = None
            self.relative_position_index = None

        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(all_head_dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x, rel_pos_bias=None, training_window_size=None):

        B, N, C = x.shape
        qkv_bias = None
        if self.q_bias is not None:
            qkv_bias = torch.cat(
                (
                    self.q_bias,
                    torch.zeros_like(self.v_bias, requires_grad=False),
                    self.v_bias,
                )
            )
        # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
        qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
        q, k, v = (
            qkv[0],
            qkv[1],
            qkv[2],
        )  # make torchscript happy (cannot use tensor as tuple)

        q = q * self.scale
        attn = q @ k.transpose(-2, -1)

        if self.relative_position_bias_table is not None:
            if training_window_size == self.window_size:
                relative_position_bias = self.relative_position_bias_table[
                    self.relative_position_index.view(-1)
                ].view(
                    self.window_size[0] * self.window_size[1] + 1,
                    self.window_size[0] * self.window_size[1] + 1,
                    -1,
                )  # Wh*Ww,Wh*Ww,nH
                relative_position_bias = relative_position_bias.permute(
                    2, 0, 1
                ).contiguous()  # nH, Wh*Ww, Wh*Ww
                attn = attn + relative_position_bias.unsqueeze(0)
            else:
                training_window_size = tuple(training_window_size.tolist())
                new_num_relative_distance = (2 * training_window_size[0] - 1) * (
                    2 * training_window_size[1] - 1
                ) + 3
                # new_num_relative_dis 为 所有可能的相对位置选项，包含cls-cls，tok-cls，与cls-tok
                new_relative_position_bias_table = F.interpolate(
                    self.relative_position_bias_table[:-3, :]
                    .permute(1, 0)
                    .view(
                        1,
                        self.num_heads,
                        2 * self.window_size[0] - 1,
                        2 * self.window_size[1] - 1,
                    ),
                    size=(
                        2 * training_window_size[0] - 1,
                        2 * training_window_size[1] - 1,
                    ),
                    mode="bicubic",
                    align_corners=False,
                )
                new_relative_position_bias_table = (
                    new_relative_position_bias_table.view(
                        self.num_heads, new_num_relative_distance - 3
                    ).permute(1, 0)
                )
                new_relative_position_bias_table = torch.cat(
                    [
                        new_relative_position_bias_table,
                        self.relative_position_bias_table[-3::],
                    ],
                    dim=0,
                )

                # get pair-wise relative position index for each token inside the window
                coords_h = torch.arange(training_window_size[0])
                coords_w = torch.arange(training_window_size[1])
                coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
                coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
                relative_coords = (
                    coords_flatten[:, :, None] - coords_flatten[:, None, :]
                )  # 2, Wh*Ww, Wh*Ww
                relative_coords = relative_coords.permute(
                    1, 2, 0
                ).contiguous()  # Wh*Ww, Wh*Ww, 2
                relative_coords[:, :, 0] += (
                    training_window_size[0] - 1
                )  # shift to start from 0
                relative_coords[:, :, 1] += training_window_size[1] - 1
                relative_coords[:, :, 0] *= 2 * training_window_size[1] - 1
                relative_position_index = torch.zeros(
                    size=(training_window_size[0] * training_window_size[1] + 1,) * 2,
                    dtype=relative_coords.dtype,
                )
                relative_position_index[1:, 1:] = relative_coords.sum(
                    -1
                )  # Wh*Ww, Wh*Ww
                relative_position_index[0, 0:] = new_num_relative_distance - 3
                relative_position_index[0:, 0] = new_num_relative_distance - 2
                relative_position_index[0, 0] = new_num_relative_distance - 1

                relative_position_bias = new_relative_position_bias_table[
                    relative_position_index.view(-1)
                ].view(
                    training_window_size[0] * training_window_size[1] + 1,
                    training_window_size[0] * training_window_size[1] + 1,
                    -1,
                )  # Wh*Ww,Wh*Ww,nH
                relative_position_bias = relative_position_bias.permute(
                    2, 0, 1
                ).contiguous()  # nH, Wh*Ww, Wh*Ww
                attn = attn + relative_position_bias.unsqueeze(0)

        if rel_pos_bias is not None:
            attn = attn + rel_pos_bias

        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class Block(nn.Module):

    def __init__(
        self,
        dim,
        num_heads,
        mlp_ratio=4.0,
        qkv_bias=False,
        qk_scale=None,
        drop=0.0,
        attn_drop=0.0,
        drop_path=0.0,
        init_values=None,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm,
        window_size=None,
        attn_head_dim=None,
    ):
        super().__init__()

        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=drop,
            window_size=window_size,
            attn_head_dim=attn_head_dim,
        )
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            act_layer=act_layer,
            drop=drop,
        )

        if init_values is not None:
            self.gamma_1 = nn.Parameter(
                init_values * torch.ones((dim)), requires_grad=True
            )
            self.gamma_2 = nn.Parameter(
                init_values * torch.ones((dim)), requires_grad=True
            )
        else:
            self.gamma_1, self.gamma_2 = None, None

    def forward(self, x, rel_pos_bias=None, training_window_size=None):

        if self.gamma_1 is None:
            x = x + self.drop_path(
                self.attn(
                    self.norm1(x),
                    rel_pos_bias=rel_pos_bias,
                    training_window_size=training_window_size,
                )
            )
            x = x + self.drop_path(self.mlp(self.norm2(x)))
        else:
            x = x + self.drop_path(
                self.gamma_1
                * self.attn(
                    self.norm1(x),
                    rel_pos_bias=rel_pos_bias,
                    training_window_size=training_window_size,
                )
            )
            x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
        return x


class PatchEmbed(nn.Module):
    """Image to Patch Embedding"""

    def __init__(
        self, img_size=[224, 224], patch_size=16, in_chans=3, embed_dim=768, bias=True
    ):
        super().__init__()

        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
        self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
        self.num_patches_w = self.patch_shape[0]
        self.num_patches_h = self.patch_shape[1]
        # the so-called patch_shape is the patch shape during pre-training
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        self.proj = nn.Conv2d(
            in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias
        )

    def forward(self, x, position_embedding=None, **kwargs):

        # FIXME look at relaxing size constraints
        # assert H == self.img_size[0] and W == self.img_size[1], \
        #     f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x)
        Hp, Wp = x.shape[2], x.shape[3]

        if position_embedding is not None:
            # interpolate the position embedding to the corresponding size
            position_embedding = position_embedding.view(
                1, self.patch_shape[0], self.patch_shape[1], -1
            ).permute(0, 3, 1, 2)
            position_embedding = F.interpolate(
                position_embedding, size=(Hp, Wp), mode="bicubic"
            )
            x = x + position_embedding

        x = x.flatten(2).transpose(1, 2)
        return x, (Hp, Wp)


class BEiT(nn.Module):
    """Vision Transformer with support for patch or hybrid CNN input stage"""

    def __init__(
        self,
        img_size=[224, 224],
        patch_size=16,
        in_chans=3,
        grid_chans=64,
        num_classes=80,
        embed_dim=768,
        self_depth=7,
        cross_depth=5,
        num_heads=12,
        mlp_ratio=4.0,
        qkv_bias=False,
        qk_scale=None,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.0,
        hybrid_backbone=None,
        norm_layer=None,
        init_values=None,
        use_abs_pos_emb=False,
        use_rel_pos_bias=False,
        use_shared_rel_pos_bias=False,
        use_checkpoint=True,
        pretrained=None,
        out_features=None,
    ):

        super(BEiT, self).__init__()

        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
        self.num_classes = num_classes
        self.num_features = self.embed_dim = (
            embed_dim  # num_features for consistency with other models
        )
        self.use_checkpoint = use_checkpoint

        if hybrid_backbone is not None:
            self.patch_embed = HybridEmbed(
                hybrid_backbone,
                img_size=img_size,
                in_chans=in_chans,
                embed_dim=embed_dim,
            )
        else:
            self.patch_embed = PatchEmbed(
                img_size=img_size,
                patch_size=patch_size,
                in_chans=in_chans,
                embed_dim=embed_dim,
            )
            self.grid_patch_embed = PatchEmbed(
                img_size=img_size,
                patch_size=patch_size,
                in_chans=grid_chans,
                embed_dim=embed_dim,
                bias=True,
            )
        num_patches = self.patch_embed.num_patches
        self.out_features = out_features
        self.out_indices = [int(name[5:]) for name in out_features]

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.grid_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        if use_abs_pos_emb:
            self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
            self.grid_pos_embed = nn.Parameter(
                torch.zeros(1, num_patches + 1, embed_dim)
            )
        else:
            self.pos_embed = None
            self.grid_pos_embed = None
        self.pos_drop = nn.Dropout(p=drop_rate)

        self.use_shared_rel_pos_bias = use_shared_rel_pos_bias
        if use_shared_rel_pos_bias:
            self.rel_pos_bias = RelativePositionBias(
                window_size=self.patch_embed.patch_shape, num_heads=num_heads
            )
        else:
            self.rel_pos_bias = None

        dpr = [
            x.item()
            for x in torch.linspace(0, drop_path_rate, self_depth + cross_depth)
        ]  # stochastic depth decay rule
        self.use_rel_pos_bias = use_rel_pos_bias
        self.blocks = nn.ModuleList(
            [
                Block(
                    dim=embed_dim,
                    num_heads=num_heads,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    qk_scale=qk_scale,
                    drop=drop_rate,
                    attn_drop=attn_drop_rate,
                    drop_path=dpr[i],
                    norm_layer=norm_layer,
                    init_values=init_values,
                    window_size=(
                        self.patch_embed.patch_shape if use_rel_pos_bias else None
                    ),
                )
                for i in range(self_depth)
            ]
        )

        self.grid_blocks = nn.ModuleList(
            [
                Block(
                    dim=embed_dim,
                    num_heads=num_heads,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    qk_scale=qk_scale,
                    drop=drop_rate,
                    attn_drop=attn_drop_rate,
                    drop_path=dpr[i],
                    norm_layer=norm_layer,
                    init_values=init_values,
                    window_size=(
                        self.patch_embed.patch_shape if use_rel_pos_bias else None
                    ),
                )
                for i in range(self_depth)
            ]
        )

        self.cross_blocks = nn.ModuleList(
            [
                CrossBlock(
                    dim=embed_dim,
                    num_heads=num_heads,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    qk_scale=qk_scale,
                    drop=drop_rate,
                    attn_drop=attn_drop_rate,
                    drop_path=dpr[i + self_depth],
                    norm_layer=norm_layer,
                    init_values=init_values,
                    window_size=(
                        self.patch_embed.patch_shape if use_rel_pos_bias else None
                    ),
                )
                for i in range(cross_depth)
            ]
        )

        # trunc_normal_(self.mask_token, std=.02)

        if patch_size == 16:
            self.fpn1 = nn.Sequential(
                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
                # nn.SyncBatchNorm(embed_dim),
                nn.BatchNorm2d(embed_dim),
                nn.GELU(),
                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
            )
            self.fpn2 = nn.Sequential(
                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
            )
            self.fpn3 = nn.Identity()
            self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)

            self.grid_fpn1 = nn.Sequential(
                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
                # nn.SyncBatchNorm(embed_dim),
                nn.BatchNorm2d(embed_dim),
                nn.GELU(),
                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
            )
            self.grid_fpn2 = nn.Sequential(
                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
            )
            self.grid_fpn3 = nn.Identity()
            self.grid_fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)

        elif patch_size == 8:
            self.fpn1 = nn.Sequential(
                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
            )
            self.fpn2 = nn.Identity()
            self.fpn3 = nn.Sequential(
                nn.MaxPool2d(kernel_size=2, stride=2),
            )
            self.fpn4 = nn.Sequential(
                nn.MaxPool2d(kernel_size=4, stride=4),
            )

            self.grid_fpn1 = nn.Sequential(
                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
            )
            self.grid_fpn2 = nn.Identity()
            self.grid_fpn3 = nn.Sequential(
                nn.MaxPool2d(kernel_size=2, stride=2),
            )
            self.grid_fpn4 = nn.Sequential(
                nn.MaxPool2d(kernel_size=4, stride=4),
            )

        if self.pos_embed is not None:
            trunc_normal_(self.pos_embed, std=0.02)
            trunc_normal_(self.grid_pos_embed, std=0.02)
        trunc_normal_(self.cls_token, std=0.02)
        trunc_normal_(self.grid_token, std=0.02)
        self.apply(self._init_weights)
        self.fix_init_weight()

    def fix_init_weight(self):

        def rescale(param, layer_id):

            param.div_(math.sqrt(2.0 * layer_id))

        for layer_id, layer in enumerate(self.blocks):
            rescale(layer.attn.proj.weight.data, layer_id + 1)
            rescale(layer.mlp.fc2.weight.data, layer_id + 1)

    def _init_weights(self, m):

        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    '''
    def init_weights(self):
        """Initialize the weights in backbone.

        Args:
            pretrained (str, optional): Path to pre-trained weights.
                Defaults to None.
        """
        logger = get_root_logger()

        if self.pos_embed is not None:
            trunc_normal_(self.pos_embed, std=.02)
        trunc_normal_(self.cls_token, std=.02)
        self.apply(self._init_weights)
        self.fix_init_weight()

        if self.init_cfg is None:
            logger.warn(f'No pre-trained weights for '
                        f'{self.__class__.__name__}, '
                        f'training start from scratch')
        else:
            assert 'checkpoint' in self.init_cfg, f'Only support ' \
                                                  f'specify `Pretrained` in ' \
                                                  f'`init_cfg` in ' \
                                                  f'{self.__class__.__name__} '
            logger.info(f"Will load ckpt from {self.init_cfg['checkpoint']}")
            load_checkpoint(self,
                            filename=self.init_cfg['checkpoint'],
                            strict=False,
                            logger=logger,
                            beit_spec_expand_rel_pos = self.use_rel_pos_bias,
                            )
    '''

    def get_num_layers(self):

        return len(self.blocks)

    @torch.jit.ignore
    def no_weight_decay(self):

        return {"pos_embed", "cls_token"}

    def forward_features(self, x, grid):
        B, C, H, W = x.shape

        vis_x, (Hp, Wp) = self.patch_embed(
            x, self.pos_embed[:, 1:, :] if self.pos_embed is not None else None
        )

        grid_x, (grid_Hp, grid_Wp) = self.grid_patch_embed(
            grid,
            self.grid_pos_embed[:, 1:, :] if self.grid_pos_embed is not None else None,
        )

        # Hp, Wp are HW for patches
        batch_size, seq_len, _ = grid_x.size()

        cls_tokens = self.cls_token.expand(
            batch_size, -1, -1
        )  # stole cls_tokens impl from Phil Wang, thanks

        grid_tokens = self.grid_token.expand(
            batch_size, -1, -1
        )  # stole cls_tokens impl from Phil Wang, thanks

        if self.pos_embed is not None:
            cls_tokens = cls_tokens + self.pos_embed[:, :1, :]
            grid_tokens = grid_tokens + self.grid_pos_embed[:, :1, :]
        vis_x = torch.cat((cls_tokens, vis_x), dim=1)
        vis_x = self.pos_drop(vis_x)

        grid_x = torch.cat((grid_tokens, grid_x), dim=1)
        grid_x = self.pos_drop(grid_x)

        features = []
        grid_features = []
        training_window_size = torch.tensor([Hp, Wp])
        grid_training_window_size = torch.tensor([grid_Hp, grid_Wp])

        rel_pos_bias = (
            self.rel_pos_bias(training_window_size)
            if self.rel_pos_bias is not None
            else None
        )

        for i, blk in enumerate(self.blocks):
            if self.use_checkpoint:
                # vis_x = checkpoint.checkpoint(
                #     blk, vis_x, rel_pos_bias, training_window_size, use_reentrant=False
                # )
                vis_x = checkpoint.checkpoint(
                    blk, vis_x, rel_pos_bias, training_window_size
                )
            else:
                vis_x = blk(
                    vis_x,
                    rel_pos_bias=rel_pos_bias,
                    training_window_size=training_window_size,
                )
            if i in self.out_indices:
                xp = vis_x[:, 1:, :].permute(0, 2, 1).reshape(B, -1, Hp, Wp)
                features.append(xp.contiguous())

        for i, grid_blk in enumerate(self.grid_blocks):
            if self.use_checkpoint:
                # grid_x = checkpoint.checkpoint(
                #     grid_blk,
                #     grid_x,
                #     rel_pos_bias,
                #     grid_training_window_size,
                #     use_reentrant=False,
                # )
                grid_x = checkpoint.checkpoint(
                    grid_blk,
                    grid_x,
                    rel_pos_bias,
                    grid_training_window_size
                )
            else:
                grid_x = grid_blk(
                    grid_x,
                    rel_pos_bias=rel_pos_bias,
                    training_window_size=grid_training_window_size,
                )
            if i in self.out_indices:
                gp = grid_x[:, 1:, :].permute(0, 2, 1).reshape(B, -1, grid_Hp, grid_Wp)
                grid_features.append(gp.contiguous())

        # import ipdb;ipdb.set_trace()
        for i, cross_blk in enumerate(self.cross_blocks):
            if self.use_checkpoint:
                # vis_x, grid_x = checkpoint.checkpoint(
                #     cross_blk, vis_x, grid_x, use_reentrant=False
                # )
                vis_x, grid_x = checkpoint.checkpoint(
                    cross_blk, vis_x, grid_x
                )
            else:
                vis_x, grid_x = cross_blk(vis_input=vis_x, grid_input=grid_x)

            if 1:
                xp = vis_x[:, 1:, :].permute(0, 2, 1).reshape(B, -1, Hp, Wp)
                features.append(xp.contiguous())

                gp = grid_x[:, 1:, :].permute(0, 2, 1).reshape(B, -1, grid_Hp, grid_Wp)
                grid_features.append(gp.contiguous())

        ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
        grid_ops = [self.grid_fpn1, self.grid_fpn2, self.grid_fpn3, self.grid_fpn4]

        for i in range(len(features)):
            features[i] = ops[i](features[i])

        for i in range(len(grid_features)):
            grid_features[i] = grid_ops[i](grid_features[i])

        feat_out = {}
        grid_feat_out = {}

        for name, vis_value, grid_value in zip(
            self.out_features, features, grid_features
        ):
            feat_out[name] = vis_value
            grid_feat_out[name] = grid_value

        return feat_out, grid_feat_out

    def forward(self, x, grid):

        x, y = self.forward_features(x, grid)
        return x, y


def VGT_dit_base_patch16(pretrained=False, **kwargs):

    model = BEiT(
        patch_size=16,
        embed_dim=768,
        self_depth=12,
        cross_depth=0,
        num_heads=12,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        init_values=0.1,
        in_chans=3,
        grid_chans=64,
        **kwargs,
    )
    model.default_cfg = _cfg()
    return model
