Skip to content

Segmentation

Integration with external segmentation pipelines: SPINEPS (spine segmentation) and VibeSeg / nnU-Net (general deep learning inference).

SPINEPS

TPTBox.segmentation.spineps

get_outpaths_spineps

get_outpaths_spineps(file_path: str | Path | BIDS_FILE, dataset: str | Path | None = None, derivative_name: str = 'derivative', ignore_bids_filter: bool = True) -> dict[Literal['out_spine', 'out_spine_raw', 'out_vert', 'out_vert_raw', 'out_unc', 'out_logits', 'out_snap', 'out_ctD', 'out_snap2', 'out_debug', 'out_raw'], Path]

Return the expected output paths for a SPINEPS segmentation run.

Parameters:

Name Type Description Default
file_path str | Path | BIDS_FILE

Path to the input NIfTI image, or a BIDS_FILE object.

required
dataset str | Path | None

Optional dataset root directory. Required when file_path is a plain str or Path and a BIDS dataset root is needed.

None
derivative_name str

Name of the derivatives sub-folder used by SPINEPS.

'derivative'
ignore_bids_filter bool

If True, disable strict BIDS filename filtering.

True

Returns:

Type Description
dict[Literal['out_spine', 'out_spine_raw', 'out_vert', 'out_vert_raw', 'out_unc', 'out_logits', 'out_snap', 'out_ctD', 'out_snap2', 'out_debug', 'out_raw'], Path]

Dictionary mapping output keys (e.g. "out_spine", "out_vert") to

dict[Literal['out_spine', 'out_spine_raw', 'out_vert', 'out_vert_raw', 'out_unc', 'out_logits', 'out_snap', 'out_ctD', 'out_snap2', 'out_debug', 'out_raw'], Path]

the corresponding Path objects.

Source code in TPTBox/segmentation/spineps.py
def get_outpaths_spineps(
    file_path: str | Path | BIDS_FILE,
    dataset: str | Path | None = None,
    derivative_name: str = "derivative",
    ignore_bids_filter: bool = True,
) -> dict[
    Literal[
        "out_spine",
        "out_spine_raw",
        "out_vert",
        "out_vert_raw",
        "out_unc",
        "out_logits",
        "out_snap",
        "out_ctD",
        "out_snap2",
        "out_debug",
        "out_raw",
    ],
    Path,
]:
    """Return the expected output paths for a SPINEPS segmentation run.

    Args:
        file_path: Path to the input NIfTI image, or a ``BIDS_FILE`` object.
        dataset: Optional dataset root directory.  Required when *file_path* is a
            plain ``str`` or ``Path`` and a BIDS dataset root is needed.
        derivative_name: Name of the derivatives sub-folder used by SPINEPS.
        ignore_bids_filter: If True, disable strict BIDS filename filtering.

    Returns:
        Dictionary mapping output keys (e.g. ``"out_spine"``, ``"out_vert"``) to
        the corresponding ``Path`` objects.
    """
    from spineps.seg_run import output_paths_from_input

    if not isinstance(file_path, BIDS_FILE):
        file_path = Path(file_path)
        file_path = BIDS_FILE(file_path, file_path.parent if dataset is None else dataset)
    output_paths = output_paths_from_input(
        file_path,
        derivative_name,
        None,
        input_format=file_path.format,
        non_strict_mode=ignore_bids_filter,
    )
    return output_paths

run_spineps

run_spineps(file_path: str | Path | BIDS_FILE, dataset: str | Path | None = None, model_semantic: str | Path = 't2w', model_instance: str | Path = 'instance', model_labeling: str | None = 't2w_labeling', derivative_name: str = 'derivative', override_semantic: bool = False, override_instance: bool = False, lambda_semantic=None, save_debug_data: bool = False, verbose: bool = False, save_raw: bool = False, ignore_compatibility_issues: bool = False, use_cpu: bool = False, **args) -> dict

Run the SPINEPS spine segmentation pipeline on a single image.

Handles model loading, BIDS path resolution, and delegates to SPINEPS' process_img_nii function.

