Skip to content

Registration

register(msims, transform_key=None, reg_channel_index=None, reg_channel=None, new_transform_key=None, registration_binning=None, reg_res_level=None, overlap_tolerance=0.0, pairwise_reg_func=phase_correlation_registration, pairwise_reg_func_kwargs=None, groupwise_resolution_method='global_optimization', groupwise_resolution_kwargs=None, pre_registration_pruning_method='alternating_pattern', pre_reg_pruning_method_kwargs=None, post_registration_do_quality_filter=False, post_registration_quality_threshold=0.2, plot_summary=False, pairs=None, scheduler=None, n_parallel_pairwise_regs=None, return_dict=False)

Register a list of views to a common extrinsic coordinate system.

High-level flow: 1) Build the overlap graph. 2) Run pairwise registrations for selected edges. 3) Resolve global transforms from the pairwise results.

Parameters

msims : list of MultiscaleSpatialImage Input views reg_channel_index : int, optional Index of channel to be used for registration, by default None reg_channel : str, optional Name of channel to be used for registration, by default None Overrides reg_channel_index transform_key : str, optional Extrinsic coordinate system to use as a starting point for the registration, by default None new_transform_key : str, optional If set, the registration result will be registered as a new extrinsic coordinate system in the input views (with the given name), by default None registration_binning : dict, optional Binning applied to each dimension during registration, by default None. If reg_res_level is also provided, the binning factors must be compatible with the resolution level. reg_res_level : int, optional Resolution level to use for registration (e.g., 0 for scale0, 1 for scale1). If None and registration_binning is provided, the optimal resolution level is automatically determined. By default None. overlap_tolerance : float or dict, optional Extend overlap regions considered for pairwise registration. - if 0, the overlap region is the intersection of the bounding boxes. - if > 0, the overlap region is the intersection of the bounding boxes extended by this value in all spatial dimensions. - if None, the full images are used for registration pairwise_reg_func : Callable, optional Function used for registration. See the docs for the function API. By default, phase_correlation_registration is used. Another useful built-in registration function is pairwise_reg_func=registration.registration_ANTsPy for translation, rigid, similarity or affine registration using ANTsPy. pairwise_reg_func_kwargs : dict, optional Additional keyword arguments passed to the registration function. In the case of pairwise_reg_func=registration_ANTsPy, this can include e.g: - 'transform_type': ['Translation', 'Rigid' 'Affine'] or ['Similarity'] For further parameters, see the docstring of the registration function. groupwise_resolution_method : str, optional Method used to resolve global transforms from pairwise registrations: - 'global_optimization' (transform: translation|rigid|similarity|affine) - 'shortest_paths' (uses the transform type defined by the pairwise registrations) - 'linear_two_pass' (transform: translation|rigid) Custom component-level methods can be registered via param_resolution.register_groupwise_resolution_method(...) and referenced by name. groupwise_resolution_kwargs : dict, optional Additional keyword arguments passed to the groupwise resolver. Parameters are method-specific. Common options include: - 'transform': final transform type (see method notes above) - 'reference_view': node index to keep fixed See the resolver docstrings for full details. pre_registration_pruning_method : str, optional Method used to eliminate registration edges (e.g. diagonals) from the view adjacency graph before registration. Available methods: - None: No pruning, useful when no regular arrangement is present. - 'alternating_pattern': Prune to edges between squares of differering colors in checkerboard pattern. Useful for regular 2D tile arrangements (of both 2D or 3D data). - 'shortest_paths_overlap_weighted': Prune to shortest paths in overlap graph (weighted by overlap). Useful to minimize the number of pairwise registrations. - 'otsu_threshold_on_overlap': Prune to edges with overlap above Otsu threshold. This is useful for regular 2D or 3D grid arrangements, as diagonal edges will be pruned. - 'keep_axis_aligned': Keep only edges that align with tile axes. This is useful for regular grid arrangements and to explicitely prune diagonals, e.g. when other methods fail. pre_reg_pruning_method_kwargs : dict, optional Additional keyword arguments passed to the pre-registration pruning method, e.g. - 'keep_axis_aligned': 'max_angle' (larger angles between stack axis and pair edge are discarded, default 0.2) post_registration_do_quality_filter : bool, optional post_registration_quality_threshold : float, optional Threshold used to filter edges by quality after registration, by default None (no filtering) plot_summary : bool, optional If True (and new_transform_key is set), plot graphs summarising the registration process and results: 1) Cross correlation values of pairwise registrations (stack boundaries shown as before registration) 2) Residual distances between registration edges after global parameter resolution. Grey edges have been removed during glob param res (stack boundaries shown as after registration). Solid edges were used by the resolver, dotted edges were unused. Stack boundary positions reflect the registration result. By default False pairs : list of tuples, optional If set, initialises the view adjacency graph using the indicates pairs of view/tile indices, by default None scheduler : str, optional (Deprecated since >0.1.28) Dask scheduler to use for parallel computation, by default None This parameter is deprecated and no longer used. Use a context manager instead to set the dask scheduler used within register(), e.g. with dask.config.set(scheduler='threads'): register(...) n_parallel_pairwise_regs : int, optional Number of parallel pairwise registrations to run. Setting this is specifically useful for limiting memory usage. By default None (all pairwise registrations are run in parallel) return_dict : bool, optional If True, return a dict containing params, registration metrics and more, by default False

