Skip to content

Registration

register(msims, transform_key=None, points_key='beads', prefilter_markers=False, reg_channel_index=None, reg_channel=None, new_transform_key=None, registration_binning=None, reg_res_level=None, overlap_tolerance=0.0, pairwise_reg_func=phase_correlation_registration, pairwise_reg_func_kwargs=None, groupwise_resolution_method='global_optimization', groupwise_resolution_kwargs=None, pre_registration_pruning_method='alternating_pattern', pre_reg_pruning_method_kwargs=None, post_registration_do_quality_filter=False, post_registration_quality_threshold=0.2, plot_summary=False, pairs=None, scheduler=None, n_parallel_pairwise_regs=None, return_dict=False)

Register a list of views to a common extrinsic coordinate system.

High-level flow: 1) Build the overlap graph. 2) Run pairwise registrations for selected edges. 3) Resolve global transforms from the pairwise results.

Parameters

msims : list of MultiscaleSpatialImage Input views reg_channel_index : int, optional Index of channel to be used for registration, by default None reg_channel : str, optional Name of channel to be used for registration, by default None Overrides reg_channel_index transform_key : str, optional Extrinsic coordinate system to use as a starting point for the registration, by default None points_key : str, optional Named point set to use for marker-aware pairwise registration functions. prefilter_markers : bool, optional If True, restrict markers to each pairwise overlap before marker-based pairwise registration. By default False. new_transform_key : str, optional If set, the registration result will be registered as a new extrinsic coordinate system in the input views (with the given name), by default None registration_binning : dict, optional Binning applied to each dimension during registration, by default None. If reg_res_level is also provided, the binning factors must be compatible with the resolution level. reg_res_level : int, optional Resolution level to use for registration (e.g., 0 for scale0, 1 for scale1). If None and registration_binning is provided, the optimal resolution level is automatically determined. By default None. overlap_tolerance : float or dict, optional Extend overlap regions considered for pairwise registration. - if 0, the overlap region is the intersection of the bounding boxes. - if > 0, the overlap region is the intersection of the bounding boxes extended by this value in all spatial dimensions. - if None, the full images are used for registration pairwise_reg_func : Callable, optional Function used for registration. See the docs for the function API. By default, phase_correlation_registration is used. Another useful built-in registration function is pairwise_reg_func=registration.registration_ANTsPy for translation, rigid, similarity or affine registration using ANTsPy. pairwise_reg_func_kwargs : dict, optional Additional keyword arguments passed to the registration function. In the case of pairwise_reg_func=registration_ANTsPy, this can include e.g: - 'transform_type': ['Translation', 'Rigid' 'Affine'] or ['Similarity'] For further parameters, see the docstring of the registration function. groupwise_resolution_method : str, optional Method used to resolve global transforms from pairwise registrations: - 'global_optimization' (transform: translation|rigid|similarity|affine) - 'shortest_paths' (uses the transform type defined by the pairwise registrations) - 'linear_two_pass' (transform: translation|rigid) Custom component-level methods can be registered via param_resolution.register_groupwise_resolution_method(...) and referenced by name. groupwise_resolution_kwargs : dict, optional Additional keyword arguments passed to the groupwise resolver. Parameters are method-specific. Common options include: - 'transform': final transform type (see method notes above) - 'reference_view': node index to keep fixed See the resolver docstrings for full details. pre_registration_pruning_method : str, optional Method used to eliminate registration edges (e.g. diagonals) from the view adjacency graph before registration. Available methods: - None: No pruning, useful when no regular arrangement is present. - 'alternating_pattern': Prune to edges between squares of differering colors in checkerboard pattern. Useful for regular 2D tile arrangements (of both 2D or 3D data). - 'shortest_paths_overlap_weighted': Prune to shortest paths in overlap graph (weighted by overlap). Useful to minimize the number of pairwise registrations. - 'otsu_threshold_on_overlap': Prune to edges with overlap above Otsu threshold. This is useful for regular 2D or 3D grid arrangements, as diagonal edges will be pruned. - 'keep_axis_aligned': Keep only edges that align with tile axes. This is useful for regular grid arrangements and to explicitely prune diagonals, e.g. when other methods fail. pre_reg_pruning_method_kwargs : dict, optional Additional keyword arguments passed to the pre-registration pruning method, e.g. - 'keep_axis_aligned': 'max_angle' (larger angles between stack axis and pair edge are discarded, default 0.2) post_registration_do_quality_filter : bool, optional post_registration_quality_threshold : float, optional Threshold used to filter edges by quality after registration, by default None (no filtering) plot_summary : bool, optional If True (and new_transform_key is set), plot graphs summarising the registration process and results: 1) Cross correlation values of pairwise registrations (stack boundaries shown as before registration) 2) Residual distances between registration edges after global parameter resolution. Grey edges have been removed during glob param res (stack boundaries shown as after registration). Solid edges were used by the resolver, dotted edges were unused. Stack boundary positions reflect the registration result. By default False pairs : list of tuples, optional If set, initialises the view adjacency graph using the indicates pairs of view/tile indices, by default None scheduler : str, optional (Deprecated since >0.1.28) Dask scheduler to use for parallel computation, by default None This parameter is deprecated and no longer used. Use a context manager instead to set the dask scheduler used within register(), e.g. with dask.config.set(scheduler='threads'): register(...) n_parallel_pairwise_regs : int, optional Number of parallel pairwise registrations to run. Setting this is specifically useful for limiting memory usage. By default None (all pairwise registrations are run in parallel) return_dict : bool, optional If True, return a dict containing params, registration metrics and more, by default False

