Skip to content

Fusion

fuse(sims, transform_key=None, fusion_func=weighted_average_fusion, weights_func=None, weights_func_kwargs=None, output_spacing=None, output_stack_mode='union', output_origin=None, output_shape=None, output_stack_properties=None, output_chunksize=None, overlap_in_pixels=None, interpolation_order=1, blending_widths=None)

Fuse input views.

This function fuses all (Z)YX views ("fields") contained in the input list of images, which can additionally contain C and T dimensions.

Parameters

sims : list of SpatialImage Input views. transform_key : str, optional Which (extrinsic coordinate system) to use as transformation parameters. By default None (intrinsic coordinate system). fusion_func : func, optional Fusion function to be applied. This function receives the following inputs (as arrays if applicable): transformed_views, blending_weights, fusion_weights, params. By default weighted_average_fusion weights_func : func, optional Function to calculate fusion weights. This function receives the following inputs: transformed_views (as spatial images), params. It returns (non-normalized) fusion weights for each view. By default None. output_spacing : dict, optional Spacing of the fused image for each spatial dimension, by default None output_stack_mode : str, optional Mode to determine output stack properties. Can be one of "union", "intersection", "sample". By default "union" output_origin : dict, optional Origin of the fused image for each spatial dimension, by default None output_shape : dict, optional Shape of the fused image for each spatial dimension, by default None output_stack_properties : dict, optional Dictionary describing the output stack with keys 'spacing', 'origin', 'shape'. Other output_* are ignored if this argument is present. output_chunksize : int or dict, optional Chunksize of the dask data array of the fused image, by default 512

Returns

SpatialImage Fused image.

