Bases: SpatialTransform
Set the spatial metadata of an image to match a reference space.
This is useful for assigning meaningful spatial metadata to a
tensor that has lost it, such as a neural network embedding or a
downsampled feature map. The data is left unchanged; only the
affine is updated so that the (possibly lower-resolution) grid
covers the same field of view, orientation, and physical center
as the reference image.
A typical use case is visualizing or resampling the output of a
network whose spatial resolution differs from its input:
Parameters:
| Name |
Type |
Description |
Default |
reference
|
Image
|
Full-resolution reference image whose field of view
and orientation will be matched.
|
required
|
**kwargs
|
Any
|
|
{}
|
Examples:
>>> import torch
>>> import torchio as tio
>>> reference = tio.ScalarImage(tensor=torch.rand(1, 64, 64, 64))
>>> # A network embedding loses spatial metadata:
>>> embedding = torch.rand(8, 16, 16, 16)
>>> image = tio.ToReferenceSpace.from_tensor(embedding, reference)
>>> image.spatial_shape
(16, 16, 16)
Source code in src/torchio/transforms/spatial/to_reference_space.py
| class ToReferenceSpace(SpatialTransform):
r"""Set the spatial metadata of an image to match a reference space.
This is useful for assigning meaningful spatial metadata to a
tensor that has lost it, such as a neural network embedding or a
downsampled feature map. The data is left unchanged; only the
affine is updated so that the (possibly lower-resolution) grid
covers the same field of view, orientation, and physical center
as the *reference* image.
A typical use case is visualizing or resampling the output of a
network whose spatial resolution differs from its input:
Args:
reference: Full-resolution reference image whose field of view
and orientation will be matched.
**kwargs: See [`Transform`][torchio.Transform].
Examples:
>>> import torch
>>> import torchio as tio
>>> reference = tio.ScalarImage(tensor=torch.rand(1, 64, 64, 64))
>>> # A network embedding loses spatial metadata:
>>> embedding = torch.rand(8, 16, 16, 16)
>>> image = tio.ToReferenceSpace.from_tensor(embedding, reference)
>>> image.spatial_shape
(16, 16, 16)
"""
def __init__(self, reference: Image, **kwargs: Any) -> None:
super().__init__(**kwargs)
if not isinstance(reference, Image):
msg = f"reference must be a TorchIO Image, got {type(reference).__name__}"
raise TypeError(msg)
self.reference = reference
def make_params(self, batch: SubjectsBatch) -> dict[str, Any]:
"""No random parameters."""
return {}
def apply_transform(
self,
batch: SubjectsBatch,
params: dict[str, Any],
) -> SubjectsBatch:
"""Replace each image's affine with the reference-space affine."""
for _name, img_batch in self._get_images(batch).items():
output_shape = (
int(img_batch.data.shape[2]),
int(img_batch.data.shape[3]),
int(img_batch.data.shape[4]),
)
new_affine = _reference_space_affine(self.reference, output_shape)
img_batch.affines[:] = [new_affine.clone() for _ in img_batch.affines]
return batch
@staticmethod
def from_tensor(tensor: Tensor, reference: Image) -> Image:
"""Build a TorchIO image from a tensor and a reference image.
Args:
tensor: A `(C, I, J, K)` tensor (e.g., a network
embedding) whose spatial metadata should match the
reference space.
reference: Reference image whose field of view and
orientation will be matched.
Returns:
A new image with *tensor* as data and a reference-space
affine. The image class matches that of *reference*.
"""
output_shape = (
int(tensor.shape[-3]),
int(tensor.shape[-2]),
int(tensor.shape[-1]),
)
new_affine = _reference_space_affine(reference, output_shape)
cls = type(reference)
return cls(tensor, affine=new_affine)
|
invertible
property
Whether this transform can be inverted.
forward(data)
forward(data: Subject) -> Subject
forward(data: Image) -> Image
forward(data: Tensor) -> Tensor
forward(data: np.ndarray) -> np.ndarray
forward(data: sitk.Image) -> sitk.Image
forward(data: nib.Nifti1Image) -> nib.Nifti1Image
forward(data: dict) -> dict
forward(data: ImagesBatch) -> ImagesBatch
forward(data: SubjectsBatch) -> SubjectsBatch
Apply the transform.
The output type always matches the input type.
Parameters:
| Name |
Type |
Description |
Default |
data
|
Any
|
|
required
|
Source code in src/torchio/transforms/transform.py
| def forward(self, data: Any) -> Any:
"""Apply the transform.
The output type always matches the input type.
Args:
data: Input data to transform.
"""
if self.copy:
data = _copy.deepcopy(data)
batch, unwrap = self._wrap(data)
if torch.rand(1).item() > self.p:
return unwrap(batch)
params = self.make_params(batch)
batch = self.apply_transform(batch, params)
# Record history on the batch
trace = AppliedTransform(name=type(self).__name__, params=params)
if not hasattr(batch, "applied_transforms"):
batch.applied_transforms = []
batch.applied_transforms.append(trace)
result = unwrap(batch)
# Propagate history to outputs that can carry it
if (
hasattr(batch, "applied_transforms")
and not isinstance(result, (SubjectsBatch, Tensor, np.ndarray))
and not isinstance(result, dict)
):
with contextlib.suppress(AttributeError):
result.applied_transforms = list(batch.applied_transforms)
return result
|
inverse(params)
Return a transform that undoes this one.
Override in invertible subclasses. The returned transform,
when applied, reverses the effect of the forward pass with
the given parameters.
Parameters:
| Name |
Type |
Description |
Default |
params
|
dict[str, Any]
|
The parameters recorded in the forward pass.
|
required
|
Returns:
| Type |
Description |
Transform
|
A new Transform instance that inverts this one.
|
Source code in src/torchio/transforms/transform.py
| def inverse(self, params: dict[str, Any]) -> Transform:
"""Return a transform that undoes this one.
Override in invertible subclasses. The returned transform,
when applied, reverses the effect of the forward pass with
the given parameters.
Args:
params: The parameters recorded in the forward pass.
Returns:
A new `Transform` instance that inverts this one.
"""
msg = f"{type(self).__name__} is not invertible"
raise NotImplementedError(msg)
|
to_hydra()
Export as a Hydra-compatible config dict.
Returns a dict with _target_ set to the fully qualified
class name and only non-default field values included.
Returns:
| Type |
Description |
dict[str, Any]
|
Dict suitable for hydra.utils.instantiate().
|
Source code in src/torchio/transforms/transform.py
| def to_hydra(self) -> dict[str, Any]:
"""Export as a Hydra-compatible config dict.
Returns a dict with `_target_` set to the fully qualified
class name and only non-default field values included.
Returns:
Dict suitable for `hydra.utils.instantiate()`.
"""
from .parameter_range import _ParameterRange
cls = type(self)
target = f"torchio.{cls.__qualname__}"
cfg: dict[str, Any] = {"_target_": target}
for name, default in _collect_init_params(cls).items():
value = getattr(self, name, default)
if isinstance(value, _ParameterRange):
if value._original == default:
continue
value = _hydra_value(value._original)
elif value == default:
continue
else:
value = _hydra_value(value)
cfg[name] = value
return cfg
|
make_params(batch)
No random parameters.
Source code in src/torchio/transforms/spatial/to_reference_space.py
| def make_params(self, batch: SubjectsBatch) -> dict[str, Any]:
"""No random parameters."""
return {}
|
Replace each image's affine with the reference-space affine.
Source code in src/torchio/transforms/spatial/to_reference_space.py
| def apply_transform(
self,
batch: SubjectsBatch,
params: dict[str, Any],
) -> SubjectsBatch:
"""Replace each image's affine with the reference-space affine."""
for _name, img_batch in self._get_images(batch).items():
output_shape = (
int(img_batch.data.shape[2]),
int(img_batch.data.shape[3]),
int(img_batch.data.shape[4]),
)
new_affine = _reference_space_affine(self.reference, output_shape)
img_batch.affines[:] = [new_affine.clone() for _ in img_batch.affines]
return batch
|
from_tensor(tensor, reference)
staticmethod
Build a TorchIO image from a tensor and a reference image.
Parameters:
| Name |
Type |
Description |
Default |
tensor
|
Tensor
|
A (C, I, J, K) tensor (e.g., a network
embedding) whose spatial metadata should match the
reference space.
|
required
|
reference
|
Image
|
Reference image whose field of view and
orientation will be matched.
|
required
|
Returns:
| Type |
Description |
Image
|
A new image with tensor as data and a reference-space
|
Image
|
affine. The image class matches that of reference.
|
Source code in src/torchio/transforms/spatial/to_reference_space.py
| @staticmethod
def from_tensor(tensor: Tensor, reference: Image) -> Image:
"""Build a TorchIO image from a tensor and a reference image.
Args:
tensor: A `(C, I, J, K)` tensor (e.g., a network
embedding) whose spatial metadata should match the
reference space.
reference: Reference image whose field of view and
orientation will be matched.
Returns:
A new image with *tensor* as data and a reference-space
affine. The image class matches that of *reference*.
"""
output_shape = (
int(tensor.shape[-3]),
int(tensor.shape[-2]),
int(tensor.shape[-1]),
)
new_affine = _reference_space_affine(reference, output_shape)
cls = type(reference)
return cls(tensor, affine=new_affine)
|