Parameters:

Name Type Description Default
file_path str | Path | BIDS_FILE

Path to the input NIfTI image, or a BIDS_FILE object.

required
dataset str | Path | None

Optional dataset root directory (used when file_path is a plain path and a BIDS root is required).

None
model_semantic str | Path

Semantic segmentation model name (e.g. "t2w") or explicit path to a model folder.

't2w'
model_instance str | Path

Instance segmentation model name or explicit path.

'instance'
model_labeling str | None

Labeling model name, or None to skip labeling.

't2w_labeling'
derivative_name str

Name of the derivatives sub-folder for outputs.

'derivative'
override_semantic bool

If True, recompute the semantic segmentation even when a cached result exists.

False
override_instance bool

If True, recompute the instance segmentation even when a cached result exists.

False
lambda_semantic

Optional callable to post-process the semantic output.

None
save_debug_data bool

If True, save intermediate debug files.

False
verbose bool

If True, enable verbose logging.

False
save_raw bool

If True, also save unprocessed (raw) model outputs.

False
ignore_compatibility_issues bool

If True, suppress BIDS compatibility checks and model/image compatibility warnings.

False
use_cpu bool

If True, force CPU inference even if a GPU is available.

False
**args

Additional keyword arguments forwarded to process_img_nii.

{}

Returns:

Type Description
dict

The output paths dictionary returned by SPINEPS' process_img_nii.

Source code in TPTBox/segmentation/spineps.py
def run_spineps(
    file_path: str | Path | BIDS_FILE,
    dataset: str | Path | None = None,
    model_semantic: str | Path = "t2w",
    model_instance: str | Path = "instance",
    model_labeling: str | None = "t2w_labeling",
    derivative_name: str = "derivative",
    override_semantic: bool = False,
    override_instance: bool = False,
    lambda_semantic=None,
    save_debug_data: bool = False,
    verbose: bool = False,
    save_raw: bool = False,
    ignore_compatibility_issues: bool = False,
    use_cpu: bool = False,
    **args,
) -> dict:
    """Run the SPINEPS spine segmentation pipeline on a single image.

    Handles model loading, BIDS path resolution, and delegates to SPINEPS'
    ``process_img_nii`` function.

    Args:
        file_path: Path to the input NIfTI image, or a ``BIDS_FILE`` object.
        dataset: Optional dataset root directory (used when *file_path* is a plain
            path and a BIDS root is required).
        model_semantic: Semantic segmentation model name (e.g. ``"t2w"``) or
            explicit path to a model folder.
        model_instance: Instance segmentation model name or explicit path.
        model_labeling: Labeling model name, or ``None`` to skip labeling.
        derivative_name: Name of the derivatives sub-folder for outputs.
        override_semantic: If True, recompute the semantic segmentation even when
            a cached result exists.
        override_instance: If True, recompute the instance segmentation even when
            a cached result exists.
        lambda_semantic: Optional callable to post-process the semantic output.
        save_debug_data: If True, save intermediate debug files.
        verbose: If True, enable verbose logging.
        save_raw: If True, also save unprocessed (raw) model outputs.
        ignore_compatibility_issues: If True, suppress BIDS compatibility checks
            and model/image compatibility warnings.
        use_cpu: If True, force CPU inference even if a GPU is available.
        **args: Additional keyword arguments forwarded to ``process_img_nii``.

    Returns:
        The output paths dictionary returned by SPINEPS' ``process_img_nii``.
    """
    from spineps import get_instance_model, get_semantic_model, process_img_nii
    from spineps.get_models import get_actual_model

    label = {}
    try:
        from spineps.get_models import get_labeling_model

        if model_labeling is not None:
            label = {"model_labeling": get_labeling_model(model_labeling, use_cpu=use_cpu)}
    except Exception:
        pass  # TODO remove when spineps has officially adopted labeling

    if not isinstance(file_path, BIDS_FILE):
        file_path = Path(file_path)
        file_path = BIDS_FILE(file_path, file_path.parent if dataset is None else dataset)
    elif dataset is not None:
        file_path.dataset = dataset
    if isinstance(model_semantic, Path):
        model_semantic = get_actual_model(model_semantic, use_cpu=use_cpu)
    else:
        model_semantic = get_semantic_model(model_semantic, use_cpu=use_cpu)
    if isinstance(model_instance, Path):
        model_instance = get_actual_model(model_instance, use_cpu=use_cpu)
    else:
        model_instance = get_instance_model(model_instance, use_cpu=use_cpu)
    output_paths, errcode = process_img_nii(
        img_ref=file_path,
        derivative_name=derivative_name,
        model_semantic=model_semantic,
        model_instance=model_instance,
        **label,
        override_semantic=override_semantic,
        override_instance=override_instance,
        lambda_semantic=lambda_semantic,
        save_debug_data=save_debug_data,
        verbose=verbose,
        save_raw=save_raw,
        ignore_compatibility_issues=ignore_compatibility_issues,
        ignore_bids_filter=ignore_compatibility_issues,
        **args,
    )
    return output_paths