Source code in src/multiview_stitcher/fusion.py
def fuse(
    sims: list,
    transform_key: str = None,
    fusion_func=weighted_average_fusion,
    weights_func=None,
    weights_func_kwargs=None,
    output_spacing: dict[str, float] = None,
    output_stack_mode: str = "union",
    output_origin: dict[str, float] = None,
    output_shape: dict[str, int] = None,
    output_stack_properties: BoundingBox = None,
    output_chunksize: Union[int, dict[str, int]] = None,
    overlap_in_pixels: int = None,
    interpolation_order: int = 1,
    blending_widths: dict[str, float] = None,
):
    """

    Fuse input views.

    This function fuses all (Z)YX views ("fields") contained in the
    input list of images, which can additionally contain C and T dimensions.

    Parameters
    ----------
    sims : list of SpatialImage
        Input views.
    transform_key : str, optional
        Which (extrinsic coordinate system) to use as transformation parameters.
        By default None (intrinsic coordinate system).
    fusion_func : func, optional
        Fusion function to be applied. This function receives the following
        inputs (as arrays if applicable): transformed_views, blending_weights, fusion_weights, params.
        By default weighted_average_fusion
    weights_func : func, optional
        Function to calculate fusion weights. This function receives the
        following inputs: transformed_views (as spatial images), params.
        It returns (non-normalized) fusion weights for each view.
        By default None.
    output_spacing : dict, optional
        Spacing of the fused image for each spatial dimension, by default None
    output_stack_mode : str, optional
        Mode to determine output stack properties. Can be one of
        "union", "intersection", "sample". By default "union"
    output_origin : dict, optional
        Origin of the fused image for each spatial dimension, by default None
    output_shape : dict, optional
        Shape of the fused image for each spatial dimension, by default None
    output_stack_properties : dict, optional
        Dictionary describing the output stack with keys
        'spacing', 'origin', 'shape'. Other output_* are ignored
        if this argument is present.
    output_chunksize : int or dict, optional
        Chunksize of the dask data array of the fused image, by default 512

    Returns
    -------
    SpatialImage
        Fused image.
    """

    ndim = si_utils.get_ndim_from_sim(sims[0])
    sdims = si_utils.get_spatial_dims_from_sim(sims[0])
    nsdims = [dim for dim in sims[0].dims if dim not in sdims]

    if output_chunksize is None:
        output_chunksize = si_utils.get_default_spatial_chunksizes(ndim)

    params = [
        si_utils.get_affine_from_sim(sim, transform_key=transform_key)
        for sim in sims
    ]

    if output_chunksize is None:
        output_chunksize = si_utils.get_default_spatial_chunksizes(ndim)
        # output_chunksize = tuple([default_chunksizes[dim] for dim in sdims])
    elif isinstance(output_chunksize, int):
        output_chunksize = {dim: output_chunksize for dim in sdims}

    if output_stack_properties is None:
        if output_spacing is None:
            output_spacing = si_utils.get_spacing_from_sim(sims[0])

        output_stack_properties = calc_fusion_stack_properties(
            sims,
            params=params,
            spacing=output_spacing,
            mode=output_stack_mode,
        )

        if output_origin is not None:
            output_stack_properties["origin"] = output_origin

        if output_shape is not None:
            output_stack_properties["shape"] = output_shape

    # determine overlap from weights method
    # (soon: fusion methods will also require overlap)
    overlap_in_pixels = 0
    if weights_func is not None:
        overlap_in_pixels = np.max(
            [
                overlap_in_pixels,
                weights.calculate_required_overlap(
                    weights_func, weights_func_kwargs
                ),
            ]
        )

    # calculate output chunk bounding boxes
    output_chunk_bbs, block_indices = mv_graph.get_chunk_bbs(
        output_stack_properties, output_chunksize
    )

    # add overlap to output chunk bounding boxes
    output_chunk_bbs_with_overlap = [
        output_chunk_bb
        | {
            "origin": {
                dim: output_chunk_bb["origin"][dim]
                - overlap_in_pixels * output_stack_properties["spacing"][dim]
                for dim in sdims
            }
        }
        | {
            "shape": {
                dim: output_chunk_bb["shape"][dim] + 2 * overlap_in_pixels
                for dim in sdims
            }
        }
        for output_chunk_bb in output_chunk_bbs
    ]

    views_bb = [si_utils.get_stack_properties_from_sim(sim) for sim in sims]

    merges = []
    for ns_coords in itertools.product(
        *tuple([sims[0].coords[nsdim] for nsdim in nsdims])
    ):
        sim_coord_dict = {
            ndsim: ns_coords[i] for i, ndsim in enumerate(nsdims)
        }
        params_coord_dict = {
            ndsim: ns_coords[i]
            for i, ndsim in enumerate(nsdims)
            if ndsim in params[0].dims
        }

        # ssims = [sim.sel(sim_coord_dict) for sim in sims]
        sparams = [param.sel(params_coord_dict) for param in params]

        # should this be done within the loop over output chunks?
        fix_dims = []
        for dim in sdims:
            other_dims = [odim for odim in sdims if odim != dim]
            if (
                any((param.sel(x_in=dim, x_out=dim) - 1) for param in sparams)
                or any(
                    any(param.sel(x_in=dim, x_out=other_dims))
                    for param in sparams
                )
                or any(
                    any(param.sel(x_in=other_dims, x_out=dim))
                    for param in sparams
                )
                or any(
                    output_stack_properties["spacing"][dim]
                    - views_bb[iview]["spacing"][dim]
                    for iview in range(len(sims))
                )
                or any(
                    float(
                        output_stack_properties["origin"][dim]
                        - param.sel(x_in=dim, x_out="1")
                    )
                    % output_stack_properties["spacing"][dim]
                    for param in sparams
                )
            ):
                continue
            fix_dims.append(dim)

        fused_output_chunks = np.empty(
            np.max(block_indices, 0) + 1, dtype=object
        )
        for output_chunk_bb, output_chunk_bb_with_overlap, block_index in zip(
            output_chunk_bbs, output_chunk_bbs_with_overlap, block_indices
        ):
            # calculate relevant slices for each output chunk
            # this is specific to each non spatial coordinate
            views_overlap_bb = [
                mv_graph.get_overlap_for_bbs(
                    target_bb=output_chunk_bb_with_overlap,
                    query_bbs=[view_bb],
                    param=sparams[iview],
                    additional_extent_in_pixels={
                        dim: 0 if dim in fix_dims else int(interpolation_order)
                        for dim in sdims
                    },
                )[0]
                for iview, view_bb in enumerate(views_bb)
            ]

            # append to output
            relevant_view_indices = np.where(
                [
                    view_overlap_bb is not None
                    for view_overlap_bb in views_overlap_bb
                ]
            )[0]

            if not len(relevant_view_indices):
                fused_output_chunks[tuple(block_index)] = da.zeros(
                    tuple([output_chunk_bb["shape"][dim] for dim in sdims]),
                    dtype=sims[0].dtype,
                )
                continue

            tol = 1e-6
            sims_slices = [
                sims[iview].sel(
                    sim_coord_dict
                    | {
                        dim: slice(
                            views_overlap_bb[iview]["origin"][dim] - tol,
                            views_overlap_bb[iview]["origin"][dim]
                            + (views_overlap_bb[iview]["shape"][dim] - 1)
                            * views_overlap_bb[iview]["spacing"][dim]
                            + tol,
                        )
                        for dim in sdims
                    },
                    drop=True,
                )
                for iview in relevant_view_indices
            ]

            # determine whether to fuse plany by plane
            #  to avoid weighting edge artifacts
            # fuse planewise if:
            # - z dimension is present
            # - params don't affect z dimension
            # - shape in z dimension is 1 (i.e. only one plane)
            # (the last criterium above could be dropped if we find a way
            # (to propagate metadata through xr.apply_ufunc)

            if (
                "z" in fix_dims
                and output_chunk_bb_with_overlap["shape"]["z"] == 1
            ):
                fuse_planewise = True

                sims_slices = [sim.isel(z=0) for sim in sims_slices]
                tmp_params = [
                    sparams[iview].sel(
                        x_in=["y", "x", "1"],
                        x_out=["y", "x", "1"],
                    )
                    for iview in relevant_view_indices
                ]

                output_chunk_bb_with_overlap = mv_graph.project_bb_along_dim(
                    output_chunk_bb_with_overlap, dim="z"
                )

                full_view_bbs = [
                    mv_graph.project_bb_along_dim(views_bb[iview], dim="z")
                    for iview in relevant_view_indices
                ]

            else:
                fuse_planewise = False
                tmp_params = [
                    sparams[iview] for iview in relevant_view_indices
                ]
                full_view_bbs = [
                    views_bb[iview] for iview in relevant_view_indices
                ]

            fused_output_chunk = delayed(
                lambda append_leading_axis, **kwargs: fuse_np(**kwargs)[
                    np.newaxis
                ]
                if append_leading_axis
                else fuse_np(**kwargs),
            )(
                append_leading_axis=fuse_planewise,
                sims=sims_slices,
                params=tmp_params,
                output_properties=output_chunk_bb_with_overlap,
                fusion_func=fusion_func,
                weights_func=weights_func,
                weights_func_kwargs=weights_func_kwargs,
                trim_overlap_in_pixels=overlap_in_pixels,
                interpolation_order=1,
                full_view_bbs=full_view_bbs,
                blending_widths=blending_widths,
            )

            fused_output_chunk = da.from_delayed(
                fused_output_chunk,
                shape=tuple([output_chunk_bb["shape"][dim] for dim in sdims]),
                dtype=sims[0].dtype,
            )

            fused_output_chunks[tuple(block_index)] = fused_output_chunk

        fused = da.block(fused_output_chunks.tolist())

        merge = si.to_spatial_image(
            fused,
            dims=sdims,
            scale=output_stack_properties["spacing"],
            translation=output_stack_properties["origin"],
        )

        merge = merge.expand_dims(nsdims)
        merge = merge.assign_coords(
            {ns_coord.name: [ns_coord.values] for ns_coord in ns_coords}
        )
        merges.append(merge)

    if len(merges) > 1:
        # suppress pandas future warning occuring within xarray.concat
        with warnings.catch_warnings():
            warnings.simplefilter(action="ignore", category=FutureWarning)

            # if sims are named, combine_by_coord returns a dataset
            res = xr.combine_by_coords([m.rename(None) for m in merges])
    else:
        res = merge

    res = si_utils.get_sim_from_xim(res)
    si_utils.set_sim_affine(
        res,
        param_utils.identity_transform(len(sdims)),
        transform_key,
    )

    # order channels in the same way as first input sim
    # (combine_by_coords may change coordinate order)
    if "c" in res.dims:
        res = res.sel({"c": sims[0].coords["c"].values})

    return res