Skip to content

Registration

Image registration via ANTs (rigid and deformable) and DeepALI (deep learning–based).

Point-Based Rigid Registration

TPTBox.registration.Point_Registration

Source code in TPTBox/registration/_ridged_points/point_registration.py
class Point_Registration:
    def __init__(
        self,
        poi_fixed: POI,
        poi_moving: POI,
        exclusion=None,
        log: Logger_Interface = No_Logger(),  # noqa: B008
        verbose=True,
        ax_code=None,
        zooms=None,
        leave_worst_percent_out=0.0,
    ):
        """Use two Centroids object to compute a ridged_points registration.

        Args:
            ctd_fixed (Centroids): _description_
            ctd_movig (Centroids): _description_
            representative_fixed (Image_Reference, optional): _description_. Defaults to None.
            representative_movig (Image_Reference, optional): _description_. Defaults to None.
            exclusion (list, optional): _description_. Defaults to [].
            log (_type_, optional): _description_. Defaults to No_Logger().
            verbose (bool, optional): _description_. Defaults to True.

        Raises:
            ValueError: Require at least two points

        Returns:
            Resample_Filter
        """
        assert leave_worst_percent_out < 1.0
        assert leave_worst_percent_out >= 0.0
        if exclusion is None:
            exclusion = []
        if ax_code is not None:
            poi_fixed.reorient_(ax_code)
        if zooms is not None and zooms != (-1, -1, -1):
            poi_fixed.rescale_(zooms)
        representative_f_sitk = nii_to_sitk(poi_fixed.make_empty_nii())
        representative_m_sitk = nii_to_sitk(poi_moving.make_empty_nii())

        # Register
        # filter points by name
        f_keys = list(filter(lambda x: x[0] not in exclusion, poi_fixed.keys()))
        m_keys = list(poi_moving.keys())
        # limit to only shared labels
        inter = [x for x in f_keys if x in m_keys]
        log.print(f_keys, verbose=verbose)
        log.print(poi_fixed.orientation, verbose=verbose)

        if len(inter) < 2:
            log.print("[!] To few points, skip registration", Log_Type.FAIL)
            raise ValueError(
                f"[!] To few points, skip registration; {poi_fixed.keys()=}; {poi_moving.keys()=}",
            )
        img_movig = poi_moving.make_empty_nii()
        assert img_movig.shape == poi_moving.shape_int, (img_movig, poi_moving.shape)
        assert img_movig.orientation == poi_moving.orientation
        if leave_worst_percent_out != 0.0:
            poi_fixed = poi_fixed.intersect(poi_moving)
            init_transform, error_reg, error_natural, delta_after = _compute_versor(
                inter,
                poi_fixed,
                representative_f_sitk,
                poi_moving,
                representative_m_sitk,
                verbose=False,
                log=log,
            )
            delta_after = sorted(delta_after.items(), key=lambda x: -x[1])
            out_str = f"Did not use the following keys for registaiton (worst {leave_worst_percent_out * 100} %) "
            for i, key in enumerate(delta_after):
                if i >= len(delta_after) * leave_worst_percent_out:
                    break
                poi_fixed.remove_centroid_(key[0])
                out_str += f"{key}, "
            log.print(out_str, verbose=verbose)
            log.print("Error with all points", error_reg, Log_Type.STAGE, verbose=verbose)
            f_keys = list(filter(lambda x: x[0] not in exclusion, poi_fixed.keys()))
            m_keys = list(poi_moving.keys())
        # limit to only shared labels
        inter = [x for x in f_keys if x in m_keys]
        init_transform, error_reg, error_natural, _ = _compute_versor(
            inter,
            poi_fixed,
            representative_f_sitk,
            poi_moving,
            representative_m_sitk,
            verbose=verbose,
            log=log,
        )
        self._transform: sitk.VersorRigid3DTransform = init_transform

        ### Point Reg
        self._img_moving: sitk.Image = representative_m_sitk
        self._img_fixed: sitk.Image = representative_f_sitk
        self.error_reg: float = error_reg
        self.error_natural: float = error_natural
        self.input_poi: Has_Grid = poi_moving.to_gird()
        self.out_poi: Has_Grid = poi_fixed.to_gird()

    def get_resampler(self, seg: bool, c_val: float, output_space: NII | None = None) -> sitk.ResampleImageFilter:
        """Build a configured SimpleITK resampler for this registration transform.

        Args:
            seg: If True, use nearest-neighbour interpolation (segmentation mode).
            c_val: Default fill value for background voxels (ignored when seg=True).
            output_space: Optional target image space. Defaults to the fixed image space.

        Returns:
            A configured ``sitk.ResampleImageFilter`` ready to be executed.
        """
        resampler: sitk.ResampleImageFilter = sitk.ResampleImageFilter()
        if output_space is None:
            resampler.SetReferenceImage(self._img_fixed)
        else:
            resampler.SetReferenceImage(nii_to_sitk(output_space))
        if seg:
            resampler.SetInterpolator(sitk.sitkNearestNeighbor)
            resampler.SetDefaultPixelValue(0)
        else:
            resampler.SetInterpolator(sitk.sitkBSplineResampler)
            resampler.SetDefaultPixelValue(math.floor(c_val))
        resampler.SetTransform(self._transform)
        return resampler

    def transform(self, x: NII_or_POI) -> NII_or_POI:
        """Apply the registration transform to a NII image or a POI landmark set.

        Args:
            x: A ``NII`` image or ``POI`` landmark set to transform.

        Returns:
            Transformed object of the same type as the input.

        Raises:
            ValueError: If *x* is neither a ``NII`` nor a ``POI``.
        """
        if isinstance(x, POI):
            return self.transform_poi(x)
        if isinstance(x, NII):
            return self.transform_nii(x)
        raise ValueError

    def transform_poi(
        self,
        poi_moving: POI,
        allow_only_same_grid_as_moving: bool = True,
        output_space: NII | POI | None = None,
    ) -> POI:
        """Transform a set of landmarks (POI) from the moving to the fixed image space.

        Args:
            poi_moving: Landmark set defined in the moving image space.
            allow_only_same_grid_as_moving: If True, assert that *poi_moving* shares
                the grid of the moving image used during registration.
            output_space: Optional target space to resample the result into.

        Returns:
            Transformed ``POI`` in the fixed (or *output_space*) coordinate frame.
        """
        # output_space: POI | NII | None = None,
        if allow_only_same_grid_as_moving:
            text = "input image must be in the same space as moving.  If you sure that this input is in same space as the moving image you can turn of 'only_allow_grid_as_moving'"
            poi_moving.assert_affine(self.input_poi, text=text)
        move_l = []
        keys = []
        out = dict(zip_strict(keys, move_l))

        for key, key2, (x, y, z) in poi_moving.items():
            out[key, key2] = self.transform_cord((x, y, z))

        poi = self.out_poi.make_empty_POI(out)
        if output_space is not None:
            poi = poi.resample_from_to(output_space)
        return poi

    def transform_poi_inverse(self, poi_moving: POI, allow_only_same_grid_as_moving=True, output_space=None):
        # output_space: POI | NII | None = None,
        if allow_only_same_grid_as_moving:
            text = "input image must be in the same space as moving.  If you sure that this input is in same space as the moving image you can turn of 'only_allow_grid_as_moving'"
            poi_moving.assert_affine(self.out_poi, text=text)
        move_l = []
        keys = []
        out = dict(zip_strict(keys, move_l))

        for key, key2, (x, y, z) in poi_moving.items():
            out[key, key2] = self.transform_cord_inverse((x, y, z))

        poi = self.out_poi.make_empty_POI(out)
        if output_space is not None:
            poi = poi.resample_from_to(output_space)
        return poi

    def transform_cord(self, cord: tuple[float, ...], out: sitk.Image | None = None) -> np.ndarray:
        """Transform a single voxel coordinate from moving to fixed image space.

        Args:
            cord: Voxel coordinate (x, y, z) in the moving image.
            out: Reference SimpleITK image defining the output space.
                Defaults to the fixed image used during registration.

        Returns:
            Transformed coordinate as a NumPy array of shape (3,).
        """
        if out is None:
            out = self._img_fixed
        ctr_b = self._img_moving.TransformContinuousIndexToPhysicalPoint(cord)
        ctr_b = self._transform.GetInverse().TransformPoint(ctr_b)
        ctr_b = out.TransformPhysicalPointToContinuousIndex(ctr_b)
        return np.array(ctr_b)

    def transform_cord_inverse(self, cord: tuple[float, ...], out: sitk.Image | None = None) -> np.ndarray:
        """Transform a single voxel coordinate from fixed to moving image space (inverse direction).

        Args:
            cord: Voxel coordinate (x, y, z) in the fixed image.
            out: Reference SimpleITK image defining the output space.
                Defaults to the fixed image used during registration.

        Returns:
            Transformed coordinate as a NumPy array of shape (3,).
        """
        if out is None:
            out = self._img_fixed
        ctr_b = out.TransformContinuousIndexToPhysicalPoint(cord)
        ctr_b = self._transform.TransformPoint(ctr_b)
        ctr_b = self._img_moving.TransformPhysicalPointToContinuousIndex(ctr_b)
        return np.array(ctr_b)

    def transform_nii(
        self,
        moving_img_nii: NII,
        allow_only_same_grid_as_moving: bool = True,
        output_space: NII | None = None,
        c_val: float | None = None,
    ) -> NII:
        """Resample a NII image from the moving into the fixed image space.

        Args:
            moving_img_nii: Image defined in the moving image space.
            allow_only_same_grid_as_moving: If True, assert that *moving_img_nii*
                shares the grid of the moving image used during registration.
            output_space: Optional target space for the resampled output.
                Defaults to the fixed image space.
            c_val: Background fill value. Derived from the image automatically
                when not provided.

        Returns:
            Resampled ``NII`` in the fixed (or *output_space*) coordinate frame.
        """
        if allow_only_same_grid_as_moving:
            text = "input image must be in the same space as moving.  If you sure that this input is in same space as the moving image you can turn of 'only_allow_grid_as_moving'"
            moving_img_nii.assert_affine(self.input_poi, text=text, shape_tolerance=0.9)
        if c_val is None:
            c_val = moving_img_nii.get_c_val()
        resampler = self.get_resampler(moving_img_nii.seg, c_val, output_space=output_space)
        img_sitk = nii_to_sitk(moving_img_nii)
        transformed_img = resampler.Execute(img_sitk)
        if moving_img_nii.seg:
            transformed_img = sitk.Round(transformed_img)
        return sitk_to_nii(transformed_img, seg=moving_img_nii.seg)

    def get_affine(self) -> np.ndarray:
        """Return the 4x4 affine matrix corresponding to the rigid registration transform.

        The matrix follows the convention:
            T(x) = A(x - c) + (t + c)

        Returns:
            A (4, 4) NumPy array representing the homogeneous affine transform.
        """
        # VersorRigid3DTransform
        # T(x) = A ( x - c ) + (t + c)
        # T(x) = Ax (- Ac + t + c)
        # let C = (- Ac + t + c)
        # (C^T*(Ax)^T)
        assert isinstance(self._transform, sitk.VersorRigid3DTransform)
        A = np.eye(4)  # noqa: N806
        v = self._transform.GetInverse()
        A[:3, :3] = np.array(v.GetMatrix()).reshape(3, 3)  # Rotation matrix
        c = np.array(v.GetCenter())  # Center of rotation
        t = np.array(v.GetTranslation())  # Translation vector
        trans = -A[:3, :3] @ c + c + t  # Correct translation formula
        A[:3, 3] = trans  # Set translation part
        return A

    def get_dump(self) -> tuple:
        """Collect the serialisable state of this registration object.

        Returns:
            A tuple containing the version tag followed by all state components
            needed to reconstruct the object via :meth:`load_`.
        """
        return (
            1,  # version
            sitk_to_nii(self._img_moving, True).to_gird(),
            sitk_to_nii(self._img_fixed, True).to_gird(),
            self._transform,
            self.error_reg,
            self.error_natural,
            self.input_poi,
            self.out_poi,
        )

    def save(self, path: str | Path) -> None:
        """Serialise the registration state to a pickle file.

        Args:
            path: Destination file path.
        """
        with open(path, "wb") as w:
            pickle.dump(self.get_dump(), w)

    @classmethod
    def load(cls, path: str | Path) -> Point_Registration:
        """Load a ``Point_Registration`` from a previously saved pickle file.

        Args:
            path: Path to the pickle file created by :meth:`save`.

        Returns:
            Reconstructed ``Point_Registration`` instance.
        """
        with open(path, "rb") as w:
            return cls.load_(pickle.load(w))

    @classmethod
    def load_(cls, w: tuple) -> Point_Registration:
        """Reconstruct a ``Point_Registration`` from a raw state tuple (as returned by :meth:`get_dump`).

        Args:
            w: Serialised state tuple.

        Returns:
            Reconstructed ``Point_Registration`` instance.
        """
        self = cls.__new__(cls)
        (
            version,
            a,
            b,
            self._transform,
            self.error_reg,
            self.error_natural,
            self.input_poi,
            self.out_poi,
        ) = w
        a: Has_Grid
        b: Has_Grid
        self._img_fixed = nii_to_sitk(a.make_nii())
        self._img_moving = nii_to_sitk(b.make_nii())
        assert version == 1, f"Version mismatch {version=}"
        return self

