Cellpose/StarDist Segmentation

In this notebook we will use PyImageJ, Cellpose and StarDist to produce index images (i.e. label images, where objects are labeled/indexed with a unique value) of HeLa cell cytoplasms and nuclei. The resulting label images are then converted into ImgLib2 and legacy ImageJ regions of interests (ROIs). Note that this notebook requires Cellpose, StarDist and their dependencies. If you haven’t already installed these packages into your environment, you can do so by running the cell blow.

!mamba install -y -c conda-forge tensorflow stardist
!pip install cellpose

⚠️Warning: The label image to ROIs features are not yet included in the current released version of PyImageJ (version 1.4.1). Until these features are included in the next release of PyImageJ, you can use them by installing the developer version of PyImageJ by cloning the PyImageJ GitHub repository and installing PyImageJ with pip install -e . in the cloned repository. Be sure to install PyImageJ in the correct conda/mamba environment. This notebook will not work with PyImageJ 1.4.1.

import sys
import imagej
import imagej.convert as convert
import numpy as np
import matplotlib.pyplot as plt
from cellpose import models
from csbdeep.utils import normalize
from stardist.models import StarDist2D

# initialize ImageJ
mode = "interactive"
if sys.platform == "darwin":
    mode = "headless"

ij = imagej.init(mode=mode)
print(f"ImageJ2 version: {ij.getVersion()}")
2024-01-08 14:06:27.145566: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-01-08 14:06:27.167676: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
ImageJ2 version: 2.14.0/1.54f

Before we get started with Cellpose and StarDist, let’s also add some useful methods. Below is a simple way to filter an index image’s labels based on a range of minimum and maximum values that define the number of pixels that make up a given label. Its important to note that this is not the same as area of the label/object, but instead simply the number of pixels that the label/object contains.

def filter_index_image(narr: np.ndarray, min_size:int, max_size:int):
    Filter an index image's labels with a pixel size range.
    unique = np.unique(narr)
    for label in unique:
        if label == 0:
            # skip the background
        # create a crop for each label
        bbox = get_bounding_box(np.where(narr == label))
        bbox_crop = narr[bbox[0]:bbox[2] + 1, bbox[1]:bbox[3] + 1].copy()
        bbox_crop[bbox_crop != label] = 0

        # get the number of pixels in label
        bbox_crop = bbox_crop.astype(bool)
        label_size = np.sum(bbox_crop)

        if not min_size <= label_size <= max_size:
            narr[narr == label] = 0
    return narr

def get_bounding_box(indices: np.ndarray):
    Get the bounding box coordinates from a the label indices.
    # get min and max bounds of indices array
    min_row = np.min(indices[0])
    min_col = np.min(indices[1])
    max_row = np.max(indices[0])
    max_col = np.max(indices[1])

    return (min_row, min_col, max_row, max_col)

Next let’s open a 3D dataset (X, Y, Channel) of HeLa cells expressing a YFP tagged APOBEC3G (A3G) protein with a nuclear stain. This dataset has two channels:

  1. DAPI (nuclear stain)

  2. YFP (cytoplasmic A3G)

Both channels can be displayed side by side with matplotlib.pyplot as shown below. Note that we first need to convert the image data from a Java object (net.imagej.DefaultDataset) to either a numpy.ndarray or, if dimensional information is available, xarray.DataArray. Here we will let PyImageJ decide what’s best as both a numpy.ndarray and xarray.DataArray will work with matplotlib.pyplot.

# open image data and convert to Python from Java
data = ij.io().open('https://media.imagej.net/pyimagej/3d/hela_a3g.tif')
xdata = ij.py.from_java(data)

# create a matplotlib subplot with 1 row and 2 columns
fig, axes = plt.subplots(1, 2, figsize=(8, 5))

# show each channel in each column of the matplotlib subplot
# xdata's image dimensions are: (row, col, ch)
axes[0].imshow(xdata[:, :, 0], cmap="gray")
axes[1].imshow(xdata[:, :, 1] * 125, cmap="gray")

# display the plot

Note: The cytoplasmic channel of this dataset is a little dark and the image is hard to see. Here we are multiplying the cytoplasmic channel input image by 125 to make it brighter. This only changes how the image is displayed and does not alter the original data.

With the data loaded we can now run StarDist on the nuclear channel and Cellpose on the cytoplasmic channel.

# run StarDist on nuclei channel
model = StarDist2D.from_pretrained('2D_versatile_fluo')
nuc_labels, _ = model.predict_instances(normalize(xdata[:, :, 0]))

# run Cellpose on cytoplasm (grayscale)
model = models.CellposeModel(gpu=False, model_type='cyto')
ch = [0, 0]
cyto_labels = model.eval(xdata[:, :, 1].data, channels=ch, diameter=72.1)
Found model '2D_versatile_fluo' for 'StarDist2D'.
Loading network weights from 'weights_best.h5'.
Loading thresholds from 'thresholds.json'.
Using default values: prob_thresh=0.479071, nms_thresh=0.3.

StarDist and Cellpose both produce index images (i.e. images where labels/objects are defined by a single unique value, an index) but they return them in different objects. StarDist will return the index image as numpy.ndarray as well as a dict which contains additional information such as object probility values. Cellpose returns a Tuple with the numpy.ndarray index image as well as the vector flow data. Additional information about the outputs from StarDist can be found here. For more information on Cellpose’s output, see this page.