VibeSeg

TPTBox.segmentation.VibeSeg.vibeseg

run_nnunet

run_nnunet(i: list[Image_Reference], out_seg: str | Path, *, override: bool = False, gpu: int = 0, ddevice: Literal['cpu', 'cuda', 'mps'] = 'cuda', dataset_id: int = 80, model_path: str | Path | None = None, auto_download=False, keep_size=False, fill_holes=False, logits=False, mapping=None, crop=False, max_folds=None, mode='nearest', padd: int = 0, key_ResEnc='__nnUNet*ResEnc', **args) -> None

Run an nnU-Net model on a list of images (multi-channel) and save the segmentation.

Parameters:

Name Type Description Default
i list[Image_Reference]

List of input image references forming one multi-channel sample.

required
out_seg str | Path

Destination path for the segmentation output (NIfTI).

required
override bool

If True, recompute and overwrite an existing output file.

False
gpu int

GPU device index to use for inference.

0
ddevice Literal['cpu', 'cuda', 'mps']

Compute device: "cuda", "cpu", or "mps".

'cuda'
dataset_id int

nnU-Net dataset identifier.

80
**args

Additional keyword arguments forwarded to run_inference_on_file.

{}
Source code in TPTBox/segmentation/VibeSeg/vibeseg.py
def run_nnunet(
    i: list[Image_Reference],
    out_seg: str | Path,
    *,
    override: bool = False,
    gpu: int = 0,
    ddevice: Literal["cpu", "cuda", "mps"] = "cuda",
    dataset_id: int = 80,
    model_path: str | Path | None = None,
    auto_download=False,  # set to True if model_path is None
    keep_size=False,
    fill_holes=False,
    logits=False,
    mapping=None,
    crop=False,
    max_folds=None,
    mode="nearest",
    padd: int = 0,
    key_ResEnc="__nnUNet*ResEnc",
    **args,
) -> None:
    """Run an nnU-Net model on a list of images (multi-channel) and save the segmentation.

    Args:
        i: List of input image references forming one multi-channel sample.
        out_seg: Destination path for the segmentation output (NIfTI).
        override: If True, recompute and overwrite an existing output file.
        gpu: GPU device index to use for inference.
        ddevice: Compute device: ``"cuda"``, ``"cpu"``, or ``"mps"``.
        dataset_id: nnU-Net dataset identifier.
        **args: Additional keyword arguments forwarded to ``run_inference_on_file``.
    """
    run_inference_on_file(
        dataset_id,
        [to_nii(i) for i in i],
        out_file=out_seg,
        override=override,
        gpu=gpu,
        ddevice=ddevice,
        model_path=model_path,
        auto_download=auto_download,
        keep_size=keep_size,
        fill_holes=fill_holes,
        logits=logits,
        mapping=mapping,
        crop=crop,
        max_folds=max_folds,
        mode=mode,
        padd=padd,
        _key_ResEnc=key_ResEnc,
        **args,
    )

extract_vertebra_bodies_from_VibeSeg