__init__

__init__(poi_fixed: POI, poi_moving: POI, exclusion=None, log: Logger_Interface = No_Logger(), verbose=True, ax_code=None, zooms=None, leave_worst_percent_out=0.0)

Use two Centroids object to compute a ridged_points registration.

Parameters:

Name Type Description Default
ctd_fixed Centroids

description

required
ctd_movig Centroids

description

required
representative_fixed Image_Reference

description. Defaults to None.

required
representative_movig Image_Reference

description. Defaults to None.

required
exclusion list

description. Defaults to [].

None
log _type_

description. Defaults to No_Logger().

No_Logger()
verbose bool

description. Defaults to True.

True

Raises:

Type Description
ValueError

Require at least two points

Returns:

Type Description

Resample_Filter

Source code in TPTBox/registration/_ridged_points/point_registration.py
def __init__(
    self,
    poi_fixed: POI,
    poi_moving: POI,
    exclusion=None,
    log: Logger_Interface = No_Logger(),  # noqa: B008
    verbose=True,
    ax_code=None,
    zooms=None,
    leave_worst_percent_out=0.0,
):
    """Use two Centroids object to compute a ridged_points registration.

    Args:
        ctd_fixed (Centroids): _description_
        ctd_movig (Centroids): _description_
        representative_fixed (Image_Reference, optional): _description_. Defaults to None.
        representative_movig (Image_Reference, optional): _description_. Defaults to None.
        exclusion (list, optional): _description_. Defaults to [].
        log (_type_, optional): _description_. Defaults to No_Logger().
        verbose (bool, optional): _description_. Defaults to True.

    Raises:
        ValueError: Require at least two points

    Returns:
        Resample_Filter
    """
    assert leave_worst_percent_out < 1.0
    assert leave_worst_percent_out >= 0.0
    if exclusion is None:
        exclusion = []
    if ax_code is not None:
        poi_fixed.reorient_(ax_code)
    if zooms is not None and zooms != (-1, -1, -1):
        poi_fixed.rescale_(zooms)
    representative_f_sitk = nii_to_sitk(poi_fixed.make_empty_nii())
    representative_m_sitk = nii_to_sitk(poi_moving.make_empty_nii())

    # Register
    # filter points by name
    f_keys = list(filter(lambda x: x[0] not in exclusion, poi_fixed.keys()))
    m_keys = list(poi_moving.keys())
    # limit to only shared labels
    inter = [x for x in f_keys if x in m_keys]
    log.print(f_keys, verbose=verbose)
    log.print(poi_fixed.orientation, verbose=verbose)

    if len(inter) < 2:
        log.print("[!] To few points, skip registration", Log_Type.FAIL)
        raise ValueError(
            f"[!] To few points, skip registration; {poi_fixed.keys()=}; {poi_moving.keys()=}",
        )
    img_movig = poi_moving.make_empty_nii()
    assert img_movig.shape == poi_moving.shape_int, (img_movig, poi_moving.shape)
    assert img_movig.orientation == poi_moving.orientation
    if leave_worst_percent_out != 0.0:
        poi_fixed = poi_fixed.intersect(poi_moving)
        init_transform, error_reg, error_natural, delta_after = _compute_versor(
            inter,
            poi_fixed,
            representative_f_sitk,
            poi_moving,
            representative_m_sitk,
            verbose=False,
            log=log,
        )
        delta_after = sorted(delta_after.items(), key=lambda x: -x[1])
        out_str = f"Did not use the following keys for registaiton (worst {leave_worst_percent_out * 100} %) "
        for i, key in enumerate(delta_after):
            if i >= len(delta_after) * leave_worst_percent_out:
                break
            poi_fixed.remove_centroid_(key[0])
            out_str += f"{key}, "
        log.print(out_str, verbose=verbose)
        log.print("Error with all points", error_reg, Log_Type.STAGE, verbose=verbose)
        f_keys = list(filter(lambda x: x[0] not in exclusion, poi_fixed.keys()))
        m_keys = list(poi_moving.keys())
    # limit to only shared labels
    inter = [x for x in f_keys if x in m_keys]
    init_transform, error_reg, error_natural, _ = _compute_versor(
        inter,
        poi_fixed,
        representative_f_sitk,
        poi_moving,
        representative_m_sitk,
        verbose=verbose,
        log=log,
    )
    self._transform: sitk.VersorRigid3DTransform = init_transform

    ### Point Reg
    self._img_moving: sitk.Image = representative_m_sitk
    self._img_fixed: sitk.Image = representative_f_sitk
    self.error_reg: float = error_reg
    self.error_natural: float = error_natural
    self.input_poi: Has_Grid = poi_moving.to_gird()
    self.out_poi: Has_Grid = poi_fixed.to_gird()

get_resampler

get_resampler(seg: bool, c_val: float, output_space: NII | None = None) -> sitk.ResampleImageFilter

Build a configured SimpleITK resampler for this registration transform.

Parameters:

Name Type Description Default
seg bool

If True, use nearest-neighbour interpolation (segmentation mode).

required
c_val float

Default fill value for background voxels (ignored when seg=True).

required
output_space NII | None

Optional target image space. Defaults to the fixed image space.

None

Returns:

Type Description
ResampleImageFilter

A configured sitk.ResampleImageFilter ready to be executed.

Source code in TPTBox/registration/_ridged_points/point_registration.py
def get_resampler(self, seg: bool, c_val: float, output_space: NII | None = None) -> sitk.ResampleImageFilter:
    """Build a configured SimpleITK resampler for this registration transform.

    Args:
        seg: If True, use nearest-neighbour interpolation (segmentation mode).
        c_val: Default fill value for background voxels (ignored when seg=True).
        output_space: Optional target image space. Defaults to the fixed image space.

    Returns:
        A configured ``sitk.ResampleImageFilter`` ready to be executed.
    """
    resampler: sitk.ResampleImageFilter = sitk.ResampleImageFilter()
    if output_space is None:
        resampler.SetReferenceImage(self._img_fixed)
    else:
        resampler.SetReferenceImage(nii_to_sitk(output_space))
    if seg:
        resampler.SetInterpolator(sitk.sitkNearestNeighbor)
        resampler.SetDefaultPixelValue(0)
    else:
        resampler.SetInterpolator(sitk.sitkBSplineResampler)
        resampler.SetDefaultPixelValue(math.floor(c_val))
    resampler.SetTransform(self._transform)
    return resampler

transform

transform(x: NII_or_POI) -> NII_or_POI

Apply the registration transform to a NII image or a POI landmark set.

Parameters:

Name Type Description Default
x NII_or_POI

A NII image or POI landmark set to transform.

required

Returns:

Type Description
NII_or_POI

Transformed object of the same type as the input.

Raises:

Type Description
ValueError

If x is neither a NII nor a POI.

Source code in TPTBox/registration/_ridged_points/point_registration.py
def transform(self, x: NII_or_POI) -> NII_or_POI:
    """Apply the registration transform to a NII image or a POI landmark set.

    Args:
        x: A ``NII`` image or ``POI`` landmark set to transform.

    Returns:
        Transformed object of the same type as the input.

    Raises:
        ValueError: If *x* is neither a ``NII`` nor a ``POI``.
    """
    if isinstance(x, POI):
        return self.transform_poi(x)
    if isinstance(x, NII):
        return self.transform_nii(x)
    raise ValueError

transform_poi

transform_poi(poi_moving: POI, allow_only_same_grid_as_moving: bool = True, output_space: NII | POI | None = None) -> POI

Transform a set of landmarks (POI) from the moving to the fixed image space.

Parameters:

Name Type Description Default
poi_moving POI

Landmark set defined in the moving image space.

required
allow_only_same_grid_as_moving bool

If True, assert that poi_moving shares the grid of the moving image used during registration.

True
output_space NII | POI | None

Optional target space to resample the result into.

None

Returns:

Type Description
POI

Transformed POI in the fixed (or output_space) coordinate frame.

Source code in TPTBox/registration/_ridged_points/point_registration.py
def transform_poi(
    self,
    poi_moving: POI,
    allow_only_same_grid_as_moving: bool = True,
    output_space: NII | POI | None = None,
) -> POI:
    """Transform a set of landmarks (POI) from the moving to the fixed image space.

    Args:
        poi_moving: Landmark set defined in the moving image space.
        allow_only_same_grid_as_moving: If True, assert that *poi_moving* shares
            the grid of the moving image used during registration.
        output_space: Optional target space to resample the result into.

    Returns:
        Transformed ``POI`` in the fixed (or *output_space*) coordinate frame.
    """
    # output_space: POI | NII | None = None,
    if allow_only_same_grid_as_moving:
        text = "input image must be in the same space as moving.  If you sure that this input is in same space as the moving image you can turn of 'only_allow_grid_as_moving'"
        poi_moving.assert_affine(self.input_poi, text=text)
    move_l = []
    keys = []
    out = dict(zip_strict(keys, move_l))

    for key, key2, (x, y, z) in poi_moving.items():
        out[key, key2] = self.transform_cord((x, y, z))

    poi = self.out_poi.make_empty_POI(out)
    if output_space is not None:
        poi = poi.resample_from_to(output_space)
    return poi

transform_cord

transform_cord(cord: tuple[float, ...], out: Image | None = None) -> np.ndarray

Transform a single voxel coordinate from moving to fixed image space.

Parameters:

Name Type Description Default
cord tuple[float, ...]

Voxel coordinate (x, y, z) in the moving image.

required
out Image | None

Reference SimpleITK image defining the output space. Defaults to the fixed image used during registration.

None

Returns:

Type Description
ndarray

Transformed coordinate as a NumPy array of shape (3,).

Source code in TPTBox/registration/_ridged_points/point_registration.py
def transform_cord(self, cord: tuple[float, ...], out: sitk.Image | None = None) -> np.ndarray:
    """Transform a single voxel coordinate from moving to fixed image space.

    Args:
        cord: Voxel coordinate (x, y, z) in the moving image.
        out: Reference SimpleITK image defining the output space.
            Defaults to the fixed image used during registration.

    Returns:
        Transformed coordinate as a NumPy array of shape (3,).
    """
    if out is None:
        out = self._img_fixed
    ctr_b = self._img_moving.TransformContinuousIndexToPhysicalPoint(cord)
    ctr_b = self._transform.GetInverse().TransformPoint(ctr_b)
    ctr_b = out.TransformPhysicalPointToContinuousIndex(ctr_b)
    return np.array(ctr_b)

transform_cord_inverse

transform_cord_inverse(cord: tuple[float, ...], out: Image | None = None) -> np.ndarray

Transform a single voxel coordinate from fixed to moving image space (inverse direction).

Parameters:

Name Type Description Default
cord tuple[float, ...]

Voxel coordinate (x, y, z) in the fixed image.

required
out Image | None

Reference SimpleITK image defining the output space. Defaults to the fixed image used during registration.

None

Returns:

Type Description
ndarray

Transformed coordinate as a NumPy array of shape (3,).

Source code in TPTBox/registration/_ridged_points/point_registration.py
def transform_cord_inverse(self, cord: tuple[float, ...], out: sitk.Image | None = None) -> np.ndarray:
    """Transform a single voxel coordinate from fixed to moving image space (inverse direction).

    Args:
        cord: Voxel coordinate (x, y, z) in the fixed image.
        out: Reference SimpleITK image defining the output space.
            Defaults to the fixed image used during registration.

    Returns:
        Transformed coordinate as a NumPy array of shape (3,).
    """
    if out is None:
        out = self._img_fixed
    ctr_b = out.TransformContinuousIndexToPhysicalPoint(cord)
    ctr_b = self._transform.TransformPoint(ctr_b)
    ctr_b = self._img_moving.TransformPhysicalPointToContinuousIndex(ctr_b)
    return np.array(ctr_b)

transform_nii

transform_nii(moving_img_nii: NII, allow_only_same_grid_as_moving: bool = True, output_space: NII | None = None, c_val: float | None = None) -> NII

Resample a NII image from the moving into the fixed image space.

Parameters:

Name Type Description Default
moving_img_nii NII

Image defined in the moving image space.

required
allow_only_same_grid_as_moving bool

If True, assert that moving_img_nii shares the grid of the moving image used during registration.

True
output_space NII | None

Optional target space for the resampled output. Defaults to the fixed image space.

None
c_val float | None

Background fill value. Derived from the image automatically when not provided.

None

Returns:

Type Description
NII

Resampled NII in the fixed (or output_space) coordinate frame.

Source code in TPTBox/registration/_ridged_points/point_registration.py
def transform_nii(
    self,
    moving_img_nii: NII,
    allow_only_same_grid_as_moving: bool = True,
    output_space: NII | None = None,
    c_val: float | None = None,
) -> NII:
    """Resample a NII image from the moving into the fixed image space.

    Args:
        moving_img_nii: Image defined in the moving image space.
        allow_only_same_grid_as_moving: If True, assert that *moving_img_nii*
            shares the grid of the moving image used during registration.
        output_space: Optional target space for the resampled output.
            Defaults to the fixed image space.
        c_val: Background fill value. Derived from the image automatically
            when not provided.

    Returns:
        Resampled ``NII`` in the fixed (or *output_space*) coordinate frame.
    """
    if allow_only_same_grid_as_moving:
        text = "input image must be in the same space as moving.  If you sure that this input is in same space as the moving image you can turn of 'only_allow_grid_as_moving'"
        moving_img_nii.assert_affine(self.input_poi, text=text, shape_tolerance=0.9)
    if c_val is None:
        c_val = moving_img_nii.get_c_val()
    resampler = self.get_resampler(moving_img_nii.seg, c_val, output_space=output_space)
    img_sitk = nii_to_sitk(moving_img_nii)
    transformed_img = resampler.Execute(img_sitk)
    if moving_img_nii.seg:
        transformed_img = sitk.Round(transformed_img)
    return sitk_to_nii(transformed_img, seg=moving_img_nii.seg)

