Skip to content

ToReferenceSpace

Bases: SpatialTransform

Set the spatial metadata of an image to match a reference space.

This is useful for assigning meaningful spatial metadata to a tensor that has lost it, such as a neural network embedding or a downsampled feature map. The data is left unchanged; only the affine is updated so that the (possibly lower-resolution) grid covers the same field of view, orientation, and physical center as the reference image.

A typical use case is visualizing or resampling the output of a network whose spatial resolution differs from its input:

Parameters:

Name Type Description Default
reference Image

Full-resolution reference image whose field of view and orientation will be matched.

required
**kwargs Any

See Transform.

{}

Examples:

>>> import torch
>>> import torchio as tio
>>> reference = tio.ScalarImage(tensor=torch.rand(1, 64, 64, 64))
>>> # A network embedding loses spatial metadata:
>>> embedding = torch.rand(8, 16, 16, 16)
>>> image = tio.ToReferenceSpace.from_tensor(embedding, reference)
>>> image.spatial_shape
(16, 16, 16)
Source code in src/torchio/transforms/spatial/to_reference_space.py
class ToReferenceSpace(SpatialTransform):
    r"""Set the spatial metadata of an image to match a reference space.

    This is useful for assigning meaningful spatial metadata to a
    tensor that has lost it, such as a neural network embedding or a
    downsampled feature map.  The data is left unchanged; only the
    affine is updated so that the (possibly lower-resolution) grid
    covers the same field of view, orientation, and physical center
    as the *reference* image.

    A typical use case is visualizing or resampling the output of a
    network whose spatial resolution differs from its input:

    Args:
        reference: Full-resolution reference image whose field of view
            and orientation will be matched.
        **kwargs: See [`Transform`][torchio.Transform].

    Examples:
        >>> import torch
        >>> import torchio as tio
        >>> reference = tio.ScalarImage(tensor=torch.rand(1, 64, 64, 64))
        >>> # A network embedding loses spatial metadata:
        >>> embedding = torch.rand(8, 16, 16, 16)
        >>> image = tio.ToReferenceSpace.from_tensor(embedding, reference)
        >>> image.spatial_shape
        (16, 16, 16)
    """

    def __init__(self, reference: Image, **kwargs: Any) -> None:
        super().__init__(**kwargs)
        if not isinstance(reference, Image):
            msg = f"reference must be a TorchIO Image, got {type(reference).__name__}"
            raise TypeError(msg)
        self.reference = reference

    def make_params(self, batch: SubjectsBatch) -> dict[str, Any]:
        """No random parameters."""
        return {}

    def apply_transform(
        self,
        batch: SubjectsBatch,
        params: dict[str, Any],
    ) -> SubjectsBatch:
        """Replace each image's affine with the reference-space affine."""
        for _name, img_batch in self._get_images(batch).items():
            output_shape = (
                int(img_batch.data.shape[2]),
                int(img_batch.data.shape[3]),
                int(img_batch.data.shape[4]),
            )
            new_affine = _reference_space_affine(self.reference, output_shape)
            img_batch.affines[:] = [new_affine.clone() for _ in img_batch.affines]
        return batch

    @staticmethod
    def from_tensor(tensor: Tensor, reference: Image) -> Image:
        """Build a TorchIO image from a tensor and a reference image.

        Args:
            tensor: A `(C, I, J, K)` tensor (e.g., a network
                embedding) whose spatial metadata should match the
                reference space.
            reference: Reference image whose field of view and
                orientation will be matched.

        Returns:
            A new image with *tensor* as data and a reference-space
            affine.  The image class matches that of *reference*.
        """
        output_shape = (
            int(tensor.shape[-3]),
            int(tensor.shape[-2]),
            int(tensor.shape[-1]),
        )
        new_affine = _reference_space_affine(reference, output_shape)
        cls = type(reference)
        return cls(tensor, affine=new_affine)

supports_per_instance_params property

Whether this transform can sample parameters per batch element.

Defaults to False. Transforms that implement per-instance parameter sampling override this to return True. When False, the transform always uses batch-shared parameters regardless of the per_instance flag, preserving the legacy behavior.

supports_per_instance_p property

Whether this transform can gate each batch element independently.

Defaults to False. Shape-preserving transforms that implement per-element probability override this to return True. Shape-changing transforms must leave it False because masked and unmasked elements would have incompatible shapes.

invertible property

Whether this transform can be inverted.

forward(data)

forward(data: Subject) -> Subject
forward(data: Image) -> Image
forward(data: Tensor) -> Tensor
forward(data: np.ndarray) -> np.ndarray
forward(data: sitk.Image) -> sitk.Image
forward(data: nib.Nifti1Image) -> nib.Nifti1Image
forward(data: dict) -> dict
forward(data: ImagesBatch) -> ImagesBatch
forward(data: SubjectsBatch) -> SubjectsBatch

Apply the transform.

The output type always matches the input type.

Parameters:

Name Type Description Default
data Any

Input data to transform.