extract_vertebra_bodies_from_VibeSeg(nii_vibeSeg: Image_Reference, num_thoracic_verts: int = 12, num_lumbar_verts: int = 5, out_path: str | Path | None = None, out_path_poi: str | Path | None = None) -> tuple[NII, POI]

Extracts and labels vertebra bodies from a VibeSeg segmentation NIfTI file.

This function processes a segmentation mask containing vertebrae and intervertebral discs (IVDs). It separates individual vertebra bodies by eroding and splitting the mask at IVD regions, labels the vertebrae from bottom to top (lumbar and thoracic), and optionally saves the labeled mask and point-of-interest (POI) data.

Parameters:

Name Type Description Default
nii_vibeSeg Image_Reference

Path or reference to the NIfTI file containing the VibeSeg segmentation mask.

required
num_thoracic_verts int

Number of thoracic vertebrae to include. Defaults to 12.

12
num_lumbar_verts int

Number of lumbar vertebrae to include. Defaults to 5.

5
out_path str | Path | None

Path to save the processed mask data. If None, no files are saved. Defaults to None.

None
out_path_poi str | Path | None

Path to save the processed POI data (ending json). If None, no files are saved. Defaults to None.

None

Returns:

Name Type Description
tuple tuple[NII, POI]
  • components (NII): A labeled NIfTI mask of the segmented vertebra bodies.
  • centroids_mapped (POI): Centroids of the labeled vertebrae as a point-of-interest (POI) dataset.
Notes
  • Labels for the vertebrae follow the naming convention: L1=20 to L5=24 for lumbar and T1=8 to T12=19 for thoracic; T13 = 28.
  • Cervical vertebrae and any unclassified regions are excluded (set to 0).
  • The output files, if saved, will include the mask and POI data:
  • Mask file: <out_path>
  • POI file: <out_path> with _poi.json suffix recommended.
Example

nii_vibeSeg = "/path/to/vibe_segmentation.nii.gz" labeled_mask, centroids = extract_vertebra_bodies_from_nii_vibeSeg(nii_vibeSeg, out_path="output_mask.nii.gz")

