def fuse(
sims: list,
transform_key: str = None,
fusion_func=weighted_average_fusion,
weights_func=None,
weights_func_kwargs=None,
output_spacing: dict[str, float] = None,
output_stack_mode: str = "union",
output_origin: dict[str, float] = None,
output_shape: dict[str, int] = None,
output_stack_properties: BoundingBox = None,
output_chunksize: Union[int, dict[str, int]] = None,
overlap_in_pixels: int = None,
interpolation_order: int = 1,
blending_widths: dict[str, float] = None,
):
"""
Fuse input views.
This function fuses all (Z)YX views ("fields") contained in the
input list of images, which can additionally contain C and T dimensions.
Parameters
----------
sims : list of SpatialImage
Input views.
transform_key : str, optional
Which (extrinsic coordinate system) to use as transformation parameters.
By default None (intrinsic coordinate system).
fusion_func : func, optional
Fusion function to be applied. This function receives the following
inputs (as arrays if applicable): transformed_views, blending_weights, fusion_weights, params.
By default weighted_average_fusion
weights_func : func, optional
Function to calculate fusion weights. This function receives the
following inputs: transformed_views (as spatial images), params.
It returns (non-normalized) fusion weights for each view.
By default None.
output_spacing : dict, optional
Spacing of the fused image for each spatial dimension, by default None
output_stack_mode : str, optional
Mode to determine output stack properties. Can be one of
"union", "intersection", "sample". By default "union"
output_origin : dict, optional
Origin of the fused image for each spatial dimension, by default None
output_shape : dict, optional
Shape of the fused image for each spatial dimension, by default None
output_stack_properties : dict, optional
Dictionary describing the output stack with keys
'spacing', 'origin', 'shape'. Other output_* are ignored
if this argument is present.
output_chunksize : int or dict, optional
Chunksize of the dask data array of the fused image, by default 512
Returns
-------
SpatialImage
Fused image.
"""
ndim = si_utils.get_ndim_from_sim(sims[0])
sdims = si_utils.get_spatial_dims_from_sim(sims[0])
nsdims = [dim for dim in sims[0].dims if dim not in sdims]
if output_chunksize is None:
output_chunksize = si_utils.get_default_spatial_chunksizes(ndim)
params = [
si_utils.get_affine_from_sim(sim, transform_key=transform_key)
for sim in sims
]
if output_chunksize is None:
output_chunksize = si_utils.get_default_spatial_chunksizes(ndim)
# output_chunksize = tuple([default_chunksizes[dim] for dim in sdims])
elif isinstance(output_chunksize, int):
output_chunksize = {dim: output_chunksize for dim in sdims}
if output_stack_properties is None:
if output_spacing is None:
output_spacing = si_utils.get_spacing_from_sim(sims[0])
output_stack_properties = calc_fusion_stack_properties(
sims,
params=params,
spacing=output_spacing,
mode=output_stack_mode,
)
if output_origin is not None:
output_stack_properties["origin"] = output_origin
if output_shape is not None:
output_stack_properties["shape"] = output_shape
# determine overlap from weights method
# (soon: fusion methods will also require overlap)
overlap_in_pixels = 0
if weights_func is not None:
overlap_in_pixels = np.max(
[
overlap_in_pixels,
weights.calculate_required_overlap(
weights_func, weights_func_kwargs
),
]
)
# calculate output chunk bounding boxes
output_chunk_bbs, block_indices = mv_graph.get_chunk_bbs(
output_stack_properties, output_chunksize
)
# add overlap to output chunk bounding boxes
output_chunk_bbs_with_overlap = [
output_chunk_bb
| {
"origin": {
dim: output_chunk_bb["origin"][dim]
- overlap_in_pixels * output_stack_properties["spacing"][dim]
for dim in sdims
}
}
| {
"shape": {
dim: output_chunk_bb["shape"][dim] + 2 * overlap_in_pixels
for dim in sdims
}
}
for output_chunk_bb in output_chunk_bbs
]
views_bb = [si_utils.get_stack_properties_from_sim(sim) for sim in sims]
merges = []
for ns_coords in itertools.product(
*tuple([sims[0].coords[nsdim] for nsdim in nsdims])
):
sim_coord_dict = {
ndsim: ns_coords[i] for i, ndsim in enumerate(nsdims)
}
params_coord_dict = {
ndsim: ns_coords[i]
for i, ndsim in enumerate(nsdims)
if ndsim in params[0].dims
}
# ssims = [sim.sel(sim_coord_dict) for sim in sims]
sparams = [param.sel(params_coord_dict) for param in params]
# should this be done within the loop over output chunks?
fix_dims = []
for dim in sdims:
other_dims = [odim for odim in sdims if odim != dim]
if (
any((param.sel(x_in=dim, x_out=dim) - 1) for param in sparams)
or any(
any(param.sel(x_in=dim, x_out=other_dims))
for param in sparams
)
or any(
any(param.sel(x_in=other_dims, x_out=dim))
for param in sparams
)
or any(
output_stack_properties["spacing"][dim]
- views_bb[iview]["spacing"][dim]
for iview in range(len(sims))
)
or any(
float(
output_stack_properties["origin"][dim]
- param.sel(x_in=dim, x_out="1")
)
% output_stack_properties["spacing"][dim]
for param in sparams
)
):
continue
fix_dims.append(dim)
fused_output_chunks = np.empty(
np.max(block_indices, 0) + 1, dtype=object
)
for output_chunk_bb, output_chunk_bb_with_overlap, block_index in zip(
output_chunk_bbs, output_chunk_bbs_with_overlap, block_indices
):
# calculate relevant slices for each output chunk
# this is specific to each non spatial coordinate
views_overlap_bb = [
mv_graph.get_overlap_for_bbs(
target_bb=output_chunk_bb_with_overlap,
query_bbs=[view_bb],
param=sparams[iview],
additional_extent_in_pixels={
dim: 0 if dim in fix_dims else int(interpolation_order)
for dim in sdims
},
)[0]
for iview, view_bb in enumerate(views_bb)
]
# append to output
relevant_view_indices = np.where(
[
view_overlap_bb is not None
for view_overlap_bb in views_overlap_bb
]
)[0]
if not len(relevant_view_indices):
fused_output_chunks[tuple(block_index)] = da.zeros(
tuple([output_chunk_bb["shape"][dim] for dim in sdims]),
dtype=sims[0].dtype,
)
continue
tol = 1e-6
sims_slices = [
sims[iview].sel(
sim_coord_dict
| {
dim: slice(
views_overlap_bb[iview]["origin"][dim] - tol,
views_overlap_bb[iview]["origin"][dim]
+ (views_overlap_bb[iview]["shape"][dim] - 1)
* views_overlap_bb[iview]["spacing"][dim]
+ tol,
)
for dim in sdims
},
drop=True,
)
for iview in relevant_view_indices
]
# determine whether to fuse plany by plane
# to avoid weighting edge artifacts
# fuse planewise if:
# - z dimension is present
# - params don't affect z dimension
# - shape in z dimension is 1 (i.e. only one plane)
# (the last criterium above could be dropped if we find a way
# (to propagate metadata through xr.apply_ufunc)
if (
"z" in fix_dims
and output_chunk_bb_with_overlap["shape"]["z"] == 1
):
fuse_planewise = True
sims_slices = [sim.isel(z=0) for sim in sims_slices]
tmp_params = [
sparams[iview].sel(
x_in=["y", "x", "1"],
x_out=["y", "x", "1"],
)
for iview in relevant_view_indices
]
output_chunk_bb_with_overlap = mv_graph.project_bb_along_dim(
output_chunk_bb_with_overlap, dim="z"
)
full_view_bbs = [
mv_graph.project_bb_along_dim(views_bb[iview], dim="z")
for iview in relevant_view_indices
]
else:
fuse_planewise = False
tmp_params = [
sparams[iview] for iview in relevant_view_indices
]
full_view_bbs = [
views_bb[iview] for iview in relevant_view_indices
]
fused_output_chunk = delayed(
lambda append_leading_axis, **kwargs: fuse_np(**kwargs)[
np.newaxis
]
if append_leading_axis
else fuse_np(**kwargs),
)(
append_leading_axis=fuse_planewise,
sims=sims_slices,
params=tmp_params,
output_properties=output_chunk_bb_with_overlap,
fusion_func=fusion_func,
weights_func=weights_func,
weights_func_kwargs=weights_func_kwargs,
trim_overlap_in_pixels=overlap_in_pixels,
interpolation_order=1,
full_view_bbs=full_view_bbs,
blending_widths=blending_widths,
)
fused_output_chunk = da.from_delayed(
fused_output_chunk,
shape=tuple([output_chunk_bb["shape"][dim] for dim in sdims]),
dtype=sims[0].dtype,
)
fused_output_chunks[tuple(block_index)] = fused_output_chunk
fused = da.block(fused_output_chunks.tolist())
merge = si.to_spatial_image(
fused,
dims=sdims,
scale=output_stack_properties["spacing"],
translation=output_stack_properties["origin"],
)
merge = merge.expand_dims(nsdims)
merge = merge.assign_coords(
{ns_coord.name: [ns_coord.values] for ns_coord in ns_coords}
)
merges.append(merge)
if len(merges) > 1:
# suppress pandas future warning occuring within xarray.concat
with warnings.catch_warnings():
warnings.simplefilter(action="ignore", category=FutureWarning)
# if sims are named, combine_by_coord returns a dataset
res = xr.combine_by_coords([m.rename(None) for m in merges])
else:
res = merge
res = si_utils.get_sim_from_xim(res)
si_utils.set_sim_affine(
res,
param_utils.identity_transform(len(sdims)),
transform_key,
)
# order channels in the same way as first input sim
# (combine_by_coords may change coordinate order)
if "c" in res.dims:
res = res.sel({"c": sims[0].coords["c"].values})
return res