Now that both StarDist and Cellpose have finished segmenting the nuclear and cytoplasmic channels respectively, lets look at the results:

# display StarDist results
fig_n, axes_n = plt.subplots(1, 2, figsize=(8, 5))
axes_n[0].imshow(xdata[:, :, 0], cmap="gray")
axes_n[1].set_title('StarDist output')
# display cytoplasmic results
fig_c, axes_c = plt.subplots(1, 2, figsize=(8, 5))
axes_c[0].imshow(xdata[:, :, 1] * 125, cmap="gray")
axes_c[1].imshow(cyto_labels[0], cmap="cividis")
axes_c[1].set_title('Cellpose output')

Unfortunately, the input nuclear data for StarDist is noisy and has debris in the nuclear channel that we do not want. These objects have been segemented by StarDist and can been seen in the output index image. To clean up the StarDist result, we can use the filter_index_image() function we added earlier to this notebook to filter the small micronuclei and small debris from the index image.

The filter_index_image() function accepts 3 parameters: narr (the index image), min_size and max_size. Here, we can guess that the all the objects we wish to remove are smaller than 500 pixels. We can also guess that all of our nuclei are smaller than 10,000 pixels.

# filter small labels out from the nucleus index image
# by specifiying a pixel size range
ij.py.show(filter_index_image(nuc_labels, 500, 10000))

Finally let’s display the two segmentations side-by-side:

# display StarDist and Cellpose results
fig_c, axes_c = plt.subplots(1, 2, figsize=(8, 5))
axes_c[0].set_title('StarDist output')
axes_c[1].imshow(cyto_labels[0], cmap="cividis")
axes_c[1].set_title('Cellpose output')
Text(0.5, 1.0, 'Cellpose output')

After obtaining the index images from StarDist and Cellpose respectively, we now have the option to convert the Cellpose and StarDist label images into two ImageJ/Fiji usable forms:

  • Convert index image labels to DefaultROITree

  • Convert index image labels to ROIs in the RoiManager

Let’s first convert the index images to DefaultROITrees that contain ImgLib2 ROIs. A DefaultROITree can be used in headless environments and can be attached to net.imagej.Datasets. ImgLib2 ROIs can easily be converted to the legacy ImageJ ROIs using the ConvertService.

# convert to ImgLib2 ROI in a ROITree
nuc_roi_tree = convert.index_img_to_roi_tree(ij, nuc_labels)
cyto_roi_tree = convert.index_img_to_roi_tree(ij, cyto_labels[0])

# print the contents of the ROITree (nuclear ROIs)
for i in range(len(nuc_roi_tree.children())):
DefaultWritablePolygon2D [(125.0, 69.0) -- (179.0, 127.0)]
DefaultWritablePolygon2D [(8.0, 151.0) -- (70.0, 215.0)]
DefaultWritablePolygon2D [(253.0, 119.0) -- (307.0, 182.0)]
DefaultWritablePolygon2D [(190.0, 82.0) -- (254.0, 129.0)]
DefaultWritablePolygon2D [(53.0, 31.0) -- (110.0, 92.0)]
DefaultWritablePolygon2D [(76.0, 142.0) -- (128.0, 216.0)]
DefaultWritablePolygon2D [(58.0, 254.0) -- (125.0, 305.0)]
DefaultWritablePolygon2D [(286.0, 248.0) -- (329.0, 290.0)]
DefaultWritablePolygon2D [(226.0, 291.0) -- (275.0, 329.0)]
DefaultWritablePolygon2D [(106.0, 292.0) -- (177.0, 329.0)]
DefaultWritablePolygon2D [(173.0, 0.0) -- (234.0, 35.0)]
DefaultWritablePolygon2D [(186.0, 204.0) -- (258.0, 251.0)]
DefaultWritablePolygon2D [(313.0, 147.0) -- (329.0, 198.0)]

⚠️Warning: The following section requires an interactive PyImageJ session. Unfortnately macOS does not support PyImageJ’s interactive mode, thus this notebook can only be run in headless mode on macOS. Visit the interactive mode on macOS section in the PyImageJ troubleshooting documentation for more information.

# convert a single ImgLib2 roi to a legacy ImageJ ROI with the ConvertService.
import scyjava as sj

imglib_polygon_roi = nuc_roi_tree.children().get(0).data()
ij_polygon_roi = ij.convert().convert(imglib_polygon_roi, sj.jimport('ij.gui.PolygonRoi'))
<java class 'net.imagej.legacy.convert.roi.polygon2d.Polygon2DWrapper'>

Unlike a DefaultROITree, converting Index image labels to ROIs in the RoiManager requires an interactive or GUI PyImageJ session. The RoiManager does not function in headless environments.

After the RoiManager has been populated with ROIs, display the source image (e.g. ij.ui().show()) and check the Show All check box in the RoiManager to transfer the ROIs onto the displayed image.

# display the input data, select channel 2 and enhance the contrast
imp = ij.WindowManager.getCurrentImage()
ij.IJ.run(imp, "Enhance Contrast", "saturated=0.35")

# convert index images to ImageJ ROI in RoiManager
convert.index_img_to_roi_manager(ij, nuc_labels)
convert.index_img_to_roi_manager(ij, cyto_labels[0])

Cellpose StarDist result