get_affine

get_affine() -> np.ndarray

Return the 4x4 affine matrix corresponding to the rigid registration transform.

The matrix follows the convention

T(x) = A(x - c) + (t + c)

Returns:

Type Description
ndarray

A (4, 4) NumPy array representing the homogeneous affine transform.

Source code in TPTBox/registration/_ridged_points/point_registration.py
def get_affine(self) -> np.ndarray:
    """Return the 4x4 affine matrix corresponding to the rigid registration transform.

    The matrix follows the convention:
        T(x) = A(x - c) + (t + c)

    Returns:
        A (4, 4) NumPy array representing the homogeneous affine transform.
    """
    # VersorRigid3DTransform
    # T(x) = A ( x - c ) + (t + c)
    # T(x) = Ax (- Ac + t + c)
    # let C = (- Ac + t + c)
    # (C^T*(Ax)^T)
    assert isinstance(self._transform, sitk.VersorRigid3DTransform)
    A = np.eye(4)  # noqa: N806
    v = self._transform.GetInverse()
    A[:3, :3] = np.array(v.GetMatrix()).reshape(3, 3)  # Rotation matrix
    c = np.array(v.GetCenter())  # Center of rotation
    t = np.array(v.GetTranslation())  # Translation vector
    trans = -A[:3, :3] @ c + c + t  # Correct translation formula
    A[:3, 3] = trans  # Set translation part
    return A

get_dump

get_dump() -> tuple

Collect the serialisable state of this registration object.

Returns:

Type Description
tuple

A tuple containing the version tag followed by all state components

tuple

needed to reconstruct the object via :meth:load_.

Source code in TPTBox/registration/_ridged_points/point_registration.py
def get_dump(self) -> tuple:
    """Collect the serialisable state of this registration object.

    Returns:
        A tuple containing the version tag followed by all state components
        needed to reconstruct the object via :meth:`load_`.
    """
    return (
        1,  # version
        sitk_to_nii(self._img_moving, True).to_gird(),
        sitk_to_nii(self._img_fixed, True).to_gird(),
        self._transform,
        self.error_reg,
        self.error_natural,
        self.input_poi,
        self.out_poi,
    )

save

save(path: str | Path) -> None

Serialise the registration state to a pickle file.

Parameters:

Name Type Description Default
path str | Path

Destination file path.

required
Source code in TPTBox/registration/_ridged_points/point_registration.py
def save(self, path: str | Path) -> None:
    """Serialise the registration state to a pickle file.

    Args:
        path: Destination file path.
    """
    with open(path, "wb") as w:
        pickle.dump(self.get_dump(), w)

load classmethod

load(path: str | Path) -> Point_Registration

Load a Point_Registration from a previously saved pickle file.

Parameters:

Name Type Description Default
path str | Path

Path to the pickle file created by :meth:save.

required

Returns:

Type Description
Point_Registration

Reconstructed Point_Registration instance.

Source code in TPTBox/registration/_ridged_points/point_registration.py
@classmethod
def load(cls, path: str | Path) -> Point_Registration:
    """Load a ``Point_Registration`` from a previously saved pickle file.

    Args:
        path: Path to the pickle file created by :meth:`save`.

    Returns:
        Reconstructed ``Point_Registration`` instance.
    """
    with open(path, "rb") as w:
        return cls.load_(pickle.load(w))

load_ classmethod

load_(w: tuple) -> Point_Registration

Reconstruct a Point_Registration from a raw state tuple (as returned by :meth:get_dump).

Parameters:

Name Type Description Default
w tuple

Serialised state tuple.

required

Returns:

Type Description
Point_Registration

Reconstructed Point_Registration instance.

Source code in TPTBox/registration/_ridged_points/point_registration.py
@classmethod
def load_(cls, w: tuple) -> Point_Registration:
    """Reconstruct a ``Point_Registration`` from a raw state tuple (as returned by :meth:`get_dump`).

    Args:
        w: Serialised state tuple.

    Returns:
        Reconstructed ``Point_Registration`` instance.
    """
    self = cls.__new__(cls)
    (
        version,
        a,
        b,
        self._transform,
        self.error_reg,
        self.error_natural,
        self.input_poi,
        self.out_poi,
    ) = w
    a: Has_Grid
    b: Has_Grid
    self._img_fixed = nii_to_sitk(a.make_nii())
    self._img_moving = nii_to_sitk(b.make_nii())
    assert version == 1, f"Version mismatch {version=}"
    return self

TPTBox.registration.ridged_points_from_poi

ridged_points_from_poi(poi_fixed: POI, poi_moving: POI, exclusion: list | None = None, log: Logger_Interface = No_Logger(), verbose: bool = True, ax_code=None, zooms=None, c_val: float | None = None, leave_worst_percent_out: float = 0.0) -> Point_Registration

Compute a rigid point-based registration from two POI landmark sets.

Parameters:

Name Type Description Default
poi_fixed POI

Landmark set in the fixed (target) image space.

required
poi_moving POI

Landmark set in the moving (source) image space.

required
exclusion list | None

List of landmark keys to exclude from the alignment.

None
log Logger_Interface

Logger used for progress and diagnostic output.

No_Logger()
verbose bool

If True, print detailed per-landmark information.

True
ax_code

Optional orientation code to reorient poi_fixed before registration.

None
zooms

Optional target voxel spacing to rescale poi_fixed before registration.

None
c_val float | None

Deprecated. Background fill value — has no effect and will trigger a DeprecationWarning if supplied.

None
leave_worst_percent_out float

Fraction (0–1) of worst-fitting landmarks to discard before computing the final transform.

0.0

Returns:

Type Description
Point_Registration

Fitted Point_Registration object.

Source code in TPTBox/registration/_ridged_points/point_registration.py
def ridged_points_from_poi(
    poi_fixed: POI,
    poi_moving: POI,
    exclusion: list | None = None,
    log: Logger_Interface = No_Logger(),  # noqa: B008
    verbose: bool = True,
    ax_code=None,
    zooms=None,
    c_val: float | None = None,
    leave_worst_percent_out: float = 0.0,
) -> Point_Registration:
    """Compute a rigid point-based registration from two POI landmark sets.

    Args:
        poi_fixed: Landmark set in the fixed (target) image space.
        poi_moving: Landmark set in the moving (source) image space.
        exclusion: List of landmark keys to exclude from the alignment.
        log: Logger used for progress and diagnostic output.
        verbose: If True, print detailed per-landmark information.
        ax_code: Optional orientation code to reorient *poi_fixed* before registration.
        zooms: Optional target voxel spacing to rescale *poi_fixed* before registration.
        c_val: Deprecated. Background fill value — has no effect and will trigger a
            ``DeprecationWarning`` if supplied.
        leave_worst_percent_out: Fraction (0–1) of worst-fitting landmarks to discard
            before computing the final transform.

    Returns:
        Fitted ``Point_Registration`` object.
    """
    if c_val is not None:
        warnings.warn(
            "c_val of ridged_points_from_poi is never used.",
            DeprecationWarning,
            stacklevel=4,
        )
    return Point_Registration(
        poi_fixed,
        poi_moving,
        exclusion=exclusion,
        log=log,  # noqa: B008
        verbose=verbose,
        ax_code=ax_code,
        zooms=zooms,
        leave_worst_percent_out=leave_worst_percent_out,
    )

TPTBox.registration.ridged_points_from_subreg_vert

ridged_points_from_subreg_vert(poi_moving: POI_Reference, vert: Image_Reference, subreg: POI_Reference, poi_target_buffer: Path | str | None = None, orientation=None, zoom: tuple[float, float, float] = (-1, -1, -1), subreg_id: int | Location | list[int | Location] | list[Location] | list[int] = 50, c_val: float = -1050, verbose: bool = True, save_buffer_file: bool = True) -> Point_Registration

Compute a rigid point-based registration using vertebra sub-region centroids.

Derives a fixed-space POI from instance and semantic segmentation masks, then aligns it to a pre-computed moving-space POI.

Parameters:

Name Type Description Default
poi_moving POI_Reference

POI landmark set (or loadable reference) for the moving image.

required
vert Image_Reference

Instance (vertebra label) segmentation in the fixed image space.

required
subreg POI_Reference

Semantic sub-region segmentation in the fixed image space.

required
poi_target_buffer Path | str | None

Optional path to cache / load the computed fixed POI.

None
orientation

Optional orientation code to reorient the fixed POI.

None
zoom tuple[float, float, float]

Target voxel spacing for rescaling the fixed POI. Pass (-1, -1, -1) to skip rescaling.

(-1, -1, -1)
subreg_id int | Location | list[int | Location] | list[Location] | list[int]

Sub-region label(s) used to extract centroid landmarks.

50
c_val float

Background fill value forwarded to the registration (currently unused internally; kept for API compatibility).

-1050
verbose bool

If True, print progress information.

True
save_buffer_file bool

If True, save the computed fixed POI to poi_target_buffer.

True

Returns:

Type Description
Point_Registration

Fitted Point_Registration object.

Source code in TPTBox/registration/_ridged_points/point_registration.py
def ridged_points_from_subreg_vert(
    poi_moving: POI_Reference,
    vert: Image_Reference,
    subreg: POI_Reference,
    poi_target_buffer: Path | str | None = None,
    orientation=None,
    zoom: tuple[float, float, float] = (-1, -1, -1),
    subreg_id: int | Location | list[int | Location] | list[Location] | list[int] = 50,
    c_val: float = -1050,
    verbose: bool = True,
    save_buffer_file: bool = True,
) -> Point_Registration:
    """Compute a rigid point-based registration using vertebra sub-region centroids.

    Derives a fixed-space POI from instance and semantic segmentation masks, then
    aligns it to a pre-computed moving-space POI.

    Args:
        poi_moving: POI landmark set (or loadable reference) for the moving image.
        vert: Instance (vertebra label) segmentation in the fixed image space.
        subreg: Semantic sub-region segmentation in the fixed image space.
        poi_target_buffer: Optional path to cache / load the computed fixed POI.
        orientation: Optional orientation code to reorient the fixed POI.
        zoom: Target voxel spacing for rescaling the fixed POI.
            Pass ``(-1, -1, -1)`` to skip rescaling.
        subreg_id: Sub-region label(s) used to extract centroid landmarks.
        c_val: Background fill value forwarded to the registration (currently unused
            internally; kept for API compatibility).
        verbose: If True, print progress information.
        save_buffer_file: If True, save the computed fixed POI to *poi_target_buffer*.

    Returns:
        Fitted ``Point_Registration`` object.
    """
    if not isinstance(subreg_id, (list, tuple)):
        subreg_id = [subreg_id]
    instance_nii = to_nii(vert, True).copy()
    semantic_nii = to_nii(subreg, True).copy()
    target_poi = (
        calc_poi_from_subreg_vert(
            instance_nii,
            semantic_nii,
            subreg_id=subreg_id,
            buffer_file=poi_target_buffer,
            save_buffer_file=save_buffer_file,
        )
        .copy()
        .extract_subregion_(*subreg_id)
    )
    if orientation is not None:
        target_poi.reorient_(orientation)
    if zoom != (-1, -1, -1):
        target_poi.rescale_(zoom)
    moving_poi = POI.load(poi_moving)
    return ridged_points_from_poi(target_poi, moving_poi, c_val=c_val, verbose=verbose)

Deformable Registration

TPTBox.registration.Deformable_Registration

Bases: General_Registration

Deformable registration between a fixed and moving image using deepali.

Wraps General_Registration with default loss terms (LNCC + BSpline bending energy) and a Stationary Velocity Field Free-Form Deformation (SVFFD) transform.

Attributes:

Name Type Description
transform

The learned deformation field resulting from the registration.

ref_nii

Reference NII object used for registration.

grid

Target grid for image warping.

mov

Processed version of the moving image.

