import time
import yaml
from pathlib import Path
import json
import SimpleITK as sitk
from wsireg.reg_image import RegImage, TransformRegImage
from wsireg.utils.reg_utils import (
    register_2d_images_itkelx,
    sitk_pmap_to_dict,
    pmap_dict_to_json,
    json_to_pmap_dict,
)
from wsireg.reg_shapes import RegShapes
from wsireg.utils.config_utils import parse_check_reg_config


class WsiReg2D(object):
    """
    Class to define a 2D registration graph and execute the registrations and transformations of the graph

    Parameters
    ----------
    project_name: str
        Project name will prefix all output files and directories
    output_dir: str
        Directory where registration data will be stored
    cache_images: bool
        whether to store images as they are preprocessed for registration (if you need to repeat or modify settings
        this will avoid image io and preprocessing)

    Attributes
    ----------
    modalities: dict
        dictionary of modality information (file path, spatial res., preprocessing), defines a graph node

    modalities: list
        list of all modality names

    n_modalities: int
        number of modalities (nodes) in the graphs

    reg_paths: dict
        dictionary of a modalities path from node to node

    reg_graph_edges: dict
        generated dictionary of necessary registrations to move modalities to their target modality

    n_registrations: int
        number of explicit registrations (edges) in the graphs

    transformation_paths: dict
        generated dictionary of necessary source - target transformations to transform modalities to their target modality

    transformations: dict
        per modality dictionary containing transformation parameters for each registration

    attachment_images: dict
        images to be transformed along the path of the defined graph, assoicated to a given modality (masks, other registered images)

    attachment_shapes: dict
        shape data attached to a modality to be transformed along the graph


    """

    def __init__(
        self,
        project_name: str,
        output_dir: str,
        cache_images=True,
        reload_data=False,
        config=None,
    ):

        if project_name is None:
            self.project_name = 'RegProj'
        else:
            self.project_name = project_name

        if output_dir is None:
            output_dir = "./"
        self.output_dir = Path(output_dir)
        self.image_cache = self.output_dir / ".imcache_{}".format(project_name)
        self.cache_images = cache_images

        self._modalities = {}
        self._modality_names = []
        self._reg_paths = {}
        self._reg_graph_edges = []
        self._transform_paths = {}

        self._transformations = None

        self.n_modalities = None
        self.n_registrations = None

        self.attachment_images = {}

        self._shape_sets = {}
        self._shape_set_names = []

        if config is not None:
            self.add_data_from_config(config)

    @property
    def modalities(self):
        return self._modalities

    @modalities.setter
    def modalities(self, modality):
        self._modalities.update(modality)
        self.n_modalities = len(self._modalities)

    @property
    def shape_sets(self):
        return self._shape_sets

    @shape_sets.setter
    def shape_sets(self, shape_set):
        self._shape_sets.update(shape_set)

    @property
    def shape_set_names(self):
        return self._shape_set_names

    @shape_set_names.setter
    def shape_set_names(self, shape_set_name):
        self._shape_set_names.append(shape_set_name)

    @property
    def modality_names(self):
        return self._modality_names

    @modality_names.setter
    def modality_names(self, modality_name):
        self._modality_names.append(modality_name)

    def add_modality(
        self,
        modality_name,
        image_fp,
        image_res=1,
        channel_names=None,
        channel_colors=None,
        prepro_dict={},
        mask=None,
    ):
        """
        Add an image modality (node) to the registration graph

        Parameters
        ----------
        modality_name : str
            Unique name identifier for the modality
        image_fp : str
            file path to the image to be read
        image_res : float
            spatial resolution of image in units per px (i.e. 0.9 um / px)
        prepro_dict :
            preprocessing parameters for the modality for registration. Registration images should be a xy single plane
            so many modalities (multi-channel, RGB) must "create" a single channel.
            Defaults: multi-channel images -> max intensity project image
            RGB -> greyscale then intensity inversion (black background, white foreground)
        """
        if modality_name in self._modality_names:
            raise ValueError(
                'modality named \"{}\" is already in modality_names'.format(
                    modality_name
                )
            )

        self.modalities = {
            modality_name: {
                "image_filepath": image_fp,
                "image_res": image_res,
                "channel_names": channel_names,
                "channel_colors": channel_colors,
                "preprocessing": prepro_dict,
                "mask": mask,
            }
        }

        self.modality_names = modality_name

    def add_shape_set(
        self, attachment_modality, shape_set_name, shape_files, image_res
    ):
        """
        Add a shape set to the graph

        Parameters
        ----------
        shape_set_name : str
            Unique name identifier for the shape set
        shape_files : list
            list of shape data in geoJSON format or list of dicts containing following keys:
            "array" = np.ndarray of xy coordinates, "shape_type" = geojson shape type (Polygon, MultiPolygon, etc.),
            "shape_name" = class of the shape("tumor","normal",etc.)
        image_res : float
            spatial resolution of shape data's associated image in units per px (i.e. 0.9 um / px)
        attachment_modality :
            image modality to which the shapes are attached
        """
        if shape_set_name in self._shape_set_names:
            raise ValueError(
                'shape named \"{}\" is already in shape_set_names'.format(
                    shape_set_name
                )
            )

        self.shape_sets = {
            shape_set_name: {
                "shape_files": shape_files,
                "image_res": image_res,
                "attachment_modality": attachment_modality,
            }
        }

        self.shape_set_names = shape_set_name

    def add_attachment_images(
        self,
        attachment_modality,
        modality_name,
        image_fp,
        image_res=1,
        channel_names=None,
        channel_colors=None,
    ):
        """
        Images which are unregistered between modalities, but are transformed following the path of one of the graph's
        modalities.

        Parameters
        ----------
        attachment_modality : str
            image modality to which the new image are attached
        modality_name : str
            name of the added attachment image
        image_fp : str
            path to the attachment modality, it will be imported and transformed without preprocessing
        image_res : float
            spatial resolution of attachment image data's in units per px (i.e. 0.9 um / px)
        """
        if attachment_modality not in self.modality_names:
            raise ValueError(
                'attachment modality named \"{}\" not found in modality_names'.format(
                    attachment_modality
                )
            )
        self.add_modality(
            modality_name,
            image_fp,
            image_res,
            channel_names=channel_names,
            channel_colors=channel_colors,
        )
        self.attachment_images[modality_name] = attachment_modality

    def add_attachment_shapes(
        self, attachment_modality, shape_set_name, shape_files
    ):
        if attachment_modality not in self.modality_names:
            raise ValueError(
                'attachment modality \"{}\" for shapes \"{}\" not found in modality_names {}'.format(
                    attachment_modality, shape_set_name, self.modality_names
                )
            )

        image_res = self.modalities[attachment_modality]["image_res"]
        self.add_shape_set(
            attachment_modality, shape_set_name, shape_files, image_res
        )

    @property
    def reg_paths(self):
        return self._reg_paths

    @reg_paths.setter
    def reg_paths(self, path_values):
        (
            src_modality,
            tgt_modality,
            thru_modality,
            reg_params,
            override_prepro,
        ) = path_values

        if thru_modality != tgt_modality:
            self._reg_paths.update(
                {src_modality: [thru_modality, tgt_modality]}
            )
        else:
            self._reg_paths.update({src_modality: [tgt_modality]})

        self.reg_graph_edges = {
            'modalities': {'source': src_modality, 'target': thru_modality},
            'params': reg_params,
            "override_prepro": override_prepro,
        }
        self.transform_paths = self._reg_paths

    def add_reg_path(
        self,
        src_modality_name,
        tgt_modality_name,
        thru_modality=None,
        reg_params=[],
        override_prepro={"source": None, "target": None},
    ):
        """
        Add registration path between modalities as well as a thru modality that describes where to attach edges.

        Parameters
        ----------
        src_modality_name : str
            modality that has been added to graph to be transformed to tgt_modality
        tgt_modality_name : str
            modality that has been added to graph that is being aligned to
        thru_modality: str
            modality that has been added to graph by which another should be run through
        reg_params:
            SimpleElastix registration parameters
        override_prepro:
            set specific preprocessing for a given registration edge for the source or target image that will override
            the set modality preprocessing FOR THIS REGISTRATION ONLY.
        """
        if src_modality_name not in self.modality_names:
            raise ValueError("source modality not found!")
        if tgt_modality_name not in self.modality_names:
            raise ValueError("target modality not found!")

        if thru_modality is None:
            self.reg_paths = (
                src_modality_name,
                tgt_modality_name,
                tgt_modality_name,
                reg_params,
                override_prepro,
            )
        else:
            self.reg_paths = (
                src_modality_name,
                tgt_modality_name,
                thru_modality,
                reg_params,
                override_prepro,
            )

    @property
    def reg_graph_edges(self):
        return self._reg_graph_edges

    @reg_graph_edges.setter
    def reg_graph_edges(self, edge):
        self._reg_graph_edges.append(edge)
        self.n_registrations = len(self._reg_graph_edges)

    @property
    def transform_paths(self):
        return self._transform_paths

    @transform_paths.setter
    def transform_paths(self, reg_paths):

        transform_path_dict = {}

        for k, v in reg_paths.items():
            tform_path = self.find_path(k, v[-1])
            transform_edges = []
            for modality in tform_path:
                for edge in self.reg_graph_edges:
                    edge_modality = edge["modalities"]['source']
                    if modality == edge_modality:
                        transform_edges.append(edge["modalities"])
            transform_path_dict.update({k: transform_edges})

        self._transform_paths = transform_path_dict

    def find_path(self, start_modality, end_modality, path=None):
        """
        Find a path from start_modality to end_modality in the graph
        """
        if path is None:
            path = []
        path = path + [start_modality]
        if start_modality == end_modality:
            return path
        if start_modality not in self.reg_paths:
            return None
        for modality in self.reg_paths[start_modality]:
            if modality not in path:
                extended_path = self.find_path(modality, end_modality, path)
                if extended_path:
                    return extended_path
        return None

    def _check_cache_modality(self, modality_name):
        cache_im_fp = self.image_cache / "{}_prepro.tiff".format(modality_name)
        cache_transform_fp = cache_im_fp.parent / "{}_init_tforms.json".format(
            cache_im_fp.stem
        )

        if cache_im_fp.exists() is True:
            im_fp = str(cache_im_fp)
            im_from_cache = True
        else:
            im_fp = self.modalities[modality_name]["image_filepath"]
            im_from_cache = False

        if cache_transform_fp.exists() is True:
            im_initial_transforms = [json_to_pmap_dict(cache_transform_fp)]
        else:
            im_initial_transforms = None

        return im_fp, im_initial_transforms, im_from_cache

    def _prepare_modality(self, modality_name, reg_edge, src_or_tgt):
        mod_data = self.modalities[modality_name].copy()

        if reg_edge.get("override_prepro") is not None:
            override_preprocessing = reg_edge.get("override_prepro")[
                src_or_tgt
            ]
        else:
            override_preprocessing = None

        if override_preprocessing is not None:
            mod_data["preprocessing"] = override_preprocessing

            return (
                mod_data["image_filepath"],
                mod_data["image_res"],
                mod_data["preprocessing"],
                None,
                mod_data["mask"],
            )
        else:

            (
                mod_data["image_filepath"],
                mod_data["transforms"],
                im_from_cache,
            ) = self._check_cache_modality(modality_name)

            if im_from_cache is True:
                mod_data["preprocessing"] = None

            return (
                mod_data["image_filepath"],
                mod_data["image_res"],
                mod_data["preprocessing"],
                mod_data["transforms"],
                mod_data["mask"],
            )

    def _cache_images(self, modality_name, reg_image):

        cache_im_fp = self.image_cache / "{}_prepro.tiff".format(modality_name)

        cache_transform_fp = self.image_cache / "{}_init_tforms.json".format(
            cache_im_fp.stem
        )

        if cache_im_fp.is_file() is False:
            sitk.WriteImage(reg_image.image, str(cache_im_fp), True)

        if reg_image.mask is not None:
            cache_mask_im_fp = self.image_cache / "{}_prepro_mask.tiff".format(
                modality_name
            )
            if cache_mask_im_fp.is_file() is False:
                sitk.WriteImage(reg_image.mask, str(cache_mask_im_fp), True)

        if cache_transform_fp.is_file() is False:
            pmap_dict_to_json(reg_image.transforms, str(cache_transform_fp))

    def _find_nonreg_modalities(self):
        registered_modalities = [
            edge.get("modalities").get("source")
            for edge in self.reg_graph_edges
        ]
        return list(set(self.modality_names).difference(registered_modalities))

    def save_config(self, registered=False):
        ts = time.strftime('%Y%m%d-%H%M%S')
        status = "registered" if registered is True else "setup"

        reg_paths = {}
        for idx, edge in enumerate(self.reg_graph_edges):
            src_modality = edge.get("modalities").get("source")
            if len(self.reg_paths[src_modality]) > 1:
                thru_modality = self.reg_paths[src_modality][0]
            else:
                thru_modality = None
            tgt_modality = self.reg_paths[src_modality][-1]
            reg_paths.update(
                {
                    f"reg_path_{idx}": {
                        "src_modality_name": edge.get("modalities").get(
                            "source"
                        ),
                        "tgt_modality_name": tgt_modality,
                        "thru_modality": thru_modality,
                        "reg_params": edge.get("params"),
                    }
                }
            )

        config = {
            "project_name": self.project_name,
            "output_dir": str(self.output_dir),
            "cache_images": self.cache_images,
            "modalities": self.modalities,
            "reg_paths": reg_paths,
            "reg_graph_edges": self.reg_graph_edges
            if status == "registered"
            else None,
        }

        output_path = (
            self.output_dir
            / f"{ts}-{self.project_name}-configuration-{status}.yaml"
        )

        with open(str(output_path), "w") as f:
            yaml.dump(config, f, sort_keys=False)

    def register_images(self, parallel=False):
        """
        Start image registration process for all modalities

        Parameters
        ----------
        parallel : bool
            whether to run each edge in parallel (not implemented yet)
        """
        if self.cache_images is True:
            self.image_cache.mkdir(parents=False, exist_ok=True)

        self.save_config(registered=False)

        for reg_edge in self.reg_graph_edges:
            if (
                reg_edge.get("registered") is None
                or reg_edge.get("registered") is False
            ):
                src_name = reg_edge["modalities"]["source"]
                tgt_name = reg_edge["modalities"]["target"]

                (
                    src_reg_image_fp,
                    src_res,
                    src_prepro,
                    src_transforms,
                    src_mask,
                ) = self._prepare_modality(src_name, reg_edge, "source")

                (
                    tgt_reg_image_fp,
                    tgt_res,
                    tgt_prepro,
                    tgt_transforms,
                    tgt_mask,
                ) = self._prepare_modality(tgt_name, reg_edge, "target")

                src_reg_image = RegImage(
                    src_reg_image_fp,
                    src_res,
                    src_prepro,
                    src_transforms,
                    mask=src_mask,
                )

                tgt_reg_image = RegImage(
                    tgt_reg_image_fp,
                    tgt_res,
                    tgt_prepro,
                    tgt_transforms,
                    mask=tgt_mask,
                )

                if self.cache_images is True:
                    if reg_edge.get("override_prepro") is not None:
                        if (
                            reg_edge.get("override_prepro").get("source")
                            is None
                        ):
                            self._cache_images(src_name, src_reg_image)
                        if (
                            reg_edge.get("override_prepro").get("target")
                            is None
                        ):
                            self._cache_images(tgt_name, tgt_reg_image)
                    else:
                        self._cache_images(src_name, src_reg_image)
                        self._cache_images(tgt_name, tgt_reg_image)

                reg_params = reg_edge["params"]

                output_path = (
                    self.output_dir
                    / "{}-{}_to_{}_reg_output".format(
                        self.project_name,
                        reg_edge["modalities"]["source"],
                        reg_edge["modalities"]["target"],
                    )
                )

                output_path.mkdir(parents=False, exist_ok=True)

                output_path_tform = (
                    self.output_dir
                    / "{}-{}_to_{}_transformations.json".format(
                        self.project_name,
                        reg_edge["modalities"]["source"],
                        reg_edge["modalities"]["target"],
                    )
                )

                reg_tforms = register_2d_images_itkelx(
                    src_reg_image,
                    tgt_reg_image,
                    reg_params,
                    output_path,
                )

                reg_tforms = [sitk_pmap_to_dict(tf) for tf in reg_tforms]
                if src_transforms is not None:
                    initial_transforms = src_reg_image.transforms[0]
                else:
                    initial_transforms = src_reg_image.transforms

                reg_edge["transforms"] = {
                    'initial': initial_transforms,
                    'registration': reg_tforms,
                }

                reg_edge["registered"] = True
                pmap_dict_to_json(
                    reg_edge["transforms"], str(output_path_tform)
                )

        self.transformations = self.reg_graph_edges
        self.save_config(registered=True)

    @property
    def transformations(self):
        return self._transformations

    @transformations.setter
    def transformations(self, reg_graph_edges):
        self._transformations = self._collate_transformations(reg_graph_edges)

    def _collate_transformations(self, reg_graph_edges):
        transforms = {}
        edge_modality_pairs = [v['modalities'] for v in self.reg_graph_edges]
        for modality, tform_edges in self.transform_paths.items():
            for idx, tform_edge in enumerate(tform_edges):
                reg_edge_tforms = self.reg_graph_edges[
                    edge_modality_pairs.index(tform_edge)
                ]["transforms"]
                if idx == 0:
                    transforms[modality] = {
                        'initial': reg_edge_tforms['initial'],
                        idx: reg_edge_tforms['registration'],
                    }
                else:
                    transforms[modality][idx] = reg_edge_tforms['registration']

        return transforms

    def _transform_nonreg_image(self, modality_key, file_writer="ome.tiff"):
        print(
            "transforming non-registered modality : {} ".format(modality_key)
        )
        output_path = self.output_dir / "{}-{}_registered".format(
            self.project_name, modality_key
        )
        im_data = self.modalities[modality_key]

        if (
            im_data.get("preprocessing").get("rot_cc") is not None
            or im_data.get("preprocessing").get("flip") is not None
        ):
            transformations = {
                "initial": self._check_cache_modality(modality_key)[1][0],
                "registered": None,
            }
        else:
            transformations = None

        tfregimage = TransformRegImage(
            output_path.stem,
            im_data["image_filepath"],
            im_data["image_res"],
            transform_data=transformations,
            channel_names=im_data.get("channel_names"),
            channel_colors=im_data.get("channel_colors"),
        )
        im_fp = tfregimage.transform_image(
            str(output_path.parent), output_type=file_writer, tile_size=512
        )
        return im_fp

    def _transform_image(
        self,
        edge_key,
        file_writer="ome.tiff",
        attachment=False,
        attachment_modality=None,
    ):
        im_data = self.modalities[edge_key]

        if attachment is True:
            final_modality = self.reg_paths[attachment_modality][-1]
            transformations = self.transformations[attachment_modality]
        else:
            final_modality = self.reg_paths[edge_key][-1]
            transformations = self.transformations[edge_key]

        print("transforming {} to {}".format(edge_key, final_modality))

        output_path = self.output_dir / "{}-{}_to_{}_registered".format(
            self.project_name,
            edge_key,
            final_modality,
        )
        tfregimage = TransformRegImage(
            output_path.stem,
            im_data["image_filepath"],
            im_data["image_res"],
            transform_data=transformations,
            channel_names=im_data.get("channel_names"),
            channel_colors=im_data.get("channel_colors"),
        )
        im_fp = tfregimage.transform_image(
            str(self.output_dir), output_type=file_writer, tile_size=512
        )

        return im_fp

    def transform_images(self, file_writer="ome.tiff", transform_non_reg=True):
        """
        Transform and write images to disk after registration. Also transforms all attachment images

        Parameters
        ----------
        file_writer : str
            output type to use, sitk writes a single resolution tiff, "zarr" writes an ome-zarr multiscale
            zarr store
        """

        image_fps = []

        if all(
            [reg_edge.get("registered") for reg_edge in self.reg_graph_edges]
        ):
            for key in self.reg_paths.keys():
                self._transform_image(key, file_writer=file_writer)

            for (
                modality,
                attachment_modality,
            ) in self.attachment_images.items():

                im_fp = self._transform_image(
                    modality,
                    file_writer=file_writer,
                    attachment=True,
                    attachment_modality=attachment_modality,
                )
                image_fps.append(im_fp)
        if transform_non_reg is True:
            # preprocess and save unregistered nodes
            nonreg_keys = self._find_nonreg_modalities()

            for key in nonreg_keys:
                self._transform_nonreg_image(key, file_writer=file_writer)

        return image_fps

    def transform_shapes(self):
        """
        Transform all attached shapes and write out shape data to geoJSON.
        """
        for set_name, set_data in self.shape_sets.items():
            attachment_modality = set_data["attachment_modality"]

            final_modality = self.reg_paths[attachment_modality][-1]

            print(
                "transforming shape set {} associated with {} to {}".format(
                    set_name, attachment_modality, final_modality
                )
            )

            rs = RegShapes(
                set_data["shape_files"], source_res=set_data["image_res"]
            )
            rs.transform_shapes(
                self.transformations[attachment_modality],
            )

            output_path = (
                self.output_dir
                / "{}-{}-{}_to_{}-transformed_shapes.json".format(
                    self.project_name,
                    set_name,
                    attachment_modality,
                    final_modality,
                )
            )

            rs.save_shape_data(output_path, transformed=True)

    def save_transformations(self):
        """
        Save all transformations for a given modality as JSON
        """
        for key in self.reg_paths.keys():

            final_modality = self.reg_paths[key][-1]

            output_path = (
                self.output_dir
                / "{}-{}_to_{}_transformations.json".format(
                    self.project_name,
                    key,
                    final_modality,
                )
            )

            with open(output_path, 'w') as fp:
                json.dump(self.transformations[key], fp, indent=4)

        for (
            modality,
            attachment_modality,
        ) in self.attachment_images.items():

            final_modality = self.reg_paths[attachment_modality][-1]

            output_path = (
                self.output_dir
                / "{}-{}_to_{}_transformations.json".format(
                    self.project_name,
                    modality,
                    final_modality,
                )
            )

            with open(output_path, 'w') as fp:
                json.dump(self.transformations[key], fp, indent=4)

    def add_data_from_config(self, config_filepath):

        reg_config = parse_check_reg_config(config_filepath)

        if reg_config.get("modalities") is not None:
            for key, val in reg_config["modalities"].items():

                image_filepath = (
                    val.get("image_filepath")
                    if val.get("image_filepath") is not None
                    else val.get("image_filepath")
                )
                self.add_modality(
                    key,
                    image_filepath,
                    image_res=val.get("image_res"),
                    channel_names=val.get("channel_names"),
                    channel_colors=val.get("channel_colors"),
                    prepro_dict=val.get("preprocessing"),
                    mask=val.get("mask"),
                )
        else:
            print("warning: config file did not contain any image modalities")

        if reg_config.get("reg_paths") is not None:

            for key, val in reg_config["reg_paths"].items():
                self.add_reg_path(
                    val.get("src_modality_name"),
                    val.get("tgt_modality_name"),
                    val.get("thru_modality"),
                    reg_params=val.get("reg_params"),
                    override_prepro=val.get("override_prepro"),
                )
        else:
            print(
                "warning: config file did not contain any registration paths"
            )
        if reg_config.get("attachment_images") is not None:

            for key, val in reg_config["attachment_images"].items():
                self.add_attachment_images(
                    val.get("attachment_modality"),
                    key,
                    val.get("image_filepath"),
                    val.get("image_res"),
                    channel_names=val.get("channel_names"),
                    channel_colors=val.get("channel_colors"),
                )

        if reg_config.get("attachment_shapes") is not None:

            for key, val in reg_config["attachment_shapes"].items():
                self.add_attachment_shapes(
                    val.get("attachment_modality"), key, val.get("shape_files")
                )

        if reg_config.get("reg_graph_edges") is not None:
            self._reg_graph_edges = reg_config["reg_graph_edges"]
            if all([re.get("registered") for re in self.reg_graph_edges]):
                self.transformations = self.reg_graph_edges

    def reset_registered_modality(self, modalities):
        edge_keys = [
            r.get("modalities").get("source") for r in self.reg_graph_edges
        ]
        if isinstance(modalities, str):
            modalities = [modalities]

        for modality in modalities:
            modality_idx = edge_keys.index(modality)
            self.reg_graph_edges[modality_idx]["registered"] = False