Source code in TPTBox/segmentation/VibeSeg/vibeseg.py
def extract_vertebra_bodies_from_VibeSeg(
    nii_vibeSeg: Image_Reference,
    num_thoracic_verts: int = 12,
    num_lumbar_verts: int = 5,
    out_path: str | Path | None = None,
    out_path_poi: str | Path | None = None,
) -> tuple[NII, POI]:
    """Extracts and labels vertebra bodies from a VibeSeg segmentation NIfTI file.

    This function processes a segmentation mask containing vertebrae and intervertebral discs (IVDs).
    It separates individual vertebra bodies by eroding and splitting the mask at IVD regions, labels the vertebrae
    from bottom to top (lumbar and thoracic), and optionally saves the labeled mask and point-of-interest (POI) data.

    Args:
        nii_vibeSeg (Image_Reference): Path or reference to the NIfTI file containing the VibeSeg segmentation mask.
        num_thoracic_verts (int, optional): Number of thoracic vertebrae to include. Defaults to 12.
        num_lumbar_verts (int, optional): Number of lumbar vertebrae to include. Defaults to 5.
        out_path (str | Path | None, optional): Path to save the processed mask data. If None, no files are saved. Defaults to None.
        out_path_poi (str | Path | None, optional):  Path to save the processed POI data (ending json). If None, no files are saved. Defaults to None.

    Returns:
        tuple:
            - components (NII): A labeled NIfTI mask of the segmented vertebra bodies.
            - centroids_mapped (POI): Centroids of the labeled vertebrae as a point-of-interest (POI) dataset.

    Notes:
        - Labels for the vertebrae follow the naming convention: L1=20 to L5=24 for lumbar and T1=8 to T12=19 for thoracic; T13 = 28.
        - Cervical vertebrae and any unclassified regions are excluded (set to 0).
        - The output files, if saved, will include the mask and POI data:
          - Mask file: `<out_path>`
          - POI file: `<out_path>` with `_poi.json` suffix recommended.

    Example:
        >>> nii_vibeSeg = "/path/to/vibe_segmentation.nii.gz"
        >>> labeled_mask, centroids = extract_vertebra_bodies_from_nii_vibeSeg(nii_vibeSeg, out_path="output_mask.nii.gz")
    """
    from TPTBox import Vertebra_Instance, calc_centroids

    # Load the nii_vibeSeg segmentation
    nii = to_nii(nii_vibeSeg, seg=True)
    vertebrae = nii.extract_label(69)
    ivds = nii.extract_label(68)

    # Erode vertebra masks and split them by IVDs
    split_masks = vertebrae.erode_msk(1, connectivity=3, verbose=False)
    split_masks[ivds.dilate_msk(1, connectivity=1, verbose=False) == 1] = 0

    # Get connected components and clean them
    vert_bodys = split_masks.get_connected_components()
    vert_bodys.dilate_msk_(3, verbose=False)
    vert_bodys[vertebrae != 1] = 0

    # Calculate centroids for vertebra bodies
    centroids_unsorted = calc_centroids(vert_bodys, second_stage=50)
    centroids_unsorted_srp = centroids_unsorted.reorient(("S", "R", "P"))
    centroids_sorted = dict(
        sorted(
            {i: centroids_unsorted_srp[i, 50][0] for i in centroids_unsorted_srp.keys_region()}.items(),
            key=lambda x: x[1],
        )
    )

    # Map centroids to labels based on thoracic and lumbar vertebra counts
    def map_to_label(index: int) -> int:
        """Map a bottom-up vertebra index to the corresponding anatomical label integer."""
        if index >= num_thoracic_verts + num_lumbar_verts:
            return 0  # Remove cervical vertebrae
        if index < num_lumbar_verts:
            return Vertebra_Instance.name2idx()[f"L{num_lumbar_verts - index}"]
        return Vertebra_Instance.name2idx()[f"T{num_thoracic_verts - (index - num_lumbar_verts)}"]

    label_mapping = {k: map_to_label(i) for i, k in enumerate(centroids_sorted)}
    vert_bodys.map_labels_(label_mapping, verbose=False)
    centroids = centroids_unsorted.map_labels(label_map_region=label_mapping)
    # Save outputs if an output path is specified
    if out_path:
        vert_bodys.save(out_path)

    if out_path_poi:
        centroids.save(out_path_poi)

    return vert_bodys, centroids

nnU-Net Utilities

TPTBox.segmentation.nnUnet_utils.inference_api

load_inf_model

load_inf_model(model_folder: str | Path, step_size: float = 0.5, ddevice: str = 'cuda', use_folds: tuple[str | int, ...] | None = None, init_threads: bool = True, allow_non_final: bool = True, inference_augmentation: bool = False, use_gaussian: bool = True, verbose: bool = False, gpu: int | None = None, memory_base: int = 5000, memory_factor: int = 160, memory_max: int = 160000, wait_till_gpu_percent_is_free: float = 0.3) -> nnUNetPredictor

Load and initialise an nnU-Net model predictor from a trained model folder.

Parameters:

Name Type Description Default
model_folder str | Path

Path to the nnU-Net result folder containing fold_* sub-directories.

required
step_size float

Sliding-window step size as a fraction of the patch size. Larger values are faster but may reduce accuracy. Must be in (0, 1].

0.5
ddevice str

Inference device: "cuda" (GPU), "cpu", or "mps" (Apple Silicon). Do NOT use this to select a GPU index.

'cuda'
use_folds tuple[str | int, ...] | None

Tuple of fold indices or fold names to load. Loads all available folds when None.

None
init_threads bool

If True, configure PyTorch thread counts optimally for the selected device.

True
allow_non_final bool

If True, fall back to checkpoint_best.pth when checkpoint_final.pth is not found.

True
inference_augmentation bool

If True, enable test-time mirroring augmentation.

False
use_gaussian bool

If True, apply Gaussian weighting in the sliding window.