Returns

list of xr.DataArray Parameters mapping each view into a new extrinsic coordinate system or dict Dictionary containing the following keys: - 'params': Parameters mapping each view into a new extrinsic coordinate system - 'pairwise_registration': Dictionary containing the following - 'summary_plot': Tuple containing the figure and axis of the summary plot - 'graph': networkx graph of pairwise registrations - 'metrics': Dictionary containing the following metrics: - 'qualities': Edge registration qualities - 'groupwise_resolution': Dictionary containing the following - 'summary_plot': Tuple containing the figure and axis of the summary plot - 'metrics': Dictionary containing the following metrics: - 'edge_residuals': Dict[int, dict[tuple, float]] mapping timepoint index to edge residuals - 'used_edges': Dict[int, list[tuple]] mapping timepoint index to edges used by the resolution method

Source code in src/multiview_stitcher/registration.py
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
2408
2409
2410
2411
2412
2413
2414
2415
2416
2417
2418
2419
2420
2421
2422
2423
2424
2425
2426
2427
2428
2429
2430
2431
2432
2433
2434
2435
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457
2458
2459
2460
2461
2462
2463
2464
2465
2466
2467
2468
2469
2470
2471
2472
2473
2474
2475
2476
2477
2478
2479
2480
2481
2482
2483
2484
2485
2486
2487
2488
2489
2490
2491
2492
2493
2494
2495
2496
2497
2498
2499
2500
2501
2502
2503
2504
2505
2506
2507
2508
2509
2510
2511
2512
2513
2514
2515
2516
2517
2518
2519
2520
2521
2522
2523
2524
2525
2526
2527
2528
2529
2530
def register(
    msims: list[MultiscaleSpatialImage],
    transform_key: str = None,
    points_key: str = "beads",
    prefilter_markers: bool = False,
    reg_channel_index: int = None,
    reg_channel: str = None,
    new_transform_key: str = None,
    registration_binning: dict[str, int] = None,
    reg_res_level: int = None,
    overlap_tolerance: Union[float, dict[str, float]] = 0.0,
    pairwise_reg_func=phase_correlation_registration,
    pairwise_reg_func_kwargs: dict = None,
    groupwise_resolution_method="global_optimization",
    groupwise_resolution_kwargs: dict = None,
    pre_registration_pruning_method="alternating_pattern",
    pre_reg_pruning_method_kwargs: dict = None,
    post_registration_do_quality_filter: bool = False,
    post_registration_quality_threshold: float = 0.2,
    plot_summary: bool = False,
    pairs: list[tuple[int, int]] = None,
    scheduler=None,  # deprecated, see docstring
    n_parallel_pairwise_regs: int = None,
    return_dict: bool = False,
):
    """
    Register a list of views to a common extrinsic coordinate system.

    High-level flow:
    1) Build the overlap graph.
    2) Run pairwise registrations for selected edges.
    3) Resolve global transforms from the pairwise results.

    Parameters
    ----------
    msims : list of MultiscaleSpatialImage
        Input views
    reg_channel_index : int, optional
        Index of channel to be used for registration, by default None
    reg_channel : str, optional
        Name of channel to be used for registration, by default None
        Overrides reg_channel_index
    transform_key : str, optional
        Extrinsic coordinate system to use as a starting point
        for the registration, by default None
    points_key : str, optional
        Named point set to use for marker-aware pairwise registration functions.
    prefilter_markers : bool, optional
        If True, restrict markers to each pairwise overlap before marker-based
        pairwise registration. By default False.
    new_transform_key : str, optional
        If set, the registration result will be registered as a new extrinsic
        coordinate system in the input views (with the given name), by default None
    registration_binning : dict, optional
        Binning applied to each dimension during registration, by default None.
        If reg_res_level is also provided, the binning factors must be compatible 
        with the resolution level.
    reg_res_level : int, optional
        Resolution level to use for registration (e.g., 0 for scale0, 1 for scale1).
        If None and registration_binning is provided, the optimal resolution level 
        is automatically determined. By default None.
    overlap_tolerance : float or dict, optional
        Extend overlap regions considered for pairwise registration.
        - if 0, the overlap region is the intersection of the bounding boxes.
        - if > 0, the overlap region is the intersection of the bounding boxes
            extended by this value in all spatial dimensions.
        - if None, the full images are used for registration
    pairwise_reg_func : Callable, optional
        Function used for registration. See the docs for the function API.
        By default, phase_correlation_registration is used. Another useful built-in
        registration function is `pairwise_reg_func=registration.registration_ANTsPy`
        for translation, rigid, similarity or affine registration using ANTsPy.
    pairwise_reg_func_kwargs : dict, optional
        Additional keyword arguments passed to the registration function.
        In the case of `pairwise_reg_func=registration_ANTsPy`, this can include e.g:
        - 'transform_type': ['Translation', 'Rigid' 'Affine'] or ['Similarity']
        For further parameters, see the docstring of the registration function.
    groupwise_resolution_method : str, optional
        Method used to resolve global transforms from pairwise registrations:
        - 'global_optimization' (transform: translation|rigid|similarity|affine)
        - 'shortest_paths' (uses the transform type defined by the pairwise registrations)
        - 'linear_two_pass' (transform: translation|rigid)
        Custom component-level methods can be registered via
        `param_resolution.register_groupwise_resolution_method(...)` and
        referenced by name.
    groupwise_resolution_kwargs : dict, optional
        Additional keyword arguments passed to the groupwise resolver.
        Parameters are method-specific. Common options include:
        - 'transform': final transform type (see method notes above)
        - 'reference_view': node index to keep fixed
        See the resolver docstrings for full details.
    pre_registration_pruning_method : str, optional
        Method used to eliminate registration edges (e.g. diagonals) from the view adjacency
        graph before registration. Available methods:
        - None: No pruning, useful when no regular arrangement is present.
        - 'alternating_pattern': Prune to edges between squares of differering
            colors in checkerboard pattern. Useful for regular 2D tile arrangements (of both 2D or 3D data).
        - 'shortest_paths_overlap_weighted': Prune to shortest paths in overlap graph
            (weighted by overlap). Useful to minimize the number of pairwise registrations.
        - 'otsu_threshold_on_overlap': Prune to edges with overlap above Otsu threshold.
            This is useful for regular 2D or 3D grid arrangements, as diagonal edges will be pruned.
        - 'keep_axis_aligned': Keep only edges that align with tile axes. This is useful for regular grid
            arrangements and to explicitely prune diagonals, e.g. when other methods fail.
    pre_reg_pruning_method_kwargs : dict, optional
        Additional keyword arguments passed to the pre-registration pruning method, e.g.
        - 'keep_axis_aligned': 'max_angle' (larger angles between stack axis and pair edge are discarded, default 0.2)
    post_registration_do_quality_filter : bool, optional
    post_registration_quality_threshold : float, optional
        Threshold used to filter edges by quality after registration,
        by default None (no filtering)
    plot_summary : bool, optional
        If True (and `new_transform_key` is set), plot graphs summarising the registration process and results:
        1) Cross correlation values of pairwise registrations
           (stack boundaries shown as before registration)
        2) Residual distances between registration edges after global parameter resolution.
           Grey edges have been removed during glob param res (stack boundaries shown as after registration).
           Solid edges were used by the resolver, dotted edges were unused.
        Stack boundary positions reflect the registration result.
        By default False
    pairs : list of tuples, optional
        If set, initialises the view adjacency graph using the indicates
        pairs of view/tile indices, by default None
    scheduler : str, optional
        (Deprecated since >0.1.28) Dask scheduler to use for parallel computation, by default None
        This parameter is deprecated and no longer used.
        Use a context manager instead to set the dask scheduler used within register(), e.g.
        `with dask.config.set(scheduler='threads'): register(...)`
    n_parallel_pairwise_regs : int, optional
        Number of parallel pairwise registrations to run. Setting this is specifically
        useful for limiting memory usage.
        By default None (all pairwise registrations are run in parallel)
    return_dict : bool, optional
        If True, return a dict containing params, registration metrics and more, by default False

    Returns
    -------
    list of xr.DataArray
        Parameters mapping each view into a new extrinsic coordinate system
    or
    dict
        Dictionary containing the following keys:
        - 'params': Parameters mapping each view into a new extrinsic coordinate system
        - 'pairwise_registration': Dictionary containing the following
            - 'summary_plot': Tuple containing the figure and axis of the summary plot
            - 'graph': networkx graph of pairwise registrations
            - 'metrics': Dictionary containing the following metrics:
                - 'qualities': Edge registration qualities
        - 'groupwise_resolution': Dictionary containing the following
            - 'summary_plot': Tuple containing the figure and axis of the summary plot
            - 'metrics': Dictionary containing the following metrics:
                - 'edge_residuals': Dict[int, dict[tuple, float]] mapping timepoint index
                  to edge residuals
                - 'used_edges': Dict[int, list[tuple]] mapping timepoint index to edges
                  used by the resolution method
    """

    # warn about deprecated parameter
    if scheduler is not None:
        warnings.warn(
            "The register(..., scheduler=) parameter is deprecated, no longer used "
            "and will be removed in a future version. "
            "Use a context manager to set the dask scheduler used within register(), e.g. "
            "`with dask.config.set(scheduler='threads'): register(...)`",
            DeprecationWarning,
            stacklevel=2,
        )

    if pairwise_reg_func_kwargs is None:
        pairwise_reg_func_kwargs = {}

    if groupwise_resolution_kwargs is None:
        groupwise_resolution_kwargs = {}

    if pre_reg_pruning_method_kwargs is None:
        pre_reg_pruning_method_kwargs = {}

    sims = [msi_utils.get_sim_from_msim(msim) for msim in msims]

    if "c" in msi_utils.get_dims(msims[0]):
        if reg_channel is None:
            if reg_channel_index is None:
                for msim in msims:
                    if "c" in msi_utils.get_dims(msim):
                        raise (
                            Exception("Please choose a registration channel.")
                        )
            else:
                reg_channel = sims[0].coords["c"][reg_channel_index]

        msims_reg = [
            msi_utils.multiscale_sel_coords(msim, {"c": reg_channel})
            if "c" in msi_utils.get_dims(msim)
            else msim
            for imsim, msim in enumerate(msims)
        ]
    else:
        msims_reg = msims

    # determine registration pairs from input images
    g = mv_graph.build_view_adjacency_graph_from_msims(
        msims_reg,
        transform_key=transform_key,
        pairs=pairs,
        overlap_tolerance=overlap_tolerance,
    )
    if logger.isEnabledFor(logging.DEBUG):
        logger.debug(
            "Registration graph built: nodes=%s, edges=%s, edge_list=%s, "
            "transform_key=%s, overlap_tolerance=%s",
            g.number_of_nodes(),
            g.number_of_edges(),
            sorted(g.edges()),
            transform_key,
            overlap_tolerance,
        )

    # prune registration pair graph
    if pre_registration_pruning_method is not None:
        g_reg = mv_graph.prune_view_adjacency_graph(
            g,
            method=pre_registration_pruning_method,
            pruning_method_kwargs=pre_reg_pruning_method_kwargs,
        )
    else:
        g_reg = g
    if logger.isEnabledFor(logging.DEBUG):
        logger.debug(
            "Registration graph after pruning: method=%s, nodes=%s, edges=%s, "
            "edge_list=%s",
            pre_registration_pruning_method,
            g_reg.number_of_nodes(),
            g_reg.number_of_edges(),
            sorted(g_reg.edges()),
        )

    # if required, import itk already here
    # to make sure it's available in dask threads
    if pairwise_reg_func == registration_ITKElastix:
        try:
            global itk
            import itk
        except ImportError:
            raise ImportError(
                "Please install the itk-elastix package to use ITKElastix for registration.\n"
                "E.g. using pip:\n"
                "- `pip install multiview-stitcher[itk-elastix]` or\n"
                "- `pip install itk-elastix`"
            ) from None

    # compute pairwise registrations
    g_reg_computed = compute_pairwise_registrations(
        msims_reg,
        g_reg,
        transform_key=transform_key,
        points_key=points_key,
        prefilter_markers=prefilter_markers,
        registration_binning=registration_binning,
        reg_res_level=reg_res_level,
        overlap_tolerance=overlap_tolerance,
        pairwise_reg_func=pairwise_reg_func,
        pairwise_reg_func_kwargs=pairwise_reg_func_kwargs,
        n_parallel_pairwise_regs=n_parallel_pairwise_regs,
    )

    # optionally filter obtained pairwise registrations by quality
    if post_registration_do_quality_filter:
        # filter edges by quality
        g_reg_computed = mv_graph.filter_edges(
            g_reg_computed,
            threshold=post_registration_quality_threshold,
            weight_key="quality",
        )

    # resolve global registration parameters from pairwise registrations
    params_dict, groupwise_resolution_info_dict = groupwise_resolution(
        g_reg_computed,
        method=groupwise_resolution_method,
        **groupwise_resolution_kwargs,
    )
    if logger.isEnabledFor(logging.DEBUG):
        for view_index, params_for_view in sorted(params_dict.items()):
            logger.debug(
                "Groupwise registration transform: view=%s, method=%s, affine=\n%s",
                view_index,
                groupwise_resolution_method,
                _format_array_for_log(params_for_view),
            )

    params = [
        params_dict[iview] for iview in sorted(g_reg_computed.nodes())
    ]

    # optionally write registration result back to the input msims
    # under a new transform key
    if new_transform_key is not None:
        for imsim, msim in enumerate(msims):
            msi_utils.set_affine_transform(
                msim,
                params[imsim],
                transform_key=new_transform_key,
                base_transform_key=transform_key,
            )

    # optionally plot registration summaries
    if plot_summary:
        plot_info = _plot_registration_summaries(
            msims,
            transform_key,
            new_transform_key,
            g_reg_computed,
            groupwise_resolution_info_dict,
            show_plot=plot_summary,
        )
    else:
        plot_info = {}

    if return_dict:
        return {
            "params": params,
            "pairwise_registration": {
                "graph": g_reg_computed,
                "metrics": {
                    "qualities": nx.get_edge_attributes(
                        g_reg_computed, "quality"
                    )
                },
                "summary_plot": None if plot_summary is False
                else (
                    plot_info['fig_pair_reg'],
                    plot_info['ax_pair_reg']
                    )
            },
            "groupwise_resolution": {
                "metrics": groupwise_resolution_info_dict,
                "summary_plot": None if plot_summary is False
                else (
                    plot_info['fig_group_res'],
                    plot_info['ax_group_res']
                )
            },
        }
    else:
        return params