if __name__ == "__main__":
    import argparse

    def config_to_WsiReg2D(config_filepath):
        reg_config = parse_check_reg_config(config_filepath)

        reg_graph = WsiReg2D(
            reg_config.get("project_name"),
            reg_config.get("output_dir"),
            reg_config.get("cache_images"),
        )
        return reg_graph

    parser = argparse.ArgumentParser(
        description='Load Whole Slide Image 2D Registration Graph from configuration file'
    )

    parser.add_argument(
        "config_filepath",
        metavar="C",
        type=str,
        nargs=1,
        help="full filepath for .yaml configuration file",
    )
    parser.add_argument(
        "--fw",
        type=str,
        nargs=1,
        help="how to write output registered images: ome.tiff, ome.zarr (default: ome.tiff)",
    )

    args = parser.parse_args()
    config_filepath = args.config_filepath[0]
    if args.fw is None:
        file_writer = "ome.tiff"
    else:
        file_writer = args.fw[0]

    reg_graph = config_to_WsiReg2D(config_filepath)
    reg_graph.add_data_from_config(config_filepath)

    reg_graph.register_images()
    reg_graph.save_transformations()
    reg_graph.transform_images(file_writer=file_writer)

    if reg_graph.shape_sets:
        reg_graph.transform_shapes()