True
verbose bool

If True, print progress information during model initialisation.

False
gpu int | None

GPU device index forwarded to the predictor. None defaults to 0.

None
memory_base int

Base GPU memory reservation in MB (default 5 000 MB = 5 GB).

5000
memory_factor int

Per-voxel memory scaling factor. The formula is prod(shape) * memory_factor / 1000 MB; 160 corresponds to ~30 GB for a 512³ volume.

160
memory_max int

Maximum GPU memory cap in MB (default 160 000 MB = 160 GB).

160000
wait_till_gpu_percent_is_free float

Fraction of GPU memory that must be free before inference is started.

0.3

Returns:

Type Description
nnUNetPredictor

Initialised nnUNetPredictor ready for inference.

Source code in TPTBox/segmentation/nnUnet_utils/inference_api.py
def load_inf_model(
    model_folder: str | Path,
    step_size: float = 0.5,
    ddevice: str = "cuda",
    use_folds: tuple[str | int, ...] | None = None,
    init_threads: bool = True,
    allow_non_final: bool = True,
    inference_augmentation: bool = False,
    use_gaussian: bool = True,
    verbose: bool = False,
    gpu: int | None = None,
    memory_base: int = 5000,
    memory_factor: int = 160,
    memory_max: int = 160000,
    wait_till_gpu_percent_is_free: float = 0.3,
) -> nnUNetPredictor:
    """Load and initialise an nnU-Net model predictor from a trained model folder.

    Args:
        model_folder: Path to the nnU-Net result folder containing ``fold_*``
            sub-directories.
        step_size: Sliding-window step size as a fraction of the patch size.
            Larger values are faster but may reduce accuracy.  Must be in (0, 1].
        ddevice: Inference device: ``"cuda"`` (GPU), ``"cpu"``, or ``"mps"``
            (Apple Silicon).  Do NOT use this to select a GPU index.
        use_folds: Tuple of fold indices or fold names to load.  Loads all
            available folds when ``None``.
        init_threads: If True, configure PyTorch thread counts optimally for the
            selected device.
        allow_non_final: If True, fall back to ``checkpoint_best.pth`` when
            ``checkpoint_final.pth`` is not found.
        inference_augmentation: If True, enable test-time mirroring augmentation.
        use_gaussian: If True, apply Gaussian weighting in the sliding window.
        verbose: If True, print progress information during model initialisation.
        gpu: GPU device index forwarded to the predictor.  ``None`` defaults to 0.
        memory_base: Base GPU memory reservation in MB (default 5 000 MB = 5 GB).
        memory_factor: Per-voxel memory scaling factor.  The formula is
            ``prod(shape) * memory_factor / 1000`` MB; 160 corresponds to ~30 GB
            for a 512³ volume.
        memory_max: Maximum GPU memory cap in MB (default 160 000 MB = 160 GB).
        wait_till_gpu_percent_is_free: Fraction of GPU memory that must be free
            before inference is started.

    Returns:
        Initialised ``nnUNetPredictor`` ready for inference.
    """
    if isinstance(model_folder, str):
        model_folder = Path(model_folder)
    if ddevice == "cpu":
        import multiprocessing

        torch.set_num_threads(multiprocessing.cpu_count()) if init_threads else None
        device = torch.device("cpu")
    elif ddevice == "cuda":
        # multithreading in torch doesn't help nnU-Net if run on GPU
        try:
            torch.set_num_threads(1) if init_threads else None
            global _interop  # noqa: PLW0603
            if not _interop:
                torch.set_num_interop_threads(1) if init_threads else None
                _interop = True
        except Exception as e:
            print(e)
        device = torch.device("cuda")
    else:
        device = torch.device("mps")

    assert model_folder.exists(), f"model-folder not found: got path {model_folder}"

    predictor = nnUNetPredictor(
        tile_step_size=step_size,
        use_gaussian=use_gaussian,
        use_mirroring=inference_augmentation,  # <- mirroring augmentation!
        perform_everything_on_gpu=ddevice != "cpu",
        device=device,
        verbose=verbose,
        verbose_preprocessing=False,
        cuda_id=0 if gpu is None else gpu,
        memory_base=memory_base,
        memory_factor=memory_factor,
        memory_max=memory_max,
        wait_till_gpu_percent_is_free=wait_till_gpu_percent_is_free,
    )
    check_name = "checkpoint_final.pth"  # if not allow_non_final else "checkpoint_best.pth"
    try:
        predictor.initialize_from_trained_model_folder(str(model_folder), checkpoint_name=check_name, use_folds=use_folds)
    except FileNotFoundError as e:
        if allow_non_final:
            try:
                predictor.initialize_from_trained_model_folder(
                    str(model_folder),
                    checkpoint_name="checkpoint_best.pth",
                    use_folds=use_folds,
                )
                logger.print("Checkpoint final not found, will load from best instead", Log_Type.WARNING)
            except Exception:
                raise e  # noqa: B904
        else:
            raise e  # noqa: TRY201
    logger.print(f"Inference Model loaded from {model_folder}") if verbose else None
    return predictor

