"""
Utility functions for querying and manipulating dimensional axis metadata.
"""
import logging
from typing import List, Tuple, Union
import numpy as np
import scyjava as sj
import xarray as xr
from jpype import JException, JObject
from imagej._java import jc
from imagej.images import is_arraylike as _is_arraylike
from imagej.images import is_xarraylike as _is_xarraylike
_logger = logging.getLogger(__name__)
[docs]
def get_axes(
rai: "jc.RandomAccessibleInterval",
) -> List["jc.CalibratedAxis"]:
"""
imagej.dims.get_axes(image) is deprecated. Use image.dim_axes instead.
"""
_logger.warning(
"imagej.dims.get_axes(image) is deprecated. Use image.dim_axes instead."
)
return [
(JObject(rai.axis(idx), jc.CalibratedAxis))
for idx in range(rai.numDimensions())
]
[docs]
def get_axis_types(rai: "jc.RandomAccessibleInterval") -> List["jc.AxisType"]:
"""
imagej.dims.get_axis_types(image) is deprecated. Use this code instead:
axis_types = [axis.type() for axis in image.dim_axes]
"""
_logger.warning(
"imagej.dims.get_axis_types(image) is deprecated. Use this code instead:\n"
+ "\n"
+ " axis_types = [axis.type() for axis in image.dim_axes]"
)
if _has_axis(rai):
rai_dims = get_dims(rai)
for i in range(len(rai_dims)):
if rai_dims[i].lower() == "c":
rai_dims[i] = "Channel"
if rai_dims[i].lower() == "t":
rai_dims[i] = "Time"
rai_axis_types = []
for i in range(len(rai_dims)):
rai_axis_types.append(jc.Axes.get(rai_dims[i]))
return rai_axis_types
else:
raise AttributeError(
f"Unsupported Java type: {type(rai)} has no axis attribute."
)
[docs]
def get_dims(image) -> List[str]:
"""
imagej.dims.get_dims(image) is deprecated. Use image.shape and image.dims instead.
"""
_logger.warning(
"imagej.dims.get_dims(image) is deprecated. Use image.shape and image.dims "
"instead."
)
if _is_xarraylike(image):
return image.dims
if _is_arraylike(image):
return image.shape
if hasattr(image, "axis"):
axes = get_axes(image)
return _get_axis_labels(axes)
if isinstance(image, jc.RandomAccessibleInterval):
return list(image.dimensionsAsLongArray())
if isinstance(image, jc.ImagePlus):
shape = image.getDimensions()
return [axis for axis in shape if axis > 1]
raise TypeError(f"Unsupported image type: {image}\n No dimensions or shape found.")
[docs]
def get_shape(image) -> List[int]:
"""
imagej.dims.get_shape(image) is deprecated. Use image.shape instead.
"""
_logger.warning(
"imagej.dims.get_shape(image) is deprecated. Use image.shape instead."
)
if _is_arraylike(image):
return list(image.shape)
if not sj.isjava(image):
raise TypeError("Unsupported type: " + str(type(image)))
if isinstance(image, jc.Dimensions):
return [image.dimension(d) for d in range(image.numDimensions())]
if isinstance(image, jc.ImagePlus):
shape = image.getDimensions()
return [axis for axis in shape if axis > 1]
raise TypeError(f"Unsupported Java type: {str(sj.jclass(image).getName())}")
[docs]
def reorganize(
rai: "jc.RandomAccessibleInterval", permute_order: List[int]
) -> "jc.ImgPlus":
"""Reorganize the dimension order of a RandomAccessibleInterval.
Permute the dimension order of an input RandomAccessibleInterval using
a List of ints (i.e. permute_order) to determine the shape of the output ImgPlus.
:param rai: A RandomAccessibleInterval,
:param permute_order: List of int in which to permute the RandomAccessibleInterval.
:return: A permuted ImgPlus.
"""
img = _dataset_to_imgplus(rai)
# check for dimension count mismatch
dim_num = rai.numDimensions()
if len(permute_order) != dim_num:
raise ValueError(
f"Mismatched dimension count: {len(permute_order)} != {dim_num}"
)
# get ImageJ resources
ImgView = sj.jimport("net.imglib2.img.ImgView")
# copy dimensional axes into
axes = []
for i in range(dim_num):
old_dim = permute_order[i]
axes.append(img.axis(old_dim))
# repeatedly permute the image dimensions into shape
rai = img.getImg()
for i in range(dim_num):
old_dim = permute_order[i]
if old_dim == i:
continue
rai = jc.Views.permute(rai, old_dim, i)
# update index mapping acccordingly...this is hairy ;-)
for j in range(dim_num):
if permute_order[j] == i:
permute_order[j] = old_dim
break
permute_order[i] = i
return jc.ImgPlus(ImgView.wrap(rai), img.getName(), axes)
[docs]
def prioritize_rai_axes_order(
axis_types: List["jc.AxisType"], ref_order: List["jc.AxisType"]
) -> List[int]:
"""Prioritize the axes order to match a reference order.
The input List of 'AxisType' from the image to be permuted
will be prioritized to match (where dimensions exist) to
a reference order (e.g. _python_rai_ref_order).
:param axis_types: List of 'net.imagej.axis.AxisType' from image.
:param ref_order: List of 'net.imagej.axis.AxisType' from reference order.
:return: List of int for permuting a image (e.g. [0, 4, 3, 1, 2])
"""
permute_order = []
for axis in ref_order:
for i in range(len(axis_types)):
if axis == axis_types[i]:
permute_order.append(i)
for i in range(len(axis_types)):
if axis_types[i] not in ref_order:
permute_order.append(i)
return permute_order
def _assign_axes(
xarr: xr.DataArray,
) -> List[Union["jc.DefaultLinearAxis", "jc.EnumeratedAxis"]]:
"""
Obtain xarray axes names, origin, scale and convert into ImageJ Axis. Supports both
DefaultLinearAxis and the newer EnumeratedAxis.
Note that, in many cases, there are small discrepancies between the coordinates.
This can either be actually within the data, or it can be from floating point math
errors. In this case, we delegate to numpy.isclose to tell us whether our
coordinates are linear or not. If our coordinates are nonlinear, and the
EnumeratedAxis type is available, we will use it. Otherwise, this function
returns a DefaultLinearAxis.
:param xarr: xarray that holds the data.
:return: A list of ImageJ Axis with the specified origin and scale.
"""
axes = [""] * xarr.ndim
for dim in xarr.dims:
axis_str = _convert_dim(dim, "java")
ax_type = jc.Axes.get(axis_str)
ax_num = _get_axis_num(xarr, dim)
coords_arr = xarr.coords[dim]
# coerce numeric scale
if not _is_numeric_scale(coords_arr):
_logger.warning(
f"The {ax_type.getLabel()} axis is non-numeric and is translated "
"to a linear index."
)
coords_arr = [np.double(x) for x in np.arange(len(xarr.coords[dim]))]
else:
coords_arr = coords_arr.to_numpy().astype(np.double)
# check scale linearity
diffs = np.diff(coords_arr)
linear: bool = diffs.size and np.all(np.isclose(diffs, diffs[0]))
if not linear:
try:
j_coords = [jc.Double(x) for x in coords_arr]
axes[ax_num] = jc.EnumeratedAxis(ax_type, sj.to_java(j_coords))
except (JException, TypeError):
# if EnumeratedAxis not available - use DefaultLinearAxis
axes[ax_num] = _get_default_linear_axis(coords_arr, ax_type)
else:
axes[ax_num] = _get_default_linear_axis(coords_arr, ax_type)
return axes
def _ends_with_channel_axis(xarr: xr.DataArray) -> bool:
"""Check if xarray.DataArray ends in the channel dimension.
:param xarr: xarray.DataArray to check.
:return: Boolean
"""
ends_with_axis = xarr.dims[len(xarr.dims) - 1].lower() in ["c", "ch", "channel"]
return ends_with_axis
def _get_axis_num(xarr: xr.DataArray, axis):
"""
Get the xarray -> java axis number due to inverted axis order for C style numpy
arrays (default)
:param xarr: Xarray to convert
:param axis: Axis number to convert
:return: Axis idx in java
"""
py_axnum = xarr.get_axis_num(axis)
if np.isfortran(xarr.values):
return py_axnum
if _ends_with_channel_axis(xarr):
if axis == len(xarr.dims) - 1:
return axis
else:
return len(xarr.dims) - py_axnum - 2
else:
return len(xarr.dims) - py_axnum - 1
def _get_axes_coords(
axes: List["jc.CalibratedAxis"], dims: List[str], shape: Tuple[int]
) -> dict:
"""
Get xarray style coordinate list dictionary from a dataset
:param axes: List of ImageJ axes
:param dims: List of axes labels for each dataset axis
:param shape: F-style, or reversed C-style, shape of axes numpy array.
:return: Dictionary of coordinates for each axis.
"""
coords = {
dims[idx]: [
axes[idx].calibratedValue(position) for position in range(shape[idx])
]
for idx in range(len(dims))
}
return coords
def _get_default_linear_axis(coords_arr: np.ndarray, ax_type: "jc.AxisType"):
"""
Create a new DefaultLinearAxis with the given coordinate array and axis type.
:param coords_arr: A 1D NumPy array.
:return: An instance of net.imagej.axis.DefaultLinearAxis.
"""
scale = coords_arr[1] - coords_arr[0] if len(coords_arr) > 1 else 1
origin = coords_arr[0] if len(coords_arr) > 0 else 0
return jc.DefaultLinearAxis(ax_type, jc.Double(scale), jc.Double(origin))
def _is_numeric_scale(coords_array: np.ndarray) -> bool:
"""
Checks if the coordinates array of the given axis is numeric.
:param coords_array: A 1D NumPy array.
:return: bool
"""
return np.issubdtype(coords_array.dtype, np.number)
def _dataset_to_imgplus(rai: "jc.RandomAccessibleInterval") -> "jc.ImgPlus":
"""Get an ImgPlus from a Dataset.
Get an ImgPlus from a Dataset or just return the RandomAccessibleInterval
if its not a Dataset.
:param rai: A RandomAccessibleInterval.
:return: The ImgPlus from a Dataset.
"""
if isinstance(rai, jc.Dataset):
return rai.getImgPlus()
else:
return rai
def _get_axis_labels(axes: List["jc.CalibratedAxis"]) -> List[str]:
"""Get the axes labels from a List of 'CalibratedAxis'.
Extract the axis labels from a List of 'CalibratedAxis'.
:param axes: A List of 'CalibratedAxis'.
:return: A list of the axis labels.
"""
return [str((axes[idx].type().getLabel())) for idx in range(len(axes))]
def _python_rai_ref_order() -> List["jc.AxisType"]:
"""Get the Java style numpy reference order.
Get a List of 'AxisType' in the Python/scikitimage
preferred order. Note that this reference order is
reversed.
:return: List of dimensions in numpy preferred order.
"""
return [jc.Axes.CHANNEL, jc.Axes.X, jc.Axes.Y, jc.Axes.Z, jc.Axes.TIME]
def _convert_dim(dim: str, direction: str) -> str:
"""Convert a dimension to Python/NumPy or ImageJ convention.
Convert a single dimension to Python/NumPy or ImageJ convention by
indicating which direction ('python' or 'java'). A converted dimension
is returned.
:param dim: A dimension to be converted.
:param direction:
'python': Convert a single dimension from ImageJ to Python/NumPy convention.
'java': Convert a single dimension from Python/NumPy to ImageJ convention.
:return: A single converted dimension.
"""
if direction.lower() == "python":
return _to_pydim(dim)
elif direction.lower() == "java":
return _to_ijdim(dim)
else:
return dim
def _convert_dims(dimensions: List[str], direction: str) -> List[str]:
"""Convert a List of dimensions to Python/NumPy or ImageJ conventions.
Convert a List of dimensions to Python/Numpy or ImageJ conventions by
indicating which direction ('python' or 'java'). A List of converted
dimentions is returned.
:param dimensions: List of dimensions (e.g. X, Y, Channel, Z, Time)
:param direction:
'python': Convert dimensions from ImageJ to Python/NumPy conventions.
'java': Convert dimensions from Python/NumPy to ImageJ conventions.
:return: List of converted dimensions.
"""
new_dims = []
if direction.lower() == "python":
for dim in dimensions:
new_dims.append(_to_pydim(dim))
return new_dims
elif direction.lower() == "java":
for dim in dimensions:
new_dims.append(_to_ijdim(dim))
return new_dims
else:
return dimensions
def _validate_dim_order(dim_order: List[str], shape: tuple) -> List[str]:
"""
Validate a List of dimensions. If the dimension list is smaller
fill the rest of the list with "dim_n" (following xarrray convention).
:param dim_order: List of dimensions (e.g. X, Y, Channel, Z, Time)
:param shape: Shape image for the dimension order.
:return: List with "dim_n" dimensions added to match shape length.
"""
dim_len = len(dim_order)
shape_len = len(shape)
if dim_len < shape_len:
d = shape_len - dim_len
for i in range(d):
dim_order.append(f"dim_{i}")
return dim_order
if dim_len > shape_len:
raise ValueError(f"Expected {shape_len} dimensions but got {dim_len}.")
return dim_order
def _has_axis(rai: "jc.RandomAccessibleInterval"):
"""Check if a RandomAccessibleInterval has axes."""
if sj.isjava(rai):
return hasattr(rai, "axis")
else:
False
def _to_pydim(key: str) -> str:
"""Convert ImageJ dimension convention to Python/NumPy."""
pydims = {
"Time": "t",
"slice": "pln",
"Z": "pln",
"Y": "row",
"X": "col",
"Channel": "ch",
}
if key in pydims:
return pydims[key]
else:
return key
def _to_ijdim(key: str) -> str:
"""Convert Python/NumPy dimension convention to ImageJ."""
ijdims = {
"col": "X",
"x": "X",
"row": "Y",
"y": "Y",
"ch": "Channel",
"c": "Channel",
"pln": "Z",
"z": "Z",
"t": "Time",
}
if key in ijdims:
return ijdims[key]
else:
return key