registration_ITKElastix(fixed_data, moving_data, *, fixed_origin, moving_origin, fixed_spacing, moving_spacing, initial_affine, transform_types=None, **elastix_registration_kwargs)

Use ITKElastix to perform registration between two spatial images.

Parameters

transform_types : list of str, optional Sequence of transform types to apply in successive stages. Supported values: 'Translation', 'Rigid', 'Similarity', 'Affine'. By default ['Translation', 'Rigid', 'Similarity']. **elastix_registration_kwargs Additional keyword arguments. The following are handled explicitly and applied to the elastix parameter map for each stage:

number_of_resolutions : int, optional
    Number of resolution levels in the multi-resolution scheme,
    by default 2.
number_of_iterations : int, optional
    Maximum number of optimizer iterations per resolution level.
    If None, the elastix default for the chosen transform type is used.
metric : str, optional
    Similarity metric used by elastix. If None, the elastix default
    for the chosen transform type is used. Common values:

    - 'AdvancedMattesMutualInformation' (default for most transforms)
    - 'AdvancedMeanSquares'
    - 'AdvancedNormalizedCorrelation'
    - 'NormalizedMutualInformation'

Remaining kwargs are forwarded to ``itk.elastix_registration_method``
(e.g. ``log_to_console=True``).
Source code in src/multiview_stitcher/registration.py
def registration_ITKElastix(
    fixed_data,
    moving_data,
    *,
    fixed_origin,
    moving_origin,
    fixed_spacing,
    moving_spacing,
    initial_affine,
    transform_types=None,
    **elastix_registration_kwargs,
):
    """
    Use ITKElastix to perform registration between two spatial images.

    Parameters
    ----------
    transform_types : list of str, optional
        Sequence of transform types to apply in successive stages.
        Supported values: 'Translation', 'Rigid', 'Similarity', 'Affine'.
        By default ['Translation', 'Rigid', 'Similarity'].
    **elastix_registration_kwargs
        Additional keyword arguments. The following are handled explicitly
        and applied to the elastix parameter map for each stage:

        number_of_resolutions : int, optional
            Number of resolution levels in the multi-resolution scheme,
            by default 2.
        number_of_iterations : int, optional
            Maximum number of optimizer iterations per resolution level.
            If None, the elastix default for the chosen transform type is used.
        metric : str, optional
            Similarity metric used by elastix. If None, the elastix default
            for the chosen transform type is used. Common values:

            - 'AdvancedMattesMutualInformation' (default for most transforms)
            - 'AdvancedMeanSquares'
            - 'AdvancedNormalizedCorrelation'
            - 'NormalizedMutualInformation'

        Remaining kwargs are forwarded to ``itk.elastix_registration_method``
        (e.g. ``log_to_console=True``).
    """

    try:
        global itk
        import itk
    except ImportError:
        raise ImportError(
            "Please install the itk-elastix package to use ITKElastix for registration.\n"
            "E.g. using pip:\n"
            "- `pip install multiview-stitcher[itk-elastix]` or\n"
            "- `pip install itk-elastix`"
        ) from None

    if transform_types is None:
        transform_types = ["Translation", "Rigid"]

    transform_types = [t.title() for t in transform_types]

    spatial_dims = fixed_data.dims
    ndim = len(spatial_dims)

    fixed_image = _get_itk_image_from_data(
        fixed_data.data,
        origin=[fixed_origin[dim] for dim in spatial_dims],
        spacing=[fixed_spacing[dim] for dim in spatial_dims],
    )
    moving_image = _get_itk_image_from_data(
        moving_data.data,
        origin=[moving_origin[dim] for dim in spatial_dims],
        spacing=[moving_spacing[dim] for dim in spatial_dims],
    )

    number_of_iterations = elastix_registration_kwargs.pop(
        "number_of_iterations", None
    )
    number_of_resolutions = elastix_registration_kwargs.pop(
        "number_of_resolutions", 2
    )
    metric = elastix_registration_kwargs.pop("metric", None)

    default_elastix_registration_kwargs = {
        "log_to_console": False,
    }
    elastix_registration_kwargs = {
        **default_elastix_registration_kwargs,
        **elastix_registration_kwargs,
    }

    # Run one elastix call per transform type, threading the composed affine
    # forward as the initial transform for each successive stage.  This avoids
    # elastix's multi-stage chaining, which breaks when output_directory is not
    # set (IntialTransformParameterFileName becomes '' for stages beyond the
    # first) and can also partially undo the initial transform when chaining.
    with tempfile.TemporaryDirectory() as tmpdir:
        current_affine = initial_affine
        result_image = None

        for i_stage, transform_type in enumerate(transform_types):
            is_last = i_stage == len(transform_types) - 1
            stage_dir = os.path.join(tmpdir, f"stage_{i_stage}")
            os.makedirs(stage_dir)

            initial_transform_path = os.path.join(
                stage_dir, "initial_transform.txt"
            )
            _write_initial_elastix_transform(
                initial_transform_path,
                initial_affine=current_affine,
                ndim=ndim,
            )

            single_stage_po = itk.ParameterObject.New()
            single_stage_po.AddParameterMap(
                _get_elastix_parameter_map(
                    transform_type,
                    number_of_resolutions=number_of_resolutions,
                    number_of_iterations=number_of_iterations,
                    metric=metric,
                    write_result_image=is_last,
                )
            )

            result_image, result_parameter_object = itk.elastix_registration_method(
                fixed_image=fixed_image,
                moving_image=moving_image,
                parameter_object=single_stage_po,
                initial_transform_parameter_file_name=initial_transform_path,
                output_directory=stage_dir,
                **elastix_registration_kwargs,
            )

            current_affine = _get_affine_from_elastix_transform_parameter_object(
                result_parameter_object,
                moving_image=moving_image,
                ndim=ndim,
            )

        affine_matrix = current_affine

    quality = link_quality_metric_func(
        np.asarray(fixed_data.data),
        itk.array_view_from_image(result_image),
    )

    return {
        "affine_matrix": affine_matrix,
        "quality": quality,
    }