Source code in TPTBox/registration/_deformable/deformable_reg.py
class Deformable_Registration(General_Registration):
    """Deformable registration between a fixed and moving image using deepali.

    Wraps ``General_Registration`` with default loss terms (LNCC + BSpline bending
    energy) and a Stationary Velocity Field Free-Form Deformation (SVFFD) transform.

    Attributes:
        transform: The learned deformation field resulting from the registration.
        ref_nii: Reference NII object used for registration.
        grid: Target grid for image warping.
        mov: Processed version of the moving image.
    """

    def __init__(
        self,
        fixed_image: Image_Reference,
        moving_image: Image_Reference,
        fixed_seg: Image_Reference | None = None,
        moving_seg: Image_Reference | None = None,
        reference_image: Image_Reference | None = None,
        source_pset=None,
        target_pset=None,
        source_landmarks=None,
        target_landmarks=None,
        # source_seg: Optional[Union[Image, PathStr]] = None,  # Masking the registration source
        # target_seg: Optional[Union[Image, PathStr]] = None,  # Masking the registration target
        device: Union[torch.device, str, int] | None = None,
        gpu=0,
        ddevice: DEVICES = "cuda",
        # foreground_mask
        fixed_mask: Image_Reference | None = None,
        moving_mask: Image_Reference | None = None,
        # normalize
        normalize_strategy: (
            Literal["auto", "CT", "MRI"] | None
        ) = "auto",  # Override on_normalize for finer normalization schema or normalize before and set to None. auto: [min,max] -> [0,1]; None: Do noting
        # Pyramid
        pyramid_levels: int | None = 3,  # 1/None = no pyramid; int: number of stacks, tuple from to (0 is finest)
        finest_level: int = 0,
        coarsest_level: int | None = None,
        pyramid_finest_spacing: Sequence[int] | torch.Tensor | None = None,
        pyramid_min_size=16,
        dims=("x", "y", "z"),
        align=False,
        transform_name: str = "SVFFD",  # Names that are defined in deepali.spatial.LINEAR_TRANSFORMS and deepali.spatialNONRIGID_TRANSFORMS. Override on_make_transform for finer controle
        transform_args: dict | None = None,
        transform_init: PathStr | None = None,  # reload initial flowfield from file
        optim_name="Adam",  # Optimizer name defined in torch.optim. or override on_optimizer finer controle
        lr: float | Sequence[float] = 0.001,  # Learning rate
        lr_end_factor: float | None = None,  # if set, will use a LinearLR scheduler to reduce the learning rate to this factor * lr
        optim_args=None,  # args of Optimizer with out lr
        smooth_grad=0.0,
        verbose=0,
        max_steps: int | Sequence[int] = 1000,  # Early stopping.  override on_converged finer controle
        max_history: int | None = 100,
        min_value=0.0,  # Early stopping.  override on_converged finer controle
        min_delta: float | Sequence[float] = -0.0001,  # Early stopping.  override on_converged finer controle
        loss_terms: list[LOSS | str] | dict[str, LOSS] | dict[str, str] | dict[str, tuple[str, dict]] | None = None,
        weights: list[float] | dict[str, float | list[float]] | None = None,
        auto_run=True,
        stride=8,
    ):
        if transform_args is None:
            transform_args = {"stride": [stride, stride, stride], "transpose": False}
        if "transpose" in transform_args and transform_name in [
            "StationaryVelocityFieldTransform",
            "SVF",
            "SVField",
            "DenseVectorFieldTransform",
        ]:
            transform_args.pop("transpose")

        if loss_terms is None:
            loss_terms = {
                "be": BSplineBending(stride=1),
                "lncc": LNCC(),
            }
        if weights is None:
            weights = {"be": 0.001, "lncc": 1}
        super().__init__(
            fixed_image=fixed_image,
            moving_image=moving_image,
            fixed_seg=fixed_seg,
            moving_seg=moving_seg,
            reference_image=reference_image,
            source_pset=source_pset,
            target_pset=target_pset,
            source_landmarks=source_landmarks,
            target_landmarks=target_landmarks,
            device=device,
            gpu=gpu,
            ddevice=ddevice,
            fixed_mask=fixed_mask,
            moving_mask=moving_mask,
            normalize_strategy=normalize_strategy,
            pyramid_levels=pyramid_levels,
            finest_level=finest_level,
            coarsest_level=coarsest_level,
            pyramid_finest_spacing=pyramid_finest_spacing,
            pyramid_min_size=pyramid_min_size,
            dims=dims,
            align=align,
            transform_name=transform_name,
            transform_args=transform_args,
            transform_init=transform_init,
            optim_name=optim_name,
            lr=lr,
            lr_end_factor=lr_end_factor,
            optim_args=optim_args,
            smooth_grad=smooth_grad,
            verbose=verbose,
            max_steps=max_steps,
            max_history=max_history,
            min_value=min_value,
            min_delta=min_delta,
            loss_terms=loss_terms,
            weights=weights,
            auto_run=auto_run,
        )

TPTBox.registration.Template_Registration

Multi-stage registration between two multi-label segmentations.

Supports optional POI landmark alignment and deformable registration; landmarks are computed on the fly if not provided. Particularly useful for MRI/CT atlas alignment with optional body-side flip handling.

Attributes:

Name Type Description
same_side bool

Whether the target and atlas represent the same anatomical side (e.g., both right sides).

reg_point Point_Registration

The rigid point-based registration component.

reg_deform Deformable_Registration

The deformable registration component.

crop tuple

The crop applied to both target and atlas after registration.

target_grid_org NII

Original spatial grid of the target.

atlas_org NII

Original spatial grid of the atlas.

target_grid NII

Cropped spatial grid used for deformable registration.

