Skip to content

Fusion

fuse(sims, transform_key=None, fusion_func=weighted_average_fusion, weights_func=None, weights_func_kwargs=None, output_spacing=None, output_stack_mode='union', output_origin=None, output_shape=None, output_stack_properties=None, output_chunksize=None, overlap_in_pixels=None, interpolation_order=1)

Fuse input views.

This function fuses all (Z)YX views ("fields") contained in the input list of images, which can additionally contain C and T dimensions.

Parameters

sims : list of SpatialImage Input views. transform_key : str, optional Which (extrinsic coordinate system) to use as transformation parameters. By default None (intrinsic coordinate system). fusion_func : func, optional Fusion function to be applied. This function receives the following inputs (as arrays if applicable): transformed_views, blending_weights, fusion_weights, params. By default weighted_average_fusion weights_func : func, optional Function to calculate fusion weights. This function receives the following inputs: transformed_views (as spatial images), params. It returns (non-normalized) fusion weights for each view. By default None. output_spacing : dict, optional Spacing of the fused image for each spatial dimension, by default None output_stack_mode : str, optional Mode to determine output stack properties. Can be one of "union", "intersection", "sample". By default "union" output_origin : dict, optional Origin of the fused image for each spatial dimension, by default None output_shape : type, optional Shape of the fused image for each spatial dimension, by default None output_stack_properties : dict, optional Dictionary describing the output stack with keys 'spacing', 'origin', 'shape'. Other output_* are ignored if this argument is present. output_chunksize : int or tuple of ints, optional Chunksize of the dask data array of the fused image, by default 512

Returns

SpatialImage Fused image.