Returns

list of xr.DataArray Parameters mapping each view into a new extrinsic coordinate system or dict Dictionary containing the following keys: - 'params': Parameters mapping each view into a new extrinsic coordinate system - 'pairwise_registration': Dictionary containing the following - 'summary_plot': Tuple containing the figure and axis of the summary plot - 'graph': networkx graph of pairwise registrations - 'metrics': Dictionary containing the following metrics: - 'qualities': Edge registration qualities - 'groupwise_resolution': Dictionary containing the following - 'summary_plot': Tuple containing the figure and axis of the summary plot - 'metrics': Dictionary containing the following metrics: - 'edge_residuals': Dict[int, dict[tuple, float]] mapping timepoint index to edge residuals - 'used_edges': Dict[int, list[tuple]] mapping timepoint index to edges used by the resolution method

Source code in src/multiview_stitcher/registration.py
def register(
    msims: list[MultiscaleSpatialImage],
    transform_key: str = None,
    reg_channel_index: int = None,
    reg_channel: str = None,
    new_transform_key: str = None,
    registration_binning: dict[str, int] = None,
    reg_res_level: int = None,
    overlap_tolerance: Union[float, dict[str, float]] = 0.0,
    pairwise_reg_func=phase_correlation_registration,
    pairwise_reg_func_kwargs: dict = None,
    groupwise_resolution_method="global_optimization",
    groupwise_resolution_kwargs: dict = None,
    pre_registration_pruning_method="alternating_pattern",
    pre_reg_pruning_method_kwargs: dict = None,
    post_registration_do_quality_filter: bool = False,
    post_registration_quality_threshold: float = 0.2,
    plot_summary: bool = False,
    pairs: list[tuple[int, int]] = None,
    scheduler=None,  # deprecated, see docstring
    n_parallel_pairwise_regs: int = None,
    return_dict: bool = False,
):
    """
    Register a list of views to a common extrinsic coordinate system.

    High-level flow:
    1) Build the overlap graph.
    2) Run pairwise registrations for selected edges.
    3) Resolve global transforms from the pairwise results.

    Parameters
    ----------
    msims : list of MultiscaleSpatialImage
        Input views
    reg_channel_index : int, optional
        Index of channel to be used for registration, by default None
    reg_channel : str, optional
        Name of channel to be used for registration, by default None
        Overrides reg_channel_index
    transform_key : str, optional
        Extrinsic coordinate system to use as a starting point
        for the registration, by default None
    new_transform_key : str, optional
        If set, the registration result will be registered as a new extrinsic
        coordinate system in the input views (with the given name), by default None
    registration_binning : dict, optional
        Binning applied to each dimension during registration, by default None.
        If reg_res_level is also provided, the binning factors must be compatible 
        with the resolution level.
    reg_res_level : int, optional
        Resolution level to use for registration (e.g., 0 for scale0, 1 for scale1).
        If None and registration_binning is provided, the optimal resolution level 
        is automatically determined. By default None.
    overlap_tolerance : float or dict, optional
        Extend overlap regions considered for pairwise registration.
        - if 0, the overlap region is the intersection of the bounding boxes.
        - if > 0, the overlap region is the intersection of the bounding boxes
            extended by this value in all spatial dimensions.
        - if None, the full images are used for registration
    pairwise_reg_func : Callable, optional
        Function used for registration. See the docs for the function API.
        By default, phase_correlation_registration is used. Another useful built-in
        registration function is `pairwise_reg_func=registration.registration_ANTsPy`
        for translation, rigid, similarity or affine registration using ANTsPy.
    pairwise_reg_func_kwargs : dict, optional
        Additional keyword arguments passed to the registration function.
        In the case of `pairwise_reg_func=registration_ANTsPy`, this can include e.g:
        - 'transform_type': ['Translation', 'Rigid' 'Affine'] or ['Similarity']
        For further parameters, see the docstring of the registration function.
    groupwise_resolution_method : str, optional
        Method used to resolve global transforms from pairwise registrations:
        - 'global_optimization' (transform: translation|rigid|similarity|affine)
        - 'shortest_paths' (uses the transform type defined by the pairwise registrations)
        - 'linear_two_pass' (transform: translation|rigid)
        Custom component-level methods can be registered via
        `param_resolution.register_groupwise_resolution_method(...)` and
        referenced by name.
    groupwise_resolution_kwargs : dict, optional
        Additional keyword arguments passed to the groupwise resolver.
        Parameters are method-specific. Common options include:
        - 'transform': final transform type (see method notes above)
        - 'reference_view': node index to keep fixed
        See the resolver docstrings for full details.
    pre_registration_pruning_method : str, optional
        Method used to eliminate registration edges (e.g. diagonals) from the view adjacency
        graph before registration. Available methods:
        - None: No pruning, useful when no regular arrangement is present.
        - 'alternating_pattern': Prune to edges between squares of differering
            colors in checkerboard pattern. Useful for regular 2D tile arrangements (of both 2D or 3D data).
        - 'shortest_paths_overlap_weighted': Prune to shortest paths in overlap graph
            (weighted by overlap). Useful to minimize the number of pairwise registrations.
        - 'otsu_threshold_on_overlap': Prune to edges with overlap above Otsu threshold.
            This is useful for regular 2D or 3D grid arrangements, as diagonal edges will be pruned.
        - 'keep_axis_aligned': Keep only edges that align with tile axes. This is useful for regular grid
            arrangements and to explicitely prune diagonals, e.g. when other methods fail.
    pre_reg_pruning_method_kwargs : dict, optional
        Additional keyword arguments passed to the pre-registration pruning method, e.g.
        - 'keep_axis_aligned': 'max_angle' (larger angles between stack axis and pair edge are discarded, default 0.2)
    post_registration_do_quality_filter : bool, optional
    post_registration_quality_threshold : float, optional
        Threshold used to filter edges by quality after registration,
        by default None (no filtering)
    plot_summary : bool, optional
        If True (and `new_transform_key` is set), plot graphs summarising the registration process and results:
        1) Cross correlation values of pairwise registrations
           (stack boundaries shown as before registration)
        2) Residual distances between registration edges after global parameter resolution.
           Grey edges have been removed during glob param res (stack boundaries shown as after registration).
           Solid edges were used by the resolver, dotted edges were unused.
        Stack boundary positions reflect the registration result.
        By default False
    pairs : list of tuples, optional
        If set, initialises the view adjacency graph using the indicates
        pairs of view/tile indices, by default None
    scheduler : str, optional
        (Deprecated since >0.1.28) Dask scheduler to use for parallel computation, by default None
        This parameter is deprecated and no longer used.
        Use a context manager instead to set the dask scheduler used within register(), e.g.
        `with dask.config.set(scheduler='threads'): register(...)`
    n_parallel_pairwise_regs : int, optional
        Number of parallel pairwise registrations to run. Setting this is specifically
        useful for limiting memory usage.
        By default None (all pairwise registrations are run in parallel)
    return_dict : bool, optional
        If True, return a dict containing params, registration metrics and more, by default False

    Returns
    -------
    list of xr.DataArray
        Parameters mapping each view into a new extrinsic coordinate system
    or
    dict
        Dictionary containing the following keys:
        - 'params': Parameters mapping each view into a new extrinsic coordinate system
        - 'pairwise_registration': Dictionary containing the following
            - 'summary_plot': Tuple containing the figure and axis of the summary plot
            - 'graph': networkx graph of pairwise registrations
            - 'metrics': Dictionary containing the following metrics:
                - 'qualities': Edge registration qualities
        - 'groupwise_resolution': Dictionary containing the following
            - 'summary_plot': Tuple containing the figure and axis of the summary plot
            - 'metrics': Dictionary containing the following metrics:
                - 'edge_residuals': Dict[int, dict[tuple, float]] mapping timepoint index
                  to edge residuals
                - 'used_edges': Dict[int, list[tuple]] mapping timepoint index to edges
                  used by the resolution method
    """

    # warn about deprecated parameter
    if scheduler is not None:
        warnings.warn(
            "The register(..., scheduler=) parameter is deprecated, no longer used "
            "and will be removed in a future version. "
            "Use a context manager to set the dask scheduler used within register(), e.g. "
            "`with dask.config.set(scheduler='threads'): register(...)`",
            DeprecationWarning,
            stacklevel=2,
        )

    if pairwise_reg_func_kwargs is None:
        pairwise_reg_func_kwargs = {}

    if groupwise_resolution_kwargs is None:
        groupwise_resolution_kwargs = {}

    if pre_reg_pruning_method_kwargs is None:
        pre_reg_pruning_method_kwargs = {}

    sims = [msi_utils.get_sim_from_msim(msim) for msim in msims]

    if "c" in msi_utils.get_dims(msims[0]):
        if reg_channel is None:
            if reg_channel_index is None:
                for msim in msims:
                    if "c" in msi_utils.get_dims(msim):
                        raise (
                            Exception("Please choose a registration channel.")
                        )
            else:
                reg_channel = sims[0].coords["c"][reg_channel_index]

        msims_reg = [
            msi_utils.multiscale_sel_coords(msim, {"c": reg_channel})
            if "c" in msi_utils.get_dims(msim)
            else msim
            for imsim, msim in enumerate(msims)
        ]
    else:
        msims_reg = msims

    # determine registration pairs from input images
    g = mv_graph.build_view_adjacency_graph_from_msims(
        msims_reg,
        transform_key=transform_key,
        pairs=pairs,
        overlap_tolerance=overlap_tolerance,
    )

    # prune registration pair graph
    if pre_registration_pruning_method is not None:
        g_reg = mv_graph.prune_view_adjacency_graph(
            g,
            method=pre_registration_pruning_method,
            pruning_method_kwargs=pre_reg_pruning_method_kwargs,
        )
    else:
        g_reg = g

    # compute pairwise registrations
    g_reg_computed = compute_pairwise_registrations(
        msims_reg,
        g_reg,
        transform_key=transform_key,
        registration_binning=registration_binning,
        reg_res_level=reg_res_level,
        overlap_tolerance=overlap_tolerance,
        pairwise_reg_func=pairwise_reg_func,
        pairwise_reg_func_kwargs=pairwise_reg_func_kwargs,
        n_parallel_pairwise_regs=n_parallel_pairwise_regs,
    )

    # optionally filter obtained pairwise registrations by quality
    if post_registration_do_quality_filter:
        # filter edges by quality
        g_reg_computed = mv_graph.filter_edges(
            g_reg_computed,
            threshold=post_registration_quality_threshold,
            weight_key="quality",
        )

    # resolve global registration parameters from pairwise registrations
    params_dict, groupwise_resolution_info_dict = groupwise_resolution(
        g_reg_computed,
        method=groupwise_resolution_method,
        **groupwise_resolution_kwargs,
    )

    params = [
        params_dict[iview] for iview in sorted(g_reg_computed.nodes())
    ]

    # optionally write registration result back to the input msims
    # under a new transform key
    if new_transform_key is not None:
        for imsim, msim in enumerate(msims):
            msi_utils.set_affine_transform(
                msim,
                params[imsim],
                transform_key=new_transform_key,
                base_transform_key=transform_key,
            )

    # optionally plot registration summaries
    if plot_summary:
        plot_info = _plot_registration_summaries(
            msims,
            transform_key,
            new_transform_key,
            g_reg_computed,
            groupwise_resolution_info_dict,
            show_plot=plot_summary,
        )
    else:
        plot_info = {}

    if return_dict:
        return {
            "params": params,
            "pairwise_registration": {
                "graph": g_reg_computed,
                "metrics": {
                    "qualities": nx.get_edge_attributes(
                        g_reg_computed, "quality"
                    )
                },
                "summary_plot": None if plot_summary is False
                else (
                    plot_info['fig_pair_reg'],
                    plot_info['ax_pair_reg']
                    )
            },
            "groupwise_resolution": {
                "metrics": groupwise_resolution_info_dict,
                "summary_plot": None if plot_summary is False
                else (
                    plot_info['fig_group_res'],
                    plot_info['ax_group_res']
                )
            },
        }
    else:
        return params