run_inference

run_inference(input_nii: str | NII | list[NII], predictor: nnUNetPredictor, reorient_PIR: bool = False, logits: bool = False, verbose: bool = False) -> tuple[NII, NII | None, np.ndarray | None]

Run nnU-Net inference on a single image or list of images (multi-channel).

Parameters:

Name Type Description Default
input_nii str | NII | list[NII]

Input as a path to a .nii.gz file, a NII object, or a list of NII objects for multi-channel models.

required
predictor nnUNetPredictor

Loaded nnUNetPredictor as returned by :func:load_inf_model.

required
reorient_PIR bool

If True, reorient each input image to PIR orientation before passing it to the model.

False
logits bool

If True, return raw softmax logits. Currently not implemented and will raise NotImplementedError.

False
verbose bool

Unused; reserved for future logging support.

False

Raises:

Type Description
NotImplementedError

If logits is True.

AssertionError

If input_nii is not a supported type, or if the output shape does not match the input shape.

Returns:

Type Description
tuple[NII, NII | None, ndarray | None]

A tuple of (segmentation_NII, uncertainty_map_or_None, softmax_logits_or_None).

Source code in TPTBox/segmentation/nnUnet_utils/inference_api.py
def run_inference(
    input_nii: str | NII | list[NII],
    predictor: nnUNetPredictor,
    reorient_PIR: bool = False,  # noqa: N803
    logits: bool = False,
    verbose: bool = False,  # noqa: ARG001
) -> tuple[NII, NII | None, np.ndarray | None]:
    """Run nnU-Net inference on a single image or list of images (multi-channel).

    Args:
        input_nii: Input as a path to a ``.nii.gz`` file, a ``NII`` object, or a
            list of ``NII`` objects for multi-channel models.
        predictor: Loaded ``nnUNetPredictor`` as returned by :func:`load_inf_model`.
        reorient_PIR: If True, reorient each input image to PIR orientation before
            passing it to the model.
        logits: If True, return raw softmax logits.  Currently not implemented and
            will raise ``NotImplementedError``.
        verbose: Unused; reserved for future logging support.

    Raises:
        NotImplementedError: If *logits* is True.
        AssertionError: If *input_nii* is not a supported type, or if the output
            shape does not match the input shape.

    Returns:
        A tuple of (segmentation_NII, uncertainty_map_or_None, softmax_logits_or_None).
    """
    if logits:
        raise NotImplementedError("logits=True")
    if isinstance(input_nii, str):
        assert input_nii.endswith(".nii.gz"), f"input file is not a .nii.gz! Got {input_nii}"
        input_nii = NII.load(input_nii, seg=False)

    assert isinstance(input_nii, (NII, list)), f"input must be a NII or str or list[NII], got {type(input_nii)}"
    if isinstance(input_nii, NII):
        input_nii = [input_nii]
    orientation = input_nii[0].orientation

    img_arrs = []
    # Prepare for nnUNet behavior
    for i in input_nii:
        if reorient_PIR:
            i.reorient_()
        a = i.get_array().astype(np.float16)
        nii_img_converted = np.transpose(a, axes=a.ndim - 1 - np.arange(a.ndim))[np.newaxis, :]
        img_arrs.append(nii_img_converted)
    try:
        img = np.vstack(img_arrs)
    except Exception:
        print("could not stack images; shapes=", [a.shape for a in img_arrs])
        raise
    props = {"spacing": i.zoom[::-1]}  # PIR
    out = predictor.predict_single_npy_array(img, props, save_or_return_probabilities=False)
    segmentation: np.ndarray = out  # type: ignore
    softmax_logits = None
    segmentation = np.transpose(segmentation, axes=segmentation.ndim - 1 - np.arange(segmentation.ndim))
    assert segmentation.shape == input_nii[0].shape, (segmentation.shape, input_nii[0].shape)
    seg_nii = input_nii[0].set_array(segmentation.astype(np.uint8), seg=True)
    seg_nii.reorient_(orientation, verbose=False)
    return seg_nii, None, softmax_logits

