import torch
import torch.nn.functional as F

def cus_sample(feat, **kwargs):
    assert len(kwargs.keys()) == 1 and list(kwargs.keys())[0] in ["size", "scale_factor"]
    return F.interpolate(feat, mode="bilinear", align_corners=True,**kwargs)


def upsample_add(*xs):
    y = xs[-1]
    for x in xs[:-1]:
        y = y + F.interpolate(x, size=y.size()[2:], mode="bilinear", align_corners=False)
    return y


def upsample_cat(*xs):
    y = xs[-1]
    out = []
    for x in xs[:-1]:
        out.append(F.interpolate(x, size=y.size()[2:], mode="bilinear", align_corners=False))
    return torch.cat([out, y], dim=1)


def upsample_reduce(b, a):
    _, C, _, _ = b.size()
    N, _, H, W = a.size()

    b = F.interpolate(b, size=(H, W), mode="bilinear", align_corners=False)
    a = a.reshape(N, -1, C, H, W).mean(1)

    return b + a


def shuffle_channels(x, groups):
    N, C, H, W = x.size()
    x = x.reshape(N, groups, C // groups, H, W).permute(0, 2, 1, 3, 4)
    return x.reshape(N, C, H, W)