registration_ITKElastix(fixed_data, moving_data, *, fixed_origin, moving_origin, fixed_spacing, moving_spacing, initial_affine, transform_types=None, **elastix_registration_kwargs)

Use ITKElastix to perform registration between two spatial images.

Parameters

transform_types : list of str, optional Sequence of transform types to apply in successive stages. Supported values: 'Translation', 'Rigid', 'Similarity', 'Affine'. By default ['Translation', 'Rigid', 'Similarity']. **elastix_registration_kwargs Additional keyword arguments. The following are handled explicitly and applied to the elastix parameter map for each stage:

number_of_resolutions : int, optional
    Number of resolution levels in the multi-resolution scheme,
    by default 2.
number_of_iterations : int, optional
    Maximum number of optimizer iterations per resolution level.
    If None, the elastix default for the chosen transform type is used.
metric : str, optional
    Similarity metric used by elastix. If None, the elastix default
    for the chosen transform type is used. Common values:

    - 'AdvancedMattesMutualInformation' (default for most transforms)
    - 'AdvancedMeanSquares'
    - 'AdvancedNormalizedCorrelation'
    - 'NormalizedMutualInformation'

Remaining kwargs are forwarded to ``itk.elastix_registration_method``
(e.g. ``log_to_console=True``).
Source code in src/multiview_stitcher/registration.py
def registration_ITKElastix(
    fixed_data,
    moving_data,
    *,
    fixed_origin,
    moving_origin,
    fixed_spacing,
    moving_spacing,
    initial_affine,
    transform_types=None,
    **elastix_registration_kwargs,
):
    """
    Use ITKElastix to perform registration between two spatial images.

    Parameters
    ----------
    transform_types : list of str, optional
        Sequence of transform types to apply in successive stages.
        Supported values: 'Translation', 'Rigid', 'Similarity', 'Affine'.
        By default ['Translation', 'Rigid', 'Similarity'].
    **elastix_registration_kwargs
        Additional keyword arguments. The following are handled explicitly
        and applied to the elastix parameter map for each stage:

        number_of_resolutions : int, optional
            Number of resolution levels in the multi-resolution scheme,
            by default 2.
        number_of_iterations : int, optional
            Maximum number of optimizer iterations per resolution level.
            If None, the elastix default for the chosen transform type is used.
        metric : str, optional
            Similarity metric used by elastix. If None, the elastix default
            for the chosen transform type is used. Common values:

            - 'AdvancedMattesMutualInformation' (default for most transforms)
            - 'AdvancedMeanSquares'
            - 'AdvancedNormalizedCorrelation'
            - 'NormalizedMutualInformation'

        Remaining kwargs are forwarded to ``itk.elastix_registration_method``
        (e.g. ``log_to_console=True``).
    """

    try:
        global itk
        import itk
    except ImportError:
        raise ImportError(
            "Please install the itk-elastix package to use ITKElastix for registration.\n"
            "E.g. using pip:\n"
            "- `pip install multiview-stitcher[itk-elastix]` or\n"
            "- `pip install itk-elastix`"
        ) from None

    if transform_types is None:
        transform_types = ["Translation", "Rigid"]

    transform_types = [t.title() for t in transform_types]

    spatial_dims = fixed_data.dims
    ndim = len(spatial_dims)

    fixed_image = _get_itk_image_from_data(
        fixed_data.data,
        origin=[fixed_origin[dim] for dim in spatial_dims],
        spacing=[fixed_spacing[dim] for dim in spatial_dims],
    )
    moving_image = _get_itk_image_from_data(
        moving_data.data,
        origin=[moving_origin[dim] for dim in spatial_dims],
        spacing=[moving_spacing[dim] for dim in spatial_dims],
    )

    number_of_iterations = elastix_registration_kwargs.pop(
        "number_of_iterations", None
    )
    number_of_resolutions = elastix_registration_kwargs.pop(
        "number_of_resolutions", 2
    )
    metric = elastix_registration_kwargs.pop("metric", None)

    default_elastix_registration_kwargs = {
        "log_to_console": False,
    }
    elastix_registration_kwargs = {
        **default_elastix_registration_kwargs,
        **elastix_registration_kwargs,
    }

    # Run one elastix call per transform type, threading the composed affine
    # forward as the initial transform for each successive stage.  This avoids
    # elastix's multi-stage chaining, which breaks when output_directory is not
    # set (IntialTransformParameterFileName becomes '' for stages beyond the
    # first) and can also partially undo the initial transform when chaining.
    with tempfile.TemporaryDirectory() as tmpdir:
        current_affine = initial_affine
        result_image = None

        for i_stage, transform_type in enumerate(transform_types):
            is_last = i_stage == len(transform_types) - 1
            stage_dir = os.path.join(tmpdir, f"stage_{i_stage}")
            os.makedirs(stage_dir)

            initial_transform_path = os.path.join(
                stage_dir, "initial_transform.txt"
            )
            _write_initial_elastix_transform(
                initial_transform_path,
                initial_affine=current_affine,
                ndim=ndim,
            )

            single_stage_po = itk.ParameterObject.New()
            single_stage_po.AddParameterMap(
                _get_elastix_parameter_map(
                    transform_type,
                    number_of_resolutions=number_of_resolutions,
                    number_of_iterations=number_of_iterations,
                    metric=metric,
                    write_result_image=is_last,
                )
            )

            result_image, result_parameter_object = itk.elastix_registration_method(
                fixed_image=fixed_image,
                moving_image=moving_image,
                parameter_object=single_stage_po,
                initial_transform_parameter_file_name=initial_transform_path,
                output_directory=stage_dir,
                **elastix_registration_kwargs,
            )

            current_affine = _get_affine_from_elastix_transform_parameter_object(
                result_parameter_object,
                moving_image=moving_image,
                ndim=ndim,
            )

        affine_matrix = current_affine

    quality = link_quality_metric_func(
        np.asarray(fixed_data.data),
        itk.array_view_from_image(result_image),
    )

    return {
        "affine_matrix": affine_matrix,
        "quality": quality,
    }