sliding_nd_slices

sliding_nd_slices(arr: ndarray, patch_size: tuple, overlap: int, fun) -> np.ndarray

Apply fun to an N-D array using a sliding-window strategy with overlap.

Parameters:

Name Type Description Default
arr ndarray

Input array to process.

required
patch_size tuple

Size of each patch along every dimension.

required
overlap int

Number of voxels to overlap between adjacent patches (symmetric).

required
fun

Callable applied to each patch; must return an array with the same shape as its input.

required

Returns:

Type Description
ndarray

Reconstructed array of the same shape as arr with patch outputs stitched

ndarray

together (overlap regions are overwritten by the most recent patch).

Source code in TPTBox/segmentation/nnUnet_utils/inference_api.py
def sliding_nd_slices(arr: np.ndarray, patch_size: tuple, overlap: int, fun) -> np.ndarray:
    """Apply *fun* to an N-D array using a sliding-window strategy with overlap.

    Args:
        arr: Input array to process.
        patch_size: Size of each patch along every dimension.
        overlap: Number of voxels to overlap between adjacent patches (symmetric).
        fun: Callable applied to each patch; must return an array with the same
            shape as its input.

    Returns:
        Reconstructed array of the same shape as *arr* with patch outputs stitched
        together (overlap regions are overwritten by the most recent patch).
    """
    print("sliding window")
    step = tuple(p - overlap for p in patch_size)
    half_overlap = overlap // 2
    shape = arr.shape

    # Compute number of steps in each dimension
    ranges = [range(0, max(s, 1), st) if s != 1 else [0] for s, st in zip(shape, step)]
    result = np.zeros_like(arr)
    for starts in np.ndindex(*[len(r) for r in ranges]):
        # Compute actual start and end indices for this patch
        idx_start = [ranges[dim][i] for dim, i in enumerate(starts)]
        idx_start2 = [ranges[dim][i] + half_overlap if ranges[dim][i] != 0 else 0 for dim, i in enumerate(starts)]
        idx_start3 = [half_overlap if ranges[dim][i] != 0 else 0 for dim, i in enumerate(starts)]
        idx_end = [min(start + size, shape[dim]) for start, size, dim in zip(idx_start, patch_size, range(len(shape)))]
        idx_end2 = [
            (start + size - half_overlap if start + size < shape[dim] else shape[dim])
            for start, size, dim in zip(idx_start, patch_size, range(len(shape)))
        ]
        idx_end3 = [(-half_overlap if a != shape[dim] else None) for a, dim in zip(idx_end2, range(len(shape)))]

        slices = tuple(slice(s, e) for s, e in zip(idx_start, idx_end))
        slices2 = tuple(slice(s, e) for s, e in zip(idx_start2, idx_end2))
        slices3 = tuple(slice(s, e) for s, e in zip(idx_start3, idx_end3))
        print("sliding window", slices)
        patch = arr[slices]
        patch = fun(patch)
        result[slices2] = patch[slices3]
    return result