Source code in src/multiview_stitcher/fusion.py
def fuse(
    sims: list,
    transform_key: str = None,
    fusion_func=weighted_average_fusion,
    weights_func=None,
    weights_func_kwargs=None,
    output_spacing=None,
    output_stack_mode="union",
    output_origin=None,
    output_shape=None,
    output_stack_properties=None,
    output_chunksize=None,
    overlap_in_pixels=None,
    interpolation_order=1,
):
    """

    Fuse input views.

    This function fuses all (Z)YX views ("fields") contained in the
    input list of images, which can additionally contain C and T dimensions.

    Parameters
    ----------
    sims : list of SpatialImage
        Input views.
    transform_key : str, optional
        Which (extrinsic coordinate system) to use as transformation parameters.
        By default None (intrinsic coordinate system).
    fusion_func : func, optional
        Fusion function to be applied. This function receives the following
        inputs (as arrays if applicable): transformed_views, blending_weights, fusion_weights, params.
        By default weighted_average_fusion
    weights_func : func, optional
        Function to calculate fusion weights. This function receives the
        following inputs: transformed_views (as spatial images), params.
        It returns (non-normalized) fusion weights for each view.
        By default None.
    output_spacing : dict, optional
        Spacing of the fused image for each spatial dimension, by default None
    output_stack_mode : str, optional
        Mode to determine output stack properties. Can be one of
        "union", "intersection", "sample". By default "union"
    output_origin : dict, optional
        Origin of the fused image for each spatial dimension, by default None
    output_shape : _type_, optional
        Shape of the fused image for each spatial dimension, by default None
    output_stack_properties : dict, optional
        Dictionary describing the output stack with keys
        'spacing', 'origin', 'shape'. Other output_* are ignored
        if this argument is present.
    output_chunksize : int or tuple of ints, optional
        Chunksize of the dask data array of the fused image, by default 512

    Returns
    -------
    SpatialImage
        Fused image.
    """

    ndim = si_utils.get_ndim_from_sim(sims[0])
    sdims = si_utils.get_spatial_dims_from_sim(sims[0])
    nsdims = [dim for dim in sims[0].dims if dim not in sdims]

    params = [
        si_utils.get_affine_from_sim(sim, transform_key=transform_key)
        for sim in sims
    ]

    params = [param_utils.invert_xparams(param) for param in params]

    if output_chunksize is None:
        default_chunksizes = si_utils.get_default_spatial_chunksizes(ndim)
        output_chunksize = tuple([default_chunksizes[dim] for dim in sdims])
    elif isinstance(output_chunksize, Iterable):
        output_chunksize = tuple(output_chunksize)
    else:
        output_chunksize = (output_chunksize,) * len(sdims)

    if output_stack_properties is None:
        if output_spacing is None:
            output_spacing = si_utils.get_spacing_from_sim(sims[0])

        output_stack_properties = calc_fusion_stack_properties(
            sims,
            params=params,
            spacing=output_spacing,
            mode=output_stack_mode,
        )

        if output_origin is not None:
            output_stack_properties["origin"] = output_origin

        if output_shape is not None:
            output_stack_properties["shape"] = output_shape

    merges = []
    for ns_coords in itertools.product(
        *tuple([sims[0].coords[nsdim] for nsdim in nsdims])
    ):
        sim_coord_dict = {
            ndsim: ns_coords[i] for i, ndsim in enumerate(nsdims)
        }
        params_coord_dict = {
            ndsim: ns_coords[i]
            for i, ndsim in enumerate(nsdims)
            if ndsim in params[0].dims
        }

        ssims = [sim.sel(sim_coord_dict) for sim in sims]
        sparams = [param.sel(params_coord_dict) for param in params]

        # convert ssims into dask arrays + metadata to get them
        # through fuse_field without triggering compute
        # https://dask.discourse.group/t/passing-dask-objects-to-delayed-computations-without-triggering-compute/1441
        sims_datas = [
            delayed(da.Array)(
                ssim.data.dask,
                ssim.data.name,
                ssim.data.chunks,
                ssim.data.dtype,
            )
            if isinstance(ssim.data, da.Array)
            else ssim.data
            for ssim in ssims
        ]

        sims_metas = [
            {
                "dims": ssim.dims,
                "scale": si_utils.get_spacing_from_sim(ssim),
                "translation": si_utils.get_origin_from_sim(ssim),
            }
            for ssim in ssims
        ]

        merge_d = delayed(fuse_field)(
            sims_datas,
            sims_metas,
            sparams,
            fusion_func=fusion_func,
            weights_func=weights_func,
            weights_func_kwargs=weights_func_kwargs,
            output_stack_properties=output_stack_properties,
            output_chunksize=output_chunksize,
            overlap_in_pixels=overlap_in_pixels,
            interpolation_order=interpolation_order,
        )

        # continue working with dask array
        merge_data = da.from_delayed(
            delayed(lambda x: x.data)(merge_d),
            shape=[output_stack_properties["shape"][dim] for dim in sdims],
            dtype=sims[0].dtype,
        )

        # rechunk to get a chunked dask array from the delayed object
        # (hacky, is there a better way to do this?)
        merge_data = merge_data.rechunk(output_chunksize)

        # trigger compute here
        merge_data = merge_data.map_blocks(
            lambda x: x.compute(),
            dtype=sims[0].dtype,
        )

        merge = si.to_spatial_image(
            merge_data,
            dims=sdims,
            scale=output_stack_properties["spacing"],
            translation=output_stack_properties["origin"],
        )

        merge = merge.expand_dims(nsdims)
        merge = merge.assign_coords(
            {ns_coord.name: [ns_coord.values] for ns_coord in ns_coords}
        )
        merges.append(merge)

    if len(merges) > 1:
        # suppress pandas future warning occuring within xarray.concat
        with warnings.catch_warnings():
            warnings.simplefilter(action="ignore", category=FutureWarning)

            # if sims are named, combine_by_coord returns a dataset
            res = xr.combine_by_coords([m.rename(None) for m in merges])
    else:
        res = merge

    res = si_utils.get_sim_from_xim(res)
    si_utils.set_sim_affine(
        res,
        param_utils.identity_transform(len(sdims)),
        transform_key,
    )

    # order channels in the same way as first input sim
    # (combine_by_coords may change coordinate order)

    if "c" in res.dims:
        res = res.sel({"c": sims[0].coords["c"].values})

    return res