from monai.inferers import SlidingWindowInferer
from monai.transforms import (
    Activationsd,
    AddChanneld,
    AsDiscreted,
    LoadImaged,
    NormalizeIntensityd,
    Orientationd,
    Spacingd,
    SqueezeDimd,
    ToNumpyd,
    ToTensord,
)

from monailabel.interfaces.tasks import InferTask, InferType
from monailabel.utils.others.post import BoundingBoxd, Restored


class MyInfer(InferTask):
    """
    This provides Inference Engine for pre-trained highresnet 3D.
    """

    def __init__(
        self,
        path,
        network=None,
        type=InferType.SEGMENTATION,
        labels=(
            "Non-ventricular",
            "3rd-Ventricle",
            "4th-Ventricle",
            "5th-Ventricle",
            "Right-Accumbens-Area",
            "Left-Accumbens-Area",
            "Right-Amygdala",
            "Left-Amygdala",
            "Pons",
            "Brain-Stem",
            "Right-Caudate",
            "Left-Caudate",
            "Right-Cerebellum-Exterior",
            "Left-Cerebellum-Exterior",
            "Right-Cerebellum",
            "Left-Cerebellum",
            "3rd-Ventricle-(Posterior-part)",
            "Right-Hippocampus",
            "Left-Hippocampus",
            "Right-Inf-Lat-Vent",
            "Left-Inf-Lat-Vent",
            "Right-Lateral-Ventricle",
            "Left-Lateral-Ventricle",
            "Right-Pallidum",
            "Left-Pallidum",
            "Right-Putamen",
            "Left-Putamen",
            "Right-Thalamus-Proper",
            "Left-Thalamus-Proper",
            "Right-Ventral-DC",
            "Left-Ventral-DC",
            "Right-vessel",
            "Left-vessel",
            "Right-periventricular-white-matter",
            "Left-periventricular-white-matter",
            "Optic-Chiasm",
            "Cerebellar-Vermal-Lobules-I-V",
            "Cerebellar-Vermal-Lobules-VI-VII",
            "Cerebellar-Vermal-Lobules-VIII-X",
            "Left-Basal-Forebrain",
            "Right-Basal-Forebrain",
            "Right-Temporal-White-Matter",
            "Right-Insula-White-Matter",
            "Right-Cingulate-White-Matter",
            "Right-Frontal-White-Matter",
            "Right-Occipital-White-Matter",
            "Right-Parietal-White-Matter",
            "Corpus-Callosum",
            "Left-Temporal-White-Matter",
            "Left-Insula-White-Matter",
            "Left-Cingulate-White-Matter",
            "Left-Frontal-White-Matter",
            "Left-Occipital-White-Matter",
            "Left-Parietal-White-Matter",
            "Right-Claustrum",
            "Left-Claustrum",
            "Right-ACgG-anterior-cingulate-gyrus",
            "Left-ACgG-anterior-cingulate-gyrus",
            "Right-AIns-anterior-insula",
            "Left-AIns-anterior-insula",
            "Right-AOrG-anterior-orbital-gyrus",
            "Left-AOrG-anterior-orbital-gyrus",
            "Right-AnG-angular-gyrus",
            "Left-AnG-angular-gyrus",
            "Right-Calc-calcarine-cortex",
            "Left-Calc-calcarine-cortex",
            "Right-CO-central-operculum",
            "Left-CO-central-operculum",
            "Right-Cun-cuneus",
            "Left-Cun-cuneus",
            "Right-Ent-entorhinal-area",
            "Left-Ent-entorhinal-area",
            "Right-FO-frontal-operculum",
            "Left-FO-frontal-operculum",
            "Right-FRP-frontal-pole",
            "Left-FRP-frontal-pole",
            "Right-FuG-fusiform-gyrus",
            "Left-FuG-fusiform-gyrus",
            "Right-GRe-gyrus-rectus",
            "Left-GRe-gyrus-rectus",
            "Right-IOG-inferior-occipital-gyrus",
            "Left-IOG-inferior-occipital-gyrus",
            "Right-ITG-inferior-temporal-gyrus",
            "Left-ITG-inferior-temporal-gyrus",
            "Right-LiG-lingual-gyrus",
            "Left-LiG-lingual-gyrus",
            "Right-LOrG-lateral-orbital-gyrus",
            "Left-LOrG-lateral-orbital-gyrus",
            "Right-MCgG-middle-cingulate-gyrus",
            "Left-MCgG-middle-cingulate-gyrus",
            "Right-MFC-medial-frontal-cortex",
            "Left-MFC-medial-frontal-cortex",
            "Right-MFG-middle-frontal-gyrus",
            "Left-MFG-middle-frontal-gyrus",
            "Right-MOG-middle-occipital-gyrus",
            "Left-MOG-middle-occipital-gyrus",
            "Right-MOrG-medial-orbital-gyrus",
            "Left-MOrG-medial-orbital-gyrus",
            "Right-MPoG-postcentral-gyrus-medial-segment",
            "Left-MPoG-postcentral-gyrus-medial-segment",
            "Right-MPrG-precentral-gyrus-medial-segment",
            "Left-MPrG-precentral-gyrus-medial-segment",
            "Right-MSFG-superior-frontal-gyrus-medial-segment",
            "Left-MSFG-superior-frontal-gyrus-medial-segment",
            "Right-MTG-middle-temporal-gyrus",
            "Left-MTG-middle-temporal-gyrus",
            "Right-OCP-occipital-pole",
            "Left-OCP-occipital-pole",
            "Right-OFuG-occipital-fusiform-gyrus",
            "Left-OFuG-occipital-fusiform-gyrus",
            "Right-OpIFG-opercular-part-of-the-inferior-frontal-gyrus",
            "Left-OpIFG-opercular-part-of-the-inferior-frontal-gyrus",
            "Right-OrIFG-orbital-part-of-the-inferior-frontal-gyrus",
            "Left-OrIFG-orbital-part-of-the-inferior-frontal-gyrus",
            "Right-PCgG-posterior-cingulate-gyrus",
            "Left-PCgG-posterior-cingulate-gyrus",
            "Right-PCu-precuneus",
            "Left-PCu-precuneus",
            "Right-PHG-parahippocampal-gyrus",
            "Left-PHG-parahippocampal-gyrus",
            "Right-PIns-posterior-insula",
            "Left-PIns-posterior-insula",
            "Right-PO-parietal-operculum",
            "Left-PO-parietal-operculum",
            "Right-PoG-postcentral-gyrus",
            "Left-PoG-postcentral-gyrus",
            "Right-POrG-posterior-orbital-gyrus",
            "Left-POrG-posterior-orbital-gyrus",
            "Right-PP-planum-polare",
            "Left-PP-planum-polare",
            "Right-PrG-precentral-gyrus",
            "Left-PrG-precentral-gyrus",
            "Right-PT-planum-temporale",
            "Left-PT-planum-temporale",
            "Right-SCA-subcallosal-area",
            "Left-SCA-subcallosal-area",
            "Right-SFG-superior-frontal-gyrus",
            "Left-SFG-superior-frontal-gyrus",
            "Right-SMC-supplementary-motor-cortex",
            "Left-SMC-supplementary-motor-cortex",
            "Right-SMG-supramarginal-gyrus",
            "Left-SMG-supramarginal-gyrus",
            "Right-SOG-superior-occipital-gyrus",
            "Left-SOG-superior-occipital-gyrus",
            "Right-SPL-superior-parietal-lobule",
            "Left-SPL-superior-parietal-lobule",
            "Right-STG-superior-temporal-gyrus",
            "Left-STG-superior-temporal-gyrus",
            "Right-TMP-temporal-pole",
            "Left-TMP-temporal-pole",
            "Right-TrIFG-triangular-part-of-the-inferior-frontal-gyrus",
            "Left-TrIFG-triangular-part-of-the-inferior-frontal-gyrus",
            "Right-TTG-transverse-temporal-gyrus",
            "Left-TTG-transverse-temporal-gyrus",
        ),
        dimension=3,
        description="A pre-trained model for volumetric (3D) segmentation",
    ):
        super().__init__(
            path=path,
            network=network,
            type=type,
            labels=labels,
            dimension=dimension,
            description=description,
        )

    def pre_transforms(self):
        return [
            LoadImaged(keys="image"),
            AddChanneld(keys="image"),
            Orientationd(keys=["image"], axcodes="RAS"),
            Spacingd(keys="image", pixdim=[1.0, 1.0, 1.0]),
            NormalizeIntensityd(keys="image"),
            ToTensord(keys=["image"]),
        ]

    def inferer(self):
        return SlidingWindowInferer(roi_size=[128, 128, 128])

    def post_transforms(self):
        return [
            AddChanneld(keys="pred"),
            Activationsd(keys="pred", softmax=True),
            AsDiscreted(keys="pred", argmax=True),
            SqueezeDimd(keys="pred", dim=0),
            ToNumpyd(keys="pred"),
            Restored(keys="pred", ref_image="image"),
            BoundingBoxd(keys="pred", result="result", bbox="bbox"),
        ]