required
Source code in src/torchio/transforms/transform.py
def forward(self, data: Any) -> Any:
    """Apply the transform.

    The output type always matches the input type.

    Args:
        data: Input data to transform.
    """
    if self.copy:
        data = _copy.deepcopy(data)
    batch, unwrap = self._wrap(data)
    # When per-element gating is active, the transform handles the
    # probability itself (masked-out elements get identity params),
    # so skip the batch-wide coin flip here. Apply iff rand < p, so
    # p=0 is always a no-op and p=1 always applies.
    if not self._per_instance_p_active(batch) and torch.rand(1).item() >= self.p:
        return unwrap(batch)
    params = self.make_params(batch)
    batch = self.apply_transform(batch, params)
    # Record history on the batch, unless every element was gated out by
    # per-element probability: that is an exact no-op, and recording it
    # would let history replay (e.g. an invertible spatial transform)
    # trigger an unnecessary identity resample.
    if not _all_elements_gated_out(params):
        trace = AppliedTransform(name=type(self).__name__, params=params)
        if not hasattr(batch, "applied_transforms"):
            batch.applied_transforms = []
        batch.applied_transforms.append(trace)
    result = unwrap(batch)
    # Propagate history to outputs that can carry it
    if (
        hasattr(batch, "applied_transforms")
        and not isinstance(result, (SubjectsBatch, Tensor, np.ndarray))
        and not isinstance(result, dict)
    ):
        with contextlib.suppress(AttributeError):
            result.applied_transforms = list(batch.applied_transforms)
    return result

inverse(params)

Return a transform that undoes this one.

Override in invertible subclasses. The returned transform, when applied, reverses the effect of the forward pass with the given parameters.

Parameters:

Name Type Description Default
params dict[str, Any]

The parameters recorded in the forward pass.

required

Returns:

Type Description
Transform

A new Transform instance that inverts this one.

Source code in src/torchio/transforms/transform.py
def inverse(self, params: dict[str, Any]) -> Transform:
    """Return a transform that undoes this one.

    Override in invertible subclasses. The returned transform,
    when applied, reverses the effect of the forward pass with
    the given parameters.

    Args:
        params: The parameters recorded in the forward pass.

    Returns:
        A new `Transform` instance that inverts this one.
    """
    msg = f"{type(self).__name__} is not invertible"
    raise NotImplementedError(msg)

to_hydra()

Export as a Hydra-compatible config dict.

Returns a dict with _target_ set to the fully qualified class name and only non-default field values included.

Returns:

Type Description
dict[str, Any]

Dict suitable for hydra.utils.instantiate().

Source code in src/torchio/transforms/transform.py
def to_hydra(self) -> dict[str, Any]:
    """Export as a Hydra-compatible config dict.

    Returns a dict with `_target_` set to the fully qualified
    class name and only non-default field values included.

    Returns:
        Dict suitable for `hydra.utils.instantiate()`.
    """
    from .parameter_range import _ParameterRange

    cls = type(self)
    target = f"torchio.{cls.__qualname__}"
    cfg: dict[str, Any] = {"_target_": target}

    for name, default in _collect_init_params(cls).items():
        value = getattr(self, name, default)
        if isinstance(value, _ParameterRange):
            if value._original == default:
                continue
            value = _hydra_value(value._original)
        elif value == default:
            continue
        else:
            value = _hydra_value(value)
        cfg[name] = value
    return cfg

make_params(batch)

No random parameters.

Source code in src/torchio/transforms/spatial/to_reference_space.py
def make_params(self, batch: SubjectsBatch) -> dict[str, Any]:
    """No random parameters."""
    return {}

apply_transform(batch, params)

Replace each image's affine with the reference-space affine.

Source code in src/torchio/transforms/spatial/to_reference_space.py
def apply_transform(
    self,
    batch: SubjectsBatch,
    params: dict[str, Any],
) -> SubjectsBatch:
    """Replace each image's affine with the reference-space affine."""
    for _name, img_batch in self._get_images(batch).items():
        output_shape = (
            int(img_batch.data.shape[2]),
            int(img_batch.data.shape[3]),
            int(img_batch.data.shape[4]),
        )
        new_affine = _reference_space_affine(self.reference, output_shape)
        img_batch.affines[:] = [new_affine.clone() for _ in img_batch.affines]
    return batch

from_tensor(tensor, reference) staticmethod

Build a TorchIO image from a tensor and a reference image.

Parameters:

Name Type Description Default
tensor Tensor

A (C, I, J, K) tensor (e.g., a network embedding) whose spatial metadata should match the reference space.

required
reference Image

Reference image whose field of view and orientation will be matched.

required

Returns:

Type Description
Image

A new image with tensor as data and a reference-space

Image

affine. The image class matches that of reference.

Source code in src/torchio/transforms/spatial/to_reference_space.py
@staticmethod
def from_tensor(tensor: Tensor, reference: Image) -> Image:
    """Build a TorchIO image from a tensor and a reference image.

    Args:
        tensor: A `(C, I, J, K)` tensor (e.g., a network
            embedding) whose spatial metadata should match the
            reference space.
        reference: Reference image whose field of view and
            orientation will be matched.

    Returns:
        A new image with *tensor* as data and a reference-space
        affine.  The image class matches that of *reference*.
    """
    output_shape = (
        int(tensor.shape[-3]),
        int(tensor.shape[-2]),
        int(tensor.shape[-1]),
    )
    new_affine = _reference_space_affine(reference, output_shape)
    cls = type(reference)
    return cls(tensor, affine=new_affine)