Source code in TPTBox/registration/_deformable/multilabel_segmentation.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
class Template_Registration:
    """Multi-stage registration between two multi-label segmentations.

    Supports optional POI landmark alignment and deformable registration; landmarks are computed
    on the fly if not provided.  Particularly useful for MRI/CT atlas alignment with optional
    body-side flip handling.

    Attributes:
        same_side (bool): Whether the target and atlas represent the same anatomical side (e.g., both right sides).
        reg_point (Point_Registration): The rigid point-based registration component.
        reg_deform (Deformable_Registration): The deformable registration component.
        crop (tuple): The crop applied to both target and atlas after registration.
        target_grid_org (NII): Original spatial grid of the target.
        atlas_org (NII): Original spatial grid of the atlas.
        target_grid (NII): Cropped spatial grid used for deformable registration.
    """

    def __init__(  # noqa: C901
        self,
        target_seg: NII,
        atlas_seg: NII,
        target_img: NII | None = None,
        atlas_img: NII | None = None,
        poi_cms: POI | None = None,
        same_side: bool = True,
        verbose=99,
        gpu=0,
        ddevice: DEVICES = "cuda",
        loss_terms=None,  # type: ignore
        weights=None,
        lr=0.01,
        lr_end_factor=None,
        max_steps=1500,
        min_delta: float | list[float] = 1e-06,
        pyramid_levels=4,
        coarsest_level=3,
        finest_level=0,
        crop: bool = True,
        cms_ids: list | None = None,
        poi_target_cms: POI | None = None,
        max_history=100,
        change_after_point_reg=lambda x, y, z, w: (x, y, z, w),
        **args,
    ):
        """Initialize a multi-stage registration pipeline from an atlas to a target image.

        Args:
            target (NII): Target image segmentation (e.g., from a subject).
            atlas (NII): Atlas image segmentation (e.g., a reference or template).
            target_img (NII): Target image if None the segmentation is used as an image.
            atlas_img (NII): Atlas image if None the segmentation is used as an image.
            poi_cms (POI | None): POI centroids of the atlas, used for initial point registration.
            same_side (bool): Whether atlas and target represent the same body side.
            verbose (int): Verbosity level for logging.
            gpu (int): GPU device ID (only relevant if using GPU).
            ddevice (DEVICES): Device type ('cuda' or 'cpu').
            loss_terms (dict): Dictionary of loss terms for deformable registration.
            weights (dict): Weights for the loss terms.
            lr (float): Learning rate for deformable registration optimizer.
            max_steps (int): Maximum optimization steps.
            min_delta (float): Minimum delta for convergence.
            pyramid_levels (int): Number of resolution levels in multi-scale deformable registration.
            coarsest_level (int): Coarsest level index.
            finest_level (int): Finest level index.
            cms_ids (list | None): List of segmentation labels used to extract POI centroids.
            poi_target_cms (POI | None): Optional precomputed centroids for the target image.
            **args: Additional keyword arguments passed to Deformable_Registration.

        Raises:
            ValueError: If an invalid axis is detected during flipping.
        """
        if weights is None:
            weights = {"be": 0.0001, "seg": 1, "Dice": 0.01, "Tether": 0.001}
        if loss_terms is None:
            loss_terms = {
                "be": ("BSplineBending", {"stride": 1}),
                "seg": "MSE",
                "Dice": "Dice",
                "Tether": Tether_Seg(delta=5),
            }

        assert target_seg.seg, target_seg.seg
        assert atlas_seg.seg
        target_seg = target_seg.copy()
        atlas_seg = atlas_seg.copy()
        if target_img is not None:
            target_img = target_img.resample_from_to(target_seg)
        if atlas_img is not None:
            atlas_img = atlas_img.resample_from_to(atlas_seg)
        self.same_side = same_side
        self.target_grid_org = target_seg.to_gird()
        self.atlas_org = atlas_seg.to_gird()
        if not same_side:
            axis = target_seg.get_axis("R")
            if axis == 0:
                target_seg = target_seg.set_array(target_seg.get_array()[::-1]).copy()
                target_img = target_img.set_array(target_img.get_array()[::-1]).copy() if target_img is not None else None
            elif axis == 1:
                target_seg = target_seg.set_array(target_seg.get_array()[:, ::-1]).copy()
                target_img = target_img.set_array(target_img.get_array()[:, ::-1]).copy() if target_img is not None else None
            elif axis == 2:
                target_seg = target_seg.set_array(target_seg.get_array()[:, :, ::-1]).copy()
                target_img = target_img.set_array(target_img.get_array()[:, :, ::-1]).copy() if target_img is not None else None
            else:
                raise ValueError(axis)
            if poi_target_cms is not None:
                axis = poi_target_cms.get_axis("R")
                for k1, k2, (x, y, z) in poi_target_cms.copy().items():
                    if axis == 0:
                        poi_target_cms[k1, k2] = (poi_target_cms.shape[0] - 1 - x, y, z)
                    elif axis == 1:
                        poi_target_cms[k1, k2] = (x, poi_target_cms.shape[1] - 1 - y, z)
                    elif axis == 2:
                        poi_target_cms[k1, k2] = (x, y, poi_target_cms.shape[2] - 1 - z)
        if poi_target_cms is None:
            x = target_seg.extract_label(cms_ids, keep_label=True) if cms_ids else target_seg
            poi_target = calc_centroids(x, second_stage=40, bar=True)  # TODO REMOVE
        else:
            poi_target = poi_target_cms.resample_from_to(target_seg)
        if poi_cms is None:
            x = atlas_seg.extract_label(cms_ids, keep_label=True) if cms_ids else atlas_seg
            poi_cms = calc_centroids(x, second_stage=40, bar=True)
        if not poi_cms.assert_affine(atlas_seg, raise_error=False):
            poi_cms = poi_cms.resample_from_to(atlas_seg)
        if crop:
            print("crop")

            crop_pad_size = 50
            _step = 50
            _max_iter = 10

            resize_mode = "crop"
            resize_param: tuple | None = None
            target_tmp = target_seg

            atlas_seg_ = atlas_seg.apply_pad(((1, 1), (1, 1), (1, 1))) if atlas_seg.is_segmentation_in_border() else atlas_seg

            for i in range(_max_iter):
                if resize_mode == "crop":
                    if i != 0:
                        crop_pad_size += _step

                    # --- try crop first ---
                    t_crop = target_seg.compute_crop(0, crop_pad_size)
                    cropped = target_seg.apply_crop(t_crop).apply_pad(crop_pad_size - 50 // 4)

                    if any(c < o for c, o in zip(cropped.shape, target_seg.shape)):
                        resize_mode = "crop"
                        resize_param = t_crop
                        target_tmp = cropped
                    else:
                        # --- fallback to padding ---
                        crop_pad_size = crop_pad_size // 2
                        target_tmp = target_seg
                        resize_mode = "pad"
                else:
                    if i != 0:
                        crop_pad_size += _step // 2
                    t_pad = tuple((crop_pad_size, crop_pad_size) for _ in range(3))
                    resize_param = t_pad
                    target_tmp = target_seg.apply_pad(t_pad)

                # --- Point registration ---
                print(f"iter {i}: using {resize_mode} ({crop_pad_size})")

                poi_target = poi_target.resample_from_to(target_tmp)

                if poi_cms is None:
                    x = atlas_seg_.extract_label(cms_ids, keep_label=True) if cms_ids else atlas_seg_
                    poi_cms = calc_centroids(x, second_stage=40, bar=True)

                if not poi_cms.assert_affine(atlas_seg_, raise_error=False):
                    poi_cms = poi_cms.resample_from_to(atlas_seg_)

                self.reg_point = Point_Registration(poi_target, poi_cms, verbose=False)
                atlas_reg = self.reg_point.transform_nii(atlas_seg_)

                if not atlas_reg.is_segmentation_in_border():
                    print("point registration ok")
                    break
                else:
                    print("atlas_reg touches border → expanding")

            # --- FINAL STEP: apply once to original target ---
            if resize_mode == "crop":
                target_seg = target_seg.apply_crop(resize_param)
                target_img = target_img.apply_crop(resize_param) if target_img is not None else None
            elif resize_mode == "pad":
                target_seg = target_seg.apply_pad(resize_param)
                target_img = target_img.apply_pad(resize_param) if target_img is not None else None

        self.reg_point = Point_Registration(poi_target.resample_from_to(target_seg), poi_cms.resample_from_to(atlas_seg))
        atlas_reg = self.reg_point.transform_nii(atlas_seg)
        atlas_img_reg = self.reg_point.transform_nii(atlas_img) if atlas_img is not None else None

        if crop:
            self.crop = (target_seg + atlas_reg).compute_crop(0, 5)
            target_seg = target_seg.apply_crop(self.crop)
            target_img = target_img.apply_crop(self.crop) if target_img is not None else None
            atlas_reg = atlas_reg.apply_crop(self.crop)
            atlas_img_reg = atlas_img_reg.apply_crop(self.crop) if atlas_img_reg is not None else None
        else:
            self.crop = None

        self.target_grid = target_seg.to_gird()
        target_seg, atlas_reg, target_img, atlas_img_reg = change_after_point_reg(target_seg, atlas_reg, target_img, atlas_img_reg)
        self.reg_deform = Deformable_Registration(
            target_seg if target_img is None else target_img,
            atlas_reg if atlas_img_reg is None else atlas_img_reg,
            target_seg.copy(),
            atlas_reg.copy(),
            loss_terms=loss_terms,
            weights=weights,
            lr=lr,
            lr_end_factor=lr_end_factor,
            max_steps=max_steps,
            min_delta=min_delta,
            pyramid_levels=pyramid_levels,
            coarsest_level=coarsest_level,
            finest_level=finest_level,
            verbose=verbose,
            gpu=gpu,
            ddevice=ddevice,
            max_history=max_history,
            **args,
        )

    def get_dump(self) -> tuple:
        """Collect the serialisable state of this registration object.

        Returns:
            A tuple containing the version tag followed by all state components
            needed to reconstruct the object via :meth:`load_`.
        """
        return (
            1,  # version
            (self.reg_point.get_dump()),
            (self.reg_deform.get_dump()),
            (
                self.same_side,
                self.atlas_org,
                self.target_grid_org,
                self.target_grid,
                self.crop,
            ),
        )

    def save(self, path: str | Path) -> None:
        """Serialise the registration state to a pickle file.

        Args:
            path: Destination file path.
        """
        with open(path, "wb") as w:
            pickle.dump(self.get_dump(), w)

    @classmethod
    def load(cls, path: str | Path) -> Template_Registration:
        """Load a previously saved registration state from a pickle file.

        Args:
            path: Path to the pickle file created by :meth:`save`.

        Returns:
            Reconstructed ``Template_Registration`` instance.
        """
        with open(path, "rb") as w:
            return cls.load_(pickle.load(w))

    @classmethod
    def load_(cls, w: tuple) -> Template_Registration:
        """Reconstruct a ``Template_Registration`` from a raw state tuple (as returned by :meth:`get_dump`).

        Args:
            w: Serialised state tuple.

        Returns:
            Reconstructed ``Template_Registration`` instance.
        """
        (version, t0, t1, x) = w
        assert version == 1, f"Version mismatch {version=}"
        self = cls.__new__(cls)
        self.reg_point = Point_Registration.load_(t0)
        self.reg_deform = Deformable_Registration.load_(t1)
        (
            self.same_side,
            self.atlas_org,
            self.target_grid_org,
            self.target_grid,
            self.crop,
        ) = x

        return self

    def transform_nii(self, nii_atlas: NII, allow_only_same_grid_as_moving: bool = True, only_rigid=False) -> NII:
        """Apply both rigid and deformable registration to a NII image.

        Args:
            nii_atlas: Atlas image to be transformed (must share the atlas grid).
            allow_only_same_grid_as_moving: If True, assert that *nii_atlas* matches
                the grid of the moving image used during point registration.

        Returns:
            Transformed ``NII`` aligned with the original target image space.
        """
        nii_atlas = self.reg_point.transform_nii(nii_atlas, allow_only_same_grid_as_moving=allow_only_same_grid_as_moving)
        if only_rigid:
            return nii_atlas

        nii_atlas = nii_atlas.apply_crop(self.crop)
        nii_reg = self.reg_deform.transform_nii(nii_atlas)
        if nii_reg.seg:
            nii_reg.set_dtype_("smallest_uint")
        out = nii_reg.resample_from_to(self.target_grid_org, mode="constant")
        if self.same_side:
            return out
        axis = out.get_axis("R")
        if axis == 0:
            target = out.set_array(out.get_array()[::-1]).copy()
        elif axis == 1:
            target = out.set_array(out.get_array()[:, ::-1]).copy()
        elif axis == 2:
            target = out.set_array(out.get_array()[:, :, ::-1]).copy()
        else:
            raise ValueError(axis)

        return target

    def transform_poi(self, poi_atlas: POI_Global | POI) -> POI:
        """Apply both rigid and deformable registration to a POI landmark set.

        Args:
            poi_atlas: Atlas landmarks to be transformed (defined in the atlas space).

        Returns:
            Transformed ``POI`` landmarks aligned to the target image space.
        """
        poi_atlas = poi_atlas.resample_from_to(self.atlas_org)

        # Point Reg
        poi_atlas = self.reg_point.transform_poi(poi_atlas)
        # Deformable
        poi_atlas = poi_atlas.apply_crop(self.crop)

        poi_reg = self.reg_deform.transform_poi(poi_atlas)
        poi_reg = poi_reg.resample_from_to(self.target_grid_org)
        if self.same_side:
            return poi_reg
        for k1, k2, v in poi_reg.copy().items():
            k = k1  # % 100
            poi_reg[k, k2] = v
        poi_reg_flip = poi_reg.make_empty_POI()
        for k1, k2, (x, y, z) in poi_reg.copy().items():
            axis = poi_reg.get_axis("R")
            if axis == 0:
                poi_reg_flip[k1, k2] = (poi_reg.shape[0] - 1 - x, y, z)
            elif axis == 1:
                poi_reg_flip[k1, k2] = (x, poi_reg.shape[1] - 1 - y, z)
            elif axis == 2:
                poi_reg_flip[k1, k2] = (x, y, poi_reg.shape[2] - 1 - z)
            else:
                raise ValueError(axis)
        return poi_reg_flip

    def transform_poi_inverse(self, poi_target: POI_Global | POI):
        """Transform POIs from target space back into atlas space.

        Args:
            poi_target (POI_Global | POI): POIs defined in target space.

        Returns:
            POI: POIs mapped back into atlas space.
        """
        poi = poi_target.copy()

        # --- undo left/right flip if needed ---
        if not self.same_side:
            poi_flip = poi.make_empty_POI()
            axis = poi.get_axis("R")

            for k1, k2, (x, y, z) in poi.copy().items():
                if axis == 0:
                    poi_flip[k1, k2] = (poi.shape[0] - 1 - x, y, z)
                elif axis == 1:
                    poi_flip[k1, k2] = (x, poi.shape[1] - 1 - y, z)
                elif axis == 2:
                    poi_flip[k1, k2] = (x, y, poi.shape[2] - 1 - z)
                else:
                    raise ValueError(axis)

            poi = poi_flip

        # --- resample into deformable registration grid ---
        poi = poi.resample_from_to(self.target_grid)

        # --- inverse deformable registration ---
        reg_deform_inv = self.reg_deform.inverse()
        poi = reg_deform_inv.transform_poi(poi)

        # --- undo crop ---
        # if self.crop is not None:
        #    poi = poi.apply_crop_inverse(self.crop)

        # --- inverse rigid point registration ---
        poi = self.reg_point.transform_poi_inverse(poi, allow_only_same_grid_as_moving=False)

        # --- back to atlas grid ---
        poi = poi.resample_from_to(self.atlas_org)

        return poi

__init__

__init__(target_seg: NII, atlas_seg: NII, target_img: NII | None = None, atlas_img: NII | None = None, poi_cms: POI | None = None, same_side: bool = True, verbose=99, gpu=0, ddevice: DEVICES = 'cuda', loss_terms=None, weights=None, lr=0.01, lr_end_factor=None, max_steps=1500, min_delta: float | list[float] = 1e-06, pyramid_levels=4, coarsest_level=3, finest_level=0, crop: bool = True, cms_ids: list | None = None, poi_target_cms: POI | None = None, max_history=100, change_after_point_reg=lambda x, y, z, w: (x, y, z, w), **args)

Initialize a multi-stage registration pipeline from an atlas to a target image.

Parameters:

Name Type Description Default
target NII

Target image segmentation (e.g., from a subject).

required
atlas NII

Atlas image segmentation (e.g., a reference or template).

required
target_img NII

Target image if None the segmentation is used as an image.

None
atlas_img NII

Atlas image if None the segmentation is used as an image.

None
poi_cms POI | None

POI centroids of the atlas, used for initial point registration.

None
same_side bool

Whether atlas and target represent the same body side.

True
verbose int

Verbosity level for logging.

99
gpu int

GPU device ID (only relevant if using GPU).

0
ddevice DEVICES

Device type ('cuda' or 'cpu').

'cuda'
loss_terms dict

Dictionary of loss terms for deformable registration.

None
weights dict

Weights for the loss terms.

None
lr float

Learning rate for deformable registration optimizer.

0.01
max_steps int

Maximum optimization steps.

1500
min_delta float

Minimum delta for convergence.

1e-06
pyramid_levels int

Number of resolution levels in multi-scale deformable registration.

4
coarsest_level int

Coarsest level index.

3
finest_level int

Finest level index.

0
cms_ids list | None

List of segmentation labels used to extract POI centroids.

None
poi_target_cms POI | None

Optional precomputed centroids for the target image.

None
**args

Additional keyword arguments passed to Deformable_Registration.

{}

Raises:

Type Description
ValueError

If an invalid axis is detected during flipping.

Source code in TPTBox/registration/_deformable/multilabel_segmentation.py
def __init__(  # noqa: C901
    self,
    target_seg: NII,
    atlas_seg: NII,
    target_img: NII | None = None,
    atlas_img: NII | None = None,
    poi_cms: POI | None = None,
    same_side: bool = True,
    verbose=99,
    gpu=0,
    ddevice: DEVICES = "cuda",
    loss_terms=None,  # type: ignore
    weights=None,
    lr=0.01,
    lr_end_factor=None,
    max_steps=1500,
    min_delta: float | list[float] = 1e-06,
    pyramid_levels=4,
    coarsest_level=3,
    finest_level=0,
    crop: bool = True,
    cms_ids: list | None = None,
    poi_target_cms: POI | None = None,
    max_history=100,
    change_after_point_reg=lambda x, y, z, w: (x, y, z, w),
    **args,
):
    """Initialize a multi-stage registration pipeline from an atlas to a target image.

    Args:
        target (NII): Target image segmentation (e.g., from a subject).
        atlas (NII): Atlas image segmentation (e.g., a reference or template).
        target_img (NII): Target image if None the segmentation is used as an image.
        atlas_img (NII): Atlas image if None the segmentation is used as an image.
        poi_cms (POI | None): POI centroids of the atlas, used for initial point registration.
        same_side (bool): Whether atlas and target represent the same body side.
        verbose (int): Verbosity level for logging.
        gpu (int): GPU device ID (only relevant if using GPU).
        ddevice (DEVICES): Device type ('cuda' or 'cpu').
        loss_terms (dict): Dictionary of loss terms for deformable registration.
        weights (dict): Weights for the loss terms.
        lr (float): Learning rate for deformable registration optimizer.
        max_steps (int): Maximum optimization steps.
        min_delta (float): Minimum delta for convergence.
        pyramid_levels (int): Number of resolution levels in multi-scale deformable registration.
        coarsest_level (int): Coarsest level index.
        finest_level (int): Finest level index.
        cms_ids (list | None): List of segmentation labels used to extract POI centroids.
        poi_target_cms (POI | None): Optional precomputed centroids for the target image.
        **args: Additional keyword arguments passed to Deformable_Registration.

    Raises:
        ValueError: If an invalid axis is detected during flipping.
    """
    if weights is None:
        weights = {"be": 0.0001, "seg": 1, "Dice": 0.01, "Tether": 0.001}
    if loss_terms is None:
        loss_terms = {
            "be": ("BSplineBending", {"stride": 1}),
            "seg": "MSE",
            "Dice": "Dice",
            "Tether": Tether_Seg(delta=5),
        }

    assert target_seg.seg, target_seg.seg
    assert atlas_seg.seg
    target_seg = target_seg.copy()
    atlas_seg = atlas_seg.copy()
    if target_img is not None:
        target_img = target_img.resample_from_to(target_seg)
    if atlas_img is not None:
        atlas_img = atlas_img.resample_from_to(atlas_seg)
    self.same_side = same_side
    self.target_grid_org = target_seg.to_gird()
    self.atlas_org = atlas_seg.to_gird()
    if not same_side:
        axis = target_seg.get_axis("R")
        if axis == 0:
            target_seg = target_seg.set_array(target_seg.get_array()[::-1]).copy()
            target_img = target_img.set_array(target_img.get_array()[::-1]).copy() if target_img is not None else None
        elif axis == 1:
            target_seg = target_seg.set_array(target_seg.get_array()[:, ::-1]).copy()
            target_img = target_img.set_array(target_img.get_array()[:, ::-1]).copy() if target_img is not None else None
        elif axis == 2:
            target_seg = target_seg.set_array(target_seg.get_array()[:, :, ::-1]).copy()
            target_img = target_img.set_array(target_img.get_array()[:, :, ::-1]).copy() if target_img is not None else None
        else:
            raise ValueError(axis)
        if poi_target_cms is not None:
            axis = poi_target_cms.get_axis("R")
            for k1, k2, (x, y, z) in poi_target_cms.copy().items():
                if axis == 0:
                    poi_target_cms[k1, k2] = (poi_target_cms.shape[0] - 1 - x, y, z)
                elif axis == 1:
                    poi_target_cms[k1, k2] = (x, poi_target_cms.shape[1] - 1 - y, z)
                elif axis == 2:
                    poi_target_cms[k1, k2] = (x, y, poi_target_cms.shape[2] - 1 - z)
    if poi_target_cms is None:
        x = target_seg.extract_label(cms_ids, keep_label=True) if cms_ids else target_seg
        poi_target = calc_centroids(x, second_stage=40, bar=True)  # TODO REMOVE
    else:
        poi_target = poi_target_cms.resample_from_to(target_seg)
    if poi_cms is None:
        x = atlas_seg.extract_label(cms_ids, keep_label=True) if cms_ids else atlas_seg
        poi_cms = calc_centroids(x, second_stage=40, bar=True)
    if not poi_cms.assert_affine(atlas_seg, raise_error=False):
        poi_cms = poi_cms.resample_from_to(atlas_seg)
    if crop:
        print("crop")

        crop_pad_size = 50
        _step = 50
        _max_iter = 10

        resize_mode = "crop"
        resize_param: tuple | None = None
        target_tmp = target_seg

        atlas_seg_ = atlas_seg.apply_pad(((1, 1), (1, 1), (1, 1))) if atlas_seg.is_segmentation_in_border() else atlas_seg

        for i in range(_max_iter):
            if resize_mode == "crop":
                if i != 0:
                    crop_pad_size += _step

                # --- try crop first ---
                t_crop = target_seg.compute_crop(0, crop_pad_size)
                cropped = target_seg.apply_crop(t_crop).apply_pad(crop_pad_size - 50 // 4)

                if any(c < o for c, o in zip(cropped.shape, target_seg.shape)):
                    resize_mode = "crop"
                    resize_param = t_crop
                    target_tmp = cropped
                else:
                    # --- fallback to padding ---
                    crop_pad_size = crop_pad_size // 2
                    target_tmp = target_seg
                    resize_mode = "pad"
            else:
                if i != 0:
                    crop_pad_size += _step // 2
                t_pad = tuple((crop_pad_size, crop_pad_size) for _ in range(3))
                resize_param = t_pad
                target_tmp = target_seg.apply_pad(t_pad)

            # --- Point registration ---
            print(f"iter {i}: using {resize_mode} ({crop_pad_size})")

            poi_target = poi_target.resample_from_to(target_tmp)

            if poi_cms is None:
                x = atlas_seg_.extract_label(cms_ids, keep_label=True) if cms_ids else atlas_seg_
                poi_cms = calc_centroids(x, second_stage=40, bar=True)

            if not poi_cms.assert_affine(atlas_seg_, raise_error=False):
                poi_cms = poi_cms.resample_from_to(atlas_seg_)

            self.reg_point = Point_Registration(poi_target, poi_cms, verbose=False)
            atlas_reg = self.reg_point.transform_nii(atlas_seg_)

            if not atlas_reg.is_segmentation_in_border():
                print("point registration ok")
                break
            else:
                print("atlas_reg touches border → expanding")

        # --- FINAL STEP: apply once to original target ---
        if resize_mode == "crop":
            target_seg = target_seg.apply_crop(resize_param)
            target_img = target_img.apply_crop(resize_param) if target_img is not None else None
        elif resize_mode == "pad":
            target_seg = target_seg.apply_pad(resize_param)
            target_img = target_img.apply_pad(resize_param) if target_img is not None else None

    self.reg_point = Point_Registration(poi_target.resample_from_to(target_seg), poi_cms.resample_from_to(atlas_seg))
    atlas_reg = self.reg_point.transform_nii(atlas_seg)
    atlas_img_reg = self.reg_point.transform_nii(atlas_img) if atlas_img is not None else None

    if crop:
        self.crop = (target_seg + atlas_reg).compute_crop(0, 5)
        target_seg = target_seg.apply_crop(self.crop)
        target_img = target_img.apply_crop(self.crop) if target_img is not None else None
        atlas_reg = atlas_reg.apply_crop(self.crop)
        atlas_img_reg = atlas_img_reg.apply_crop(self.crop) if atlas_img_reg is not None else None
    else:
        self.crop = None

    self.target_grid = target_seg.to_gird()
    target_seg, atlas_reg, target_img, atlas_img_reg = change_after_point_reg(target_seg, atlas_reg, target_img, atlas_img_reg)
    self.reg_deform = Deformable_Registration(
        target_seg if target_img is None else target_img,
        atlas_reg if atlas_img_reg is None else atlas_img_reg,
        target_seg.copy(),
        atlas_reg.copy(),
        loss_terms=loss_terms,
        weights=weights,
        lr=lr,
        lr_end_factor=lr_end_factor,
        max_steps=max_steps,
        min_delta=min_delta,
        pyramid_levels=pyramid_levels,
        coarsest_level=coarsest_level,
        finest_level=finest_level,
        verbose=verbose,
        gpu=gpu,
        ddevice=ddevice,
        max_history=max_history,
        **args,
    )

get_dump

get_dump() -> tuple

Collect the serialisable state of this registration object.

Returns:

Type Description
tuple

A tuple containing the version tag followed by all state components

tuple

needed to reconstruct the object via :meth:load_.

Source code in TPTBox/registration/_deformable/multilabel_segmentation.py
def get_dump(self) -> tuple:
    """Collect the serialisable state of this registration object.

    Returns:
        A tuple containing the version tag followed by all state components
        needed to reconstruct the object via :meth:`load_`.
    """
    return (
        1,  # version
        (self.reg_point.get_dump()),
        (self.reg_deform.get_dump()),
        (
            self.same_side,
            self.atlas_org,
            self.target_grid_org,
            self.target_grid,
            self.crop,
        ),
    )

save

save(path: str | Path) -> None

Serialise the registration state to a pickle file.

Parameters:

Name Type Description Default
path str | Path

Destination file path.

required
Source code in TPTBox/registration/_deformable/multilabel_segmentation.py
def save(self, path: str | Path) -> None:
    """Serialise the registration state to a pickle file.

    Args:
        path: Destination file path.
    """
    with open(path, "wb") as w:
        pickle.dump(self.get_dump(), w)

load classmethod

load(path: str | Path) -> Template_Registration

Load a previously saved registration state from a pickle file.

Parameters:

Name Type Description Default
path str | Path

Path to the pickle file created by :meth:save.

required

Returns:

Type Description
Template_Registration

Reconstructed Template_Registration instance.

Source code in TPTBox/registration/_deformable/multilabel_segmentation.py
@classmethod
def load(cls, path: str | Path) -> Template_Registration:
    """Load a previously saved registration state from a pickle file.

    Args:
        path: Path to the pickle file created by :meth:`save`.

    Returns:
        Reconstructed ``Template_Registration`` instance.
    """
    with open(path, "rb") as w:
        return cls.load_(pickle.load(w))

load_ classmethod

load_(w: tuple) -> Template_Registration

Reconstruct a Template_Registration from a raw state tuple (as returned by :meth:get_dump).

Parameters:

Name Type Description Default
w tuple

Serialised state tuple.

required

Returns:

Type Description
Template_Registration

Reconstructed Template_Registration instance.

Source code in TPTBox/registration/_deformable/multilabel_segmentation.py
@classmethod
def load_(cls, w: tuple) -> Template_Registration:
    """Reconstruct a ``Template_Registration`` from a raw state tuple (as returned by :meth:`get_dump`).

    Args:
        w: Serialised state tuple.

    Returns:
        Reconstructed ``Template_Registration`` instance.
    """
    (version, t0, t1, x) = w
    assert version == 1, f"Version mismatch {version=}"
    self = cls.__new__(cls)
    self.reg_point = Point_Registration.load_(t0)
    self.reg_deform = Deformable_Registration.load_(t1)
    (
        self.same_side,
        self.atlas_org,
        self.target_grid_org,
        self.target_grid,
        self.crop,
    ) = x

    return self

transform_nii

transform_nii(nii_atlas: NII, allow_only_same_grid_as_moving: bool = True, only_rigid=False) -> NII

Apply both rigid and deformable registration to a NII image.

Parameters:

Name Type Description Default
nii_atlas NII

Atlas image to be transformed (must share the atlas grid).

required
allow_only_same_grid_as_moving bool

If True, assert that nii_atlas matches the grid of the moving image used during point registration.

True

Returns:

Type Description
NII

Transformed NII aligned with the original target image space.

Source code in TPTBox/registration/_deformable/multilabel_segmentation.py
def transform_nii(self, nii_atlas: NII, allow_only_same_grid_as_moving: bool = True, only_rigid=False) -> NII:
    """Apply both rigid and deformable registration to a NII image.

    Args:
        nii_atlas: Atlas image to be transformed (must share the atlas grid).
        allow_only_same_grid_as_moving: If True, assert that *nii_atlas* matches
            the grid of the moving image used during point registration.

    Returns:
        Transformed ``NII`` aligned with the original target image space.
    """
    nii_atlas = self.reg_point.transform_nii(nii_atlas, allow_only_same_grid_as_moving=allow_only_same_grid_as_moving)
    if only_rigid:
        return nii_atlas

    nii_atlas = nii_atlas.apply_crop(self.crop)
    nii_reg = self.reg_deform.transform_nii(nii_atlas)
    if nii_reg.seg:
        nii_reg.set_dtype_("smallest_uint")
    out = nii_reg.resample_from_to(self.target_grid_org, mode="constant")
    if self.same_side:
        return out
    axis = out.get_axis("R")
    if axis == 0:
        target = out.set_array(out.get_array()[::-1]).copy()
    elif axis == 1:
        target = out.set_array(out.get_array()[:, ::-1]).copy()
    elif axis == 2:
        target = out.set_array(out.get_array()[:, :, ::-1]).copy()
    else:
        raise ValueError(axis)

    return target

transform_poi

transform_poi(poi_atlas: POI_Global | POI) -> POI

Apply both rigid and deformable registration to a POI landmark set.

Parameters:

Name Type Description Default
poi_atlas POI_Global | POI

Atlas landmarks to be transformed (defined in the atlas space).

required

Returns:

Type Description
POI

Transformed POI landmarks aligned to the target image space.

Source code in TPTBox/registration/_deformable/multilabel_segmentation.py
def transform_poi(self, poi_atlas: POI_Global | POI) -> POI:
    """Apply both rigid and deformable registration to a POI landmark set.

    Args:
        poi_atlas: Atlas landmarks to be transformed (defined in the atlas space).

    Returns:
        Transformed ``POI`` landmarks aligned to the target image space.
    """
    poi_atlas = poi_atlas.resample_from_to(self.atlas_org)

    # Point Reg
    poi_atlas = self.reg_point.transform_poi(poi_atlas)
    # Deformable
    poi_atlas = poi_atlas.apply_crop(self.crop)

    poi_reg = self.reg_deform.transform_poi(poi_atlas)
    poi_reg = poi_reg.resample_from_to(self.target_grid_org)
    if self.same_side:
        return poi_reg
    for k1, k2, v in poi_reg.copy().items():
        k = k1  # % 100
        poi_reg[k, k2] = v
    poi_reg_flip = poi_reg.make_empty_POI()
    for k1, k2, (x, y, z) in poi_reg.copy().items():
        axis = poi_reg.get_axis("R")
        if axis == 0:
            poi_reg_flip[k1, k2] = (poi_reg.shape[0] - 1 - x, y, z)
        elif axis == 1:
            poi_reg_flip[k1, k2] = (x, poi_reg.shape[1] - 1 - y, z)
        elif axis == 2:
            poi_reg_flip[k1, k2] = (x, y, poi_reg.shape[2] - 1 - z)
        else:
            raise ValueError(axis)
    return poi_reg_flip

transform_poi_inverse

transform_poi_inverse(poi_target: POI_Global | POI)

Transform POIs from target space back into atlas space.

Parameters:

Name Type Description Default
poi_target POI_Global | POI

POIs defined in target space.

required

Returns:

Name Type Description
POI

POIs mapped back into atlas space.

Source code in TPTBox/registration/_deformable/multilabel_segmentation.py
def transform_poi_inverse(self, poi_target: POI_Global | POI):
    """Transform POIs from target space back into atlas space.

    Args:
        poi_target (POI_Global | POI): POIs defined in target space.

    Returns:
        POI: POIs mapped back into atlas space.
    """
    poi = poi_target.copy()

    # --- undo left/right flip if needed ---
    if not self.same_side:
        poi_flip = poi.make_empty_POI()
        axis = poi.get_axis("R")

        for k1, k2, (x, y, z) in poi.copy().items():
            if axis == 0:
                poi_flip[k1, k2] = (poi.shape[0] - 1 - x, y, z)
            elif axis == 1:
                poi_flip[k1, k2] = (x, poi.shape[1] - 1 - y, z)
            elif axis == 2:
                poi_flip[k1, k2] = (x, y, poi.shape[2] - 1 - z)
            else:
                raise ValueError(axis)

        poi = poi_flip

    # --- resample into deformable registration grid ---
    poi = poi.resample_from_to(self.target_grid)

    # --- inverse deformable registration ---
    reg_deform_inv = self.reg_deform.inverse()
    poi = reg_deform_inv.transform_poi(poi)

    # --- undo crop ---
    # if self.crop is not None:
    #    poi = poi.apply_crop_inverse(self.crop)

    # --- inverse rigid point registration ---
    poi = self.reg_point.transform_poi_inverse(poi, allow_only_same_grid_as_moving=False)

    # --- back to atlas grid ---
    poi = poi.resample_from_to(self.atlas_org)

    return poi

Deep Learning Registration (DeepALI)

TPTBox.registration.General_Registration

Bases: DeepaliPairwiseImageTrainer

A class for performing deformable registration between a fixed and moving image.

Attributes:

Name Type Description
transform Tensor

The transformation matrix resulting from the registration.

ref_nii NII

Reference NII object used for registration.

grid Tensor

Target grid for image warping.

mov NII

Processed version of the moving image.

Source code in TPTBox/registration/_deepali/deepali_model.py
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
class General_Registration(DeepaliPairwiseImageTrainer):
    """A class for performing deformable registration between a fixed and moving image.

    Attributes:
        transform (torch.Tensor): The transformation matrix resulting from the registration.
        ref_nii (NII): Reference NII object used for registration.
        grid (torch.Tensor): Target grid for image warping.
        mov (NII): Processed version of the moving image.
    """

    def __init__(
        self,
        fixed_image: Image_Reference,
        moving_image: Image_Reference,
        fixed_seg: Image_Reference | None = None,
        moving_seg: Image_Reference | None = None,
        reference_image: Image_Reference | None = None,
        source_pset=None,
        target_pset=None,
        source_landmarks: POI | None = None,
        target_landmarks: POI | None = None,
        # source_seg: Optional[Union[Image, PathStr]] = None,  # Masking the registration source
        # target_seg: Optional[Union[Image, PathStr]] = None,  # Masking the registration target
        device: Union[torch.device, str, int] | None = None,
        gpu=0,
        ddevice: DEVICES = "cuda",
        # foreground_mask
        fixed_mask: Image_Reference | None = None,
        moving_mask: Image_Reference | None = None,
        # normalize
        normalize_strategy: (
            Literal["auto", "CT", "MRI"] | None
        ) = "auto",  # Override on_normalize for finer normalization schema or normalize before and set to None. auto: [min,max] -> [0,1]; None: Do noting
        # Pyramid
        pyramid_levels: int | None = None,  # 1/None = no pyramid; int: number of stacks, tuple from to (0 is finest)
        finest_level: int = 0,
        coarsest_level: int | None = None,
        pyramid_finest_spacing: Sequence[int] | torch.Tensor | None = None,
        pyramid_min_size=16,
        dims=("x", "y", "z"),
        align=False,
        transform_name: str = "SVFFD",  # Names that are defined in deepali.spatial.LINEAR_TRANSFORMS and deepali.spatialNONRIGID_TRANSFORMS. Override on_make_transform for finer control
        transform_args: dict | None = None,
        transform_init: PathStr | None = None,  # reload initial flowfield from file
        optim_name="Adam",  # Optimizer name defined in torch.optim. or override on_optimizer finer control
        lr: float | Sequence[float] = 0.01,  # Learning rate
        lr_end_factor: float | None = None,  # if set, will use a LinearLR scheduler to reduce the learning rate to this factor * lr
        optim_args=None,  # args of Optimizer with out lr
        smooth_grad=0.0,
        verbose=99,
        max_steps: int | Sequence[int] = 250,  # Early stopping.  override on_converged finer control
        max_history: int | None = 100,
        min_value=0.0,  # Early stopping.  override on_converged finer control
        min_delta: float | Sequence[float] = 0.0,  # Early stopping.  override on_converged finer control
        loss_terms: list[LOSS | str] | dict[str, LOSS] | dict[str, str] | dict[str, tuple[str, dict]] | None = None,
        weights: list[float] | dict[str, float | list[float]] | None = None,
        auto_run=True,
    ) -> None:
        if device is None:
            # self.gpu = gpu
            # self.ddevice: DEVICES = ddevice
            device = get_device(ddevice, gpu)
        fix = to_nii(fixed_image).copy()
        mov = to_nii(moving_image).copy()
        if reference_image is None:
            reference_image = fix
        else:
            fix = fix.resample_from_to(reference_image)
            if fixed_seg is not None:
                fixed_seg = to_nii(fixed_seg, True).resample_from_to(reference_image)
        ## Resample and save images
        source = mov  # .resample_from_to_(reference_image)
        ## Load configuration and perform registration
        self.target_grid = fix.to_gird()
        self.input_grid = mov.to_gird()
        self.source_landmarks_poi = source_landmarks
        self.target_landmarks_poi = target_landmarks
        self._is_inverted = False

        super().__init__(
            source=source.to_deepali(),
            target=fix.to_deepali(),
            source_seg=to_nii(moving_seg, True).to_deepali() if fixed_seg is not None else None,
            target_seg=to_nii(fixed_seg, True).to_deepali() if moving_seg is not None else None,
            source_pset=source_pset,
            target_pset=target_pset,
            source_landmarks=source_landmarks,
            target_landmarks=target_landmarks,
            device=device,
            target_mask=to_nii(fixed_mask, True).resample_from_to(fix, verbose=False).to_deepali() if fixed_mask is not None else None,
            source_mask=to_nii(moving_mask, True).to_deepali() if moving_mask is not None else None,
            normalize_strategy=normalize_strategy,
            pyramid_levels=pyramid_levels,
            finest_level=finest_level,
            coarsest_level=coarsest_level,
            pyramid_finest_spacing=pyramid_finest_spacing,
            pyramid_min_size=pyramid_min_size,
            dims=dims,
            align=align,
            transform_name=transform_name,
            transform_args=transform_args,
            transform_init=transform_init,
            optim_name=optim_name,
            lr=lr,
            lr_end_factor=lr_end_factor,
            optim_args=optim_args,
            smooth_grad=smooth_grad,
            verbose=verbose,
            max_steps=max_steps,
            max_history=max_history,
            min_value=min_value,
            min_delta=min_delta,
            loss_terms=loss_terms,
            weights=weights,
        )
        if auto_run:
            self.run()

    # def on_transform_update(self, transform: SpatialTransform):
    #    if self.source_landmarks_poi is not None and self.target_landmarks_poi is not None:
    #        lm = self.source_landmarks_poi.copy()
    #        tm = self.target_landmarks_poi.copy()
    #        for k in lm.keys().copy():
    #            if k not in tm:
    #                lm.remove_(k)
    #        for k in tm.keys().copy():
    #            if k not in lm:
    #                tm.remove_(k)
    #        self.source_landmarks = self.poi_to_deepali(lm, transform)
    #        self.target_landmarks = self.poi_to_deepali(tm, transform)
    #        assert self.source_landmarks.shape == self.target_landmarks.shape, (self.source_landmarks.shape, self.target_landmarks.shape)

    # def poi_to_deepali(self, poi: POI, transform: SpatialTransform):
    #    import torch
    #    from deepali.core import Axes
    #    keys: list[tuple[int, int]] = []
    #    points = []
    #    for key, key2, (x, y, z) in poi.items(sort=True):
    #        keys.append((key, key2))
    #        points.append((x, y, z))
    #        print(key, key2)
    #    with torch.inference_mode():
    #        data = torch.Tensor(points).unsqueeze(0)
    #        # data = (
    #        #    poi.to_deepali_grid()
    #        #    .transform_points(data, axes=Axes.GRID, to_grid=transform.grid(), to_axes=transform.axes(), decimals=None)
    #        #    .unsqueeze(0)
    #        # )
    #    return data.clone()
    def inverse(self) -> Self:
        """Invert the registration transformation.

        Returns:
            Self: The instance with the inverted transformation.
        """
        self._is_inverted = not self._is_inverted
        from copy import copy

        out = copy(self)
        out._is_inverted = not self._is_inverted
        return out

    # def on_run_end(
    #    self,
    #    grid_transform,
    #    target_image: deepaliImage,
    #    source_image: deepaliImage,
    #    target_image_seg: deepaliImage,
    #    source_image_seg: deepaliImage,
    #    opt,
    #    lr_sq,
    #    num_steps,
    #    level,
    # ):
    #    import numpy as np
    #
    #    arr_target = (
    #        target_image.tensor()
    #        .squeeze()
    #        .permute(2, 1, 0)
    #        .detach()
    #        .cpu()
    #        .float()
    #        .numpy()
    #    )
    #    grid = NII.from_deepali_grid(target_image.grid())
    #    nii_target = grid.make_nii(arr_target, False)
    #    nii_target.save(
    #        f"/DATA/NAS/datasets_processed/CT_spine/dataset-myelom/target_img{level}.nii.gz"
    #    )
    #    arr_source = (
    #        source_image.tensor()
    #        .squeeze()
    #        .permute(2, 1, 0)
    #        .detach()
    #        .cpu()
    #        .float()
    #        .numpy()
    #    )
    #    nii_source = grid.make_nii(arr_source, False)
    #    nii_source.save(
    #        f"/DATA/NAS/datasets_processed/CT_spine/dataset-myelom/source_img{level}.nii.gz"
    #    )
    #    arr = source_image_seg.tensor().permute(0, 3, 2, 1).detach().cpu().numpy()
    #
    #    arr_new_source_seg = np.zeros(arr.shape[-3:])
    #    print(arr_new_source_seg.shape)
    #    print(arr.shape)
    #    for i in range(arr.shape[0]):
    #        arr_new_source_seg[arr[i] >= 0.5] = i
    #    nii_source = grid.make_nii(arr_new_source_seg.astype(np.uint16), True)
    #    nii_source.save(
    #        f"/DATA/NAS/datasets_processed/CT_spine/dataset-myelom/source{level}.nii.gz"
    #    )
    #    arr_target_seg = target_image_seg.tensor().permute(0, 3, 2, 1).detach().cpu().numpy()
    #
    #    arr_new_target_seg = np.zeros(arr_target_seg.shape[-3:])
    #    for i in range(arr_target_seg.shape[0]):
    #        arr_new_target_seg[arr_target_seg[i] >= 0.5] = i
    #    nii_target_seg = grid.make_nii(arr_new_target_seg.astype(np.uint16), True)
    #    nii_target_seg.save(
    #        f"/DATA/NAS/datasets_processed/CT_spine/dataset-myelom/target{level}.nii.gz"
    #    )
    #    out = self.transform_nii(nii_target_seg)
    #    out.save(
    #        f"/DATA/NAS/datasets_processed/CT_spine/dataset-myelom/moved{level}.nii.gz"
    #    )
    #    dice = out.resample_from_to(nii_source).dice(nii_source)
    #    from TPTBox import Print_Logger
    #
    #    Print_Logger().on_debug(np.mean(list(dice.values())), dice)
    #    # exit()

    @torch.no_grad()
    def transform_nii(
        self,
        img: NII,
        gpu: int | None = None,
        ddevice: DEVICES | None = None,
        target: Has_Grid | None = None,
        align_corners=True,
        inverse=False,
    ) -> NII:
        """Apply the computed transformation to a given NII image.

        Args:
            img (NII): The NII image to be transformed.

        Returns:
            NII: The transformed image as an NII object.
        """
        if self._is_inverted:
            inverse = not inverse
        device = get_device(ddevice, 0 if gpu is None else gpu) if ddevice is not None else self.device
        target_grid_nii = self.target_grid if target is None else target
        target_grid = target_grid_nii.to_deepali_grid(align_corners)
        source_image = img.resample_from_to(self.input_grid, mode="constant").to_deepali()
        data = _warp_image(
            source_image,
            target_grid,
            self.transform,
            "nearest" if img.seg else "linear",
            device=device,
            inverse=inverse,
        ).squeeze()
        data: torch.Tensor = data.permute(*torch.arange(data.ndim - 1, -1, -1))  # type: ignore
        out = target_grid_nii.make_nii(data.detach().cpu().numpy(), img.seg)
        return out

    def transform_poi(
        self,
        poi: POI,
        gpu: int | None = None,
        ddevice: DEVICES | None = None,
        align_corners: bool = True,
        inverse: bool = True,
    ) -> POI:
        """Apply the computed registration to a POI object.

        Args:
            poi: Source ``POI`` with centroid coordinates to transform.
            gpu: GPU index override. Defaults to the device used during registration.
            ddevice: Device type override (e.g. ``"cuda"``).
            align_corners: Whether corners or centres are aligned during warping.
            inverse: Apply the inverse transform when ``True`` (default, maps
                moving centroids to fixed space).

        Returns:
            A new ``POI`` resampled to the target grid with transformed coordinates.
        """
        if self._is_inverted:
            inverse = not inverse
        device = get_device(ddevice, 0 if gpu is None else gpu) if ddevice is not None else self.device
        source_image = poi.resample_from_to(self.target_grid)
        data = _warp_poi(
            source_image,
            self.target_grid,
            self.transform,
            align_corners,
            device=device,
            inverse=inverse,
        )
        return data.resample_from_to(self.target_grid)

    def transform_points(
        self,
        points,
        axes: Axes,
        to_axes: Axes,
        grid: Deepali_Grid | Has_Grid,
        to_grid: Deepali_Grid | Has_Grid,
        gpu: int | None = None,
        ddevice: DEVICES | None = None,
        inverse=True,
    ):
        """Transform a set of points using the registered transformation.

        Args:
            points (list): List of points to warp: (b,n) b points with n coordinates.
            axes (Axes): Axes of the input points.
            to_axes (Axes): Axes of the output points.
            grid (Deepali_Grid | Has_Grid): The grid to which the points belong.
            to_grid (Deepali_Grid | Has_Grid): The target grid for the transformed points.
            gpu (int, optional): GPU index to use. Defaults to None.
            ddevice (DEVICES, optional): Device type. Defaults to "cuda".
            inverse (bool, optional): Whether to apply the inverse transformation. Defaults to True.
        """
        if self._is_inverted:
            inverse = not inverse
        if isinstance(grid, Has_Grid):
            grid = grid.to_deepali_grid()
        if isinstance(to_grid, Has_Grid):
            to_grid = to_grid.to_deepali_grid()
        device = get_device(ddevice, 0 if gpu is None else gpu) if ddevice is not None else self.device
        return _warp_points(
            points,
            axes,
            to_axes,
            grid,
            to_grid,
            transform=self.transform,
            device=device,
            inverse=True,
        )

    def __call__(self, *args, **kwds) -> NII:
        """Call method to apply the transformation using the transform_nii method.

        Args:
            *args: Positional arguments for the transform_nii method.
            **kwds: Keyword arguments for the transform_nii method.

        Returns:
            NII: The transformed image.
        """
        return self.transform_nii(*args, **kwds)

    def get_dump(self) -> tuple:
        """Return a serialisable tuple of the registration state for pickling.

        Returns:
            Tuple of ``(transform, target_grid, input_grid, _is_inverted)``.
        """
        return (self.transform, self.target_grid, self.input_grid, self._is_inverted)

    def save(self, path: str | Path) -> None:
        """Serialise the registration result to a pickle file.

        Args:
            path: Destination file path.
        """
        with open(path, "wb") as w:
            pickle.dump(self.get_dump(), w)

    @classmethod
    def load(cls, path: str | Path, gpu: int = 0, ddevice: DEVICES = "cuda") -> Self:
        """Load a previously saved ``General_Registration`` from a pickle file.

        Args:
            path: Path to the pickle file written by :meth:`save`.
            gpu: GPU index to map the transform to.
            ddevice: Device type (e.g. ``"cuda"`` or ``"cpu"``).

        Returns:
            Reconstructed ``General_Registration`` instance.
        """
        with open(path, "rb") as w:
            return cls.load_(pickle.load(w), gpu, ddevice)

    @classmethod
    def load_(cls, w: tuple, gpu: int = 0, ddevice: DEVICES = "cuda") -> Self:
        """Reconstruct a ``General_Registration`` from a raw dump tuple.

        Args:
            w: Tuple as returned by :meth:`get_dump`.
            gpu: GPU index for device placement.
            ddevice: Device type string.

        Returns:
            Reconstructed ``General_Registration`` instance.
        """
        transform, grid, mov, _is_inverted = w
        self = cls.__new__(cls)
        self.transform = transform
        self.target_grid = grid
        self.input_grid = mov
        self._is_inverted = _is_inverted
        self.device = get_device(ddevice, gpu)
        return self

inverse

inverse() -> Self

Invert the registration transformation.

Returns:

Name Type Description
Self Self

The instance with the inverted transformation.

Source code in TPTBox/registration/_deepali/deepali_model.py
def inverse(self) -> Self:
    """Invert the registration transformation.

    Returns:
        Self: The instance with the inverted transformation.
    """
    self._is_inverted = not self._is_inverted
    from copy import copy

    out = copy(self)
    out._is_inverted = not self._is_inverted
    return out

transform_nii

transform_nii(img: NII, gpu: int | None = None, ddevice: DEVICES | None = None, target: Has_Grid | None = None, align_corners=True, inverse=False) -> NII

Apply the computed transformation to a given NII image.

Parameters:

Name Type Description Default
img NII

The NII image to be transformed.

required

Returns:

Name Type Description
NII NII

The transformed image as an NII object.

Source code in TPTBox/registration/_deepali/deepali_model.py
@torch.no_grad()
def transform_nii(
    self,
    img: NII,
    gpu: int | None = None,
    ddevice: DEVICES | None = None,
    target: Has_Grid | None = None,
    align_corners=True,
    inverse=False,
) -> NII:
    """Apply the computed transformation to a given NII image.

    Args:
        img (NII): The NII image to be transformed.

    Returns:
        NII: The transformed image as an NII object.
    """
    if self._is_inverted:
        inverse = not inverse
    device = get_device(ddevice, 0 if gpu is None else gpu) if ddevice is not None else self.device
    target_grid_nii = self.target_grid if target is None else target
    target_grid = target_grid_nii.to_deepali_grid(align_corners)
    source_image = img.resample_from_to(self.input_grid, mode="constant").to_deepali()
    data = _warp_image(
        source_image,
        target_grid,
        self.transform,
        "nearest" if img.seg else "linear",
        device=device,
        inverse=inverse,
    ).squeeze()
    data: torch.Tensor = data.permute(*torch.arange(data.ndim - 1, -1, -1))  # type: ignore
    out = target_grid_nii.make_nii(data.detach().cpu().numpy(), img.seg)
    return out

transform_poi

transform_poi(poi: POI, gpu: int | None = None, ddevice: DEVICES | None = None, align_corners: bool = True, inverse: bool = True) -> POI

Apply the computed registration to a POI object.

Parameters:

Name Type Description Default
poi POI

Source POI with centroid coordinates to transform.

required
gpu int | None

GPU index override. Defaults to the device used during registration.

None
ddevice DEVICES | None

Device type override (e.g. "cuda").

None
align_corners bool

Whether corners or centres are aligned during warping.

True
inverse bool

Apply the inverse transform when True (default, maps moving centroids to fixed space).

True

Returns:

Type Description
POI

A new POI resampled to the target grid with transformed coordinates.

Source code in TPTBox/registration/_deepali/deepali_model.py
def transform_poi(
    self,
    poi: POI,
    gpu: int | None = None,
    ddevice: DEVICES | None = None,
    align_corners: bool = True,
    inverse: bool = True,
) -> POI:
    """Apply the computed registration to a POI object.

    Args:
        poi: Source ``POI`` with centroid coordinates to transform.
        gpu: GPU index override. Defaults to the device used during registration.
        ddevice: Device type override (e.g. ``"cuda"``).
        align_corners: Whether corners or centres are aligned during warping.
        inverse: Apply the inverse transform when ``True`` (default, maps
            moving centroids to fixed space).

    Returns:
        A new ``POI`` resampled to the target grid with transformed coordinates.
    """
    if self._is_inverted:
        inverse = not inverse
    device = get_device(ddevice, 0 if gpu is None else gpu) if ddevice is not None else self.device
    source_image = poi.resample_from_to(self.target_grid)
    data = _warp_poi(
        source_image,
        self.target_grid,
        self.transform,
        align_corners,
        device=device,
        inverse=inverse,
    )
    return data.resample_from_to(self.target_grid)

transform_points

transform_points(points, axes: Axes, to_axes: Axes, grid: Grid | Has_Grid, to_grid: Grid | Has_Grid, gpu: int | None = None, ddevice: DEVICES | None = None, inverse=True)

Transform a set of points using the registered transformation.

Parameters:

Name Type Description Default
points list

List of points to warp: (b,n) b points with n coordinates.

required
axes Axes

Axes of the input points.

required
to_axes Axes

Axes of the output points.

required
grid Grid | Has_Grid

The grid to which the points belong.

required
to_grid Grid | Has_Grid

The target grid for the transformed points.

required
gpu int

GPU index to use. Defaults to None.

None
ddevice DEVICES

Device type. Defaults to "cuda".

None
inverse bool

Whether to apply the inverse transformation. Defaults to True.

True
Source code in TPTBox/registration/_deepali/deepali_model.py
def transform_points(
    self,
    points,
    axes: Axes,
    to_axes: Axes,
    grid: Deepali_Grid | Has_Grid,
    to_grid: Deepali_Grid | Has_Grid,
    gpu: int | None = None,
    ddevice: DEVICES | None = None,
    inverse=True,
):
    """Transform a set of points using the registered transformation.

    Args:
        points (list): List of points to warp: (b,n) b points with n coordinates.
        axes (Axes): Axes of the input points.
        to_axes (Axes): Axes of the output points.
        grid (Deepali_Grid | Has_Grid): The grid to which the points belong.
        to_grid (Deepali_Grid | Has_Grid): The target grid for the transformed points.
        gpu (int, optional): GPU index to use. Defaults to None.
        ddevice (DEVICES, optional): Device type. Defaults to "cuda".
        inverse (bool, optional): Whether to apply the inverse transformation. Defaults to True.
    """
    if self._is_inverted:
        inverse = not inverse
    if isinstance(grid, Has_Grid):
        grid = grid.to_deepali_grid()
    if isinstance(to_grid, Has_Grid):
        to_grid = to_grid.to_deepali_grid()
    device = get_device(ddevice, 0 if gpu is None else gpu) if ddevice is not None else self.device
    return _warp_points(
        points,
        axes,
        to_axes,
        grid,
        to_grid,
        transform=self.transform,
        device=device,
        inverse=True,
    )

__call__

__call__(*args, **kwds) -> NII

Call method to apply the transformation using the transform_nii method.

Parameters:

Name Type Description Default
*args

Positional arguments for the transform_nii method.

()
**kwds

Keyword arguments for the transform_nii method.

{}

Returns:

Name Type Description
NII NII

The transformed image.

Source code in TPTBox/registration/_deepali/deepali_model.py
def __call__(self, *args, **kwds) -> NII:
    """Call method to apply the transformation using the transform_nii method.

    Args:
        *args: Positional arguments for the transform_nii method.
        **kwds: Keyword arguments for the transform_nii method.

    Returns:
        NII: The transformed image.
    """
    return self.transform_nii(*args, **kwds)

get_dump

get_dump() -> tuple

Return a serialisable tuple of the registration state for pickling.

Returns:

Type Description
tuple

Tuple of (transform, target_grid, input_grid, _is_inverted).

Source code in TPTBox/registration/_deepali/deepali_model.py
def get_dump(self) -> tuple:
    """Return a serialisable tuple of the registration state for pickling.

    Returns:
        Tuple of ``(transform, target_grid, input_grid, _is_inverted)``.
    """
    return (self.transform, self.target_grid, self.input_grid, self._is_inverted)

save

save(path: str | Path) -> None

Serialise the registration result to a pickle file.

Parameters:

Name Type Description Default
path str | Path

Destination file path.

required
Source code in TPTBox/registration/_deepali/deepali_model.py
def save(self, path: str | Path) -> None:
    """Serialise the registration result to a pickle file.

    Args:
        path: Destination file path.
    """
    with open(path, "wb") as w:
        pickle.dump(self.get_dump(), w)

load classmethod

load(path: str | Path, gpu: int = 0, ddevice: DEVICES = 'cuda') -> Self

Load a previously saved General_Registration from a pickle file.

Parameters:

Name Type Description Default
path str | Path

Path to the pickle file written by :meth:save.

required
gpu int

GPU index to map the transform to.

0
ddevice DEVICES

Device type (e.g. "cuda" or "cpu").

'cuda'

Returns:

Type Description
Self

Reconstructed General_Registration instance.

Source code in TPTBox/registration/_deepali/deepali_model.py
@classmethod
def load(cls, path: str | Path, gpu: int = 0, ddevice: DEVICES = "cuda") -> Self:
    """Load a previously saved ``General_Registration`` from a pickle file.

    Args:
        path: Path to the pickle file written by :meth:`save`.
        gpu: GPU index to map the transform to.
        ddevice: Device type (e.g. ``"cuda"`` or ``"cpu"``).

    Returns:
        Reconstructed ``General_Registration`` instance.
    """
    with open(path, "rb") as w:
        return cls.load_(pickle.load(w), gpu, ddevice)

load_ classmethod

load_(w: tuple, gpu: int = 0, ddevice: DEVICES = 'cuda') -> Self

Reconstruct a General_Registration from a raw dump tuple.

Parameters:

Name Type Description Default
w tuple

Tuple as returned by :meth:get_dump.

required
gpu int

GPU index for device placement.

0
ddevice DEVICES

Device type string.

'cuda'

Returns:

Type Description
Self

Reconstructed General_Registration instance.

Source code in TPTBox/registration/_deepali/deepali_model.py
@classmethod
def load_(cls, w: tuple, gpu: int = 0, ddevice: DEVICES = "cuda") -> Self:
    """Reconstruct a ``General_Registration`` from a raw dump tuple.

    Args:
        w: Tuple as returned by :meth:`get_dump`.
        gpu: GPU index for device placement.
        ddevice: Device type string.

    Returns:
        Reconstructed ``General_Registration`` instance.
    """
    transform, grid, mov, _is_inverted = w
    self = cls.__new__(cls)
    self.transform = transform
    self.target_grid = grid
    self.input_grid = mov
    self._is_inverted = _is_inverted
    self.device = get_device(ddevice, gpu)
    return self