Skip to content

Affine

Bases: Spatial

Apply a random or fixed affine transform.

Convenience wrapper around Spatial exposing only the affine parameters. The affine matrix data structure is available as AffineMatrix.

Parameters:

Name Type Description Default
scales TypeParameterValue

See Spatial. Default: 1.0 (no scaling).

1.0
degrees TypeParameterValue

See Spatial. Default: 0.0 (no rotation).

0.0
translation TypeParameterValue

See Spatial.

0.0
isotropic bool

See Spatial.

False
center TypeCenter

See Spatial.

'image'
default_pad_value TypePadValue | float

See Spatial.

'minimum'
default_pad_label int | float

See Spatial.

0
image_interpolation TypeInterpolation

See Spatial.

'linear'
label_interpolation TypeInterpolation

See Spatial.

'nearest'
**kwargs Any

See Transform.

{}

Examples:

>>> import torchio as tio
>>> transform = tio.Affine(degrees=(-15, 15))
>>> transform = tio.Affine(scales=1.0, degrees=(0, 0, 90))
Source code in src/torchio/transforms/spatial/spatial.py
class Affine(Spatial):
    r"""Apply a random or fixed affine transform.

    Convenience wrapper around [`Spatial`][torchio.Spatial] exposing
    only the affine parameters.  The affine matrix data structure is
    available as [`AffineMatrix`][torchio.AffineMatrix].

    Args:
        scales: See [`Spatial`][torchio.Spatial].
            Default: `1.0` (no scaling).
        degrees: See [`Spatial`][torchio.Spatial].
            Default: `0.0` (no rotation).
        translation: See [`Spatial`][torchio.Spatial].
        isotropic: See [`Spatial`][torchio.Spatial].
        center: See [`Spatial`][torchio.Spatial].
        default_pad_value: See [`Spatial`][torchio.Spatial].
        default_pad_label: See [`Spatial`][torchio.Spatial].
        image_interpolation: See [`Spatial`][torchio.Spatial].
        label_interpolation: See [`Spatial`][torchio.Spatial].
        **kwargs: See [`Transform`][torchio.Transform].

    Examples:
        >>> import torchio as tio
        >>> transform = tio.Affine(degrees=(-15, 15))
        >>> transform = tio.Affine(scales=1.0, degrees=(0, 0, 90))
    """

    def __init__(
        self,
        *,
        scales: TypeParameterValue = 1.0,
        degrees: TypeParameterValue = 0.0,
        translation: TypeParameterValue = 0.0,
        isotropic: bool = False,
        center: TypeCenter = "image",
        default_pad_value: TypePadValue | float = "minimum",
        default_pad_label: int | float = 0,
        image_interpolation: TypeInterpolation = "linear",
        label_interpolation: TypeInterpolation = "nearest",
        **kwargs: Any,
    ) -> None:
        super().__init__(
            scales=scales,
            degrees=degrees,
            translation=translation,
            isotropic=isotropic,
            center=center,
            default_pad_value=default_pad_value,
            default_pad_label=default_pad_label,
            image_interpolation=image_interpolation,
            label_interpolation=label_interpolation,
            **kwargs,
        )
        self._warn_if_noop(
            is_noop=(
                self.scales.is_constant(1.0)
                and self.degrees.is_constant(0.0)
                and self.translation.is_constant(0.0)
            ),
            hint="degrees=(-15, 15)",
        )

invertible property

Whether this transform can be inverted.

forward(data)

forward(data: Subject) -> Subject
forward(data: Image) -> Image
forward(data: Tensor) -> Tensor
forward(data: np.ndarray) -> np.ndarray
forward(data: sitk.Image) -> sitk.Image
forward(data: nib.Nifti1Image) -> nib.Nifti1Image
forward(data: dict) -> dict
forward(data: ImagesBatch) -> ImagesBatch
forward(data: SubjectsBatch) -> SubjectsBatch

Apply the transform.

The output type always matches the input type.

Parameters:

Name Type Description Default
data Any

Input data to transform.

required
Source code in src/torchio/transforms/transform.py
def forward(self, data: Any) -> Any:
    """Apply the transform.

    The output type always matches the input type.

    Args:
        data: Input data to transform.
    """
    if self.copy:
        data = _copy.deepcopy(data)
    batch, unwrap = self._wrap(data)
    # When per-element gating is active, the transform handles the
    # probability itself (masked-out elements get identity params),
    # so skip the batch-wide coin flip here. Apply iff rand < p, so
    # p=0 is always a no-op and p=1 always applies.
    if not self._per_instance_p_active(batch) and torch.rand(1).item() >= self.p:
        return unwrap(batch)
    params = self.make_params(batch)
    batch = self.apply_transform(batch, params)
    # Record history on the batch, unless every element was gated out by
    # per-element probability: that is an exact no-op, and recording it
    # would let history replay (e.g. an invertible spatial transform)
    # trigger an unnecessary identity resample.
    if not _all_elements_gated_out(params):
        trace = AppliedTransform(name=type(self).__name__, params=params)
        if not hasattr(batch, "applied_transforms"):
            batch.applied_transforms = []
        batch.applied_transforms.append(trace)
    result = unwrap(batch)
    # Propagate history to outputs that can carry it
    if (
        hasattr(batch, "applied_transforms")
        and not isinstance(result, (SubjectsBatch, Tensor, np.ndarray))
        and not isinstance(result, dict)
    ):
        with contextlib.suppress(AttributeError):
            result.applied_transforms = list(batch.applied_transforms)
    return result

make_params(batch)

Sample random parameters and resolve the output space.

Scales, degrees, translation, and control-point displacements are sampled per batch element when per-instance augmentation is active (the default for batches), and once otherwise.

Returns:

Type Description
dict[str, Any]

Dict of serializable parameters for apply_transform and

dict[str, Any]

history replay.

Source code in src/torchio/transforms/spatial/spatial.py
def make_params(self, batch: SubjectsBatch) -> dict[str, Any]:
    """Sample random parameters and resolve the output space.

    Scales, degrees, translation, and control-point displacements are
    sampled per batch element when per-instance augmentation is
    active (the default for batches), and once otherwise.

    Returns:
        Dict of serializable parameters for `apply_transform` and
        history replay.
    """
    images = self._get_images(batch)
    if not images:
        return {"selected_images": []}

    _, first_batch = next(iter(images.items()))
    first_shape = _get_spatial_shape(first_batch)
    first_affine = first_batch.affines[0]

    params: dict[str, Any] = {
        "selected_images": list(images.keys()),
        "original": _serialize_space((first_shape, first_affine)),
        "affine_first": self.affine_first,
        "image_interpolation": self.image_interpolation,
        "label_interpolation": self.label_interpolation,
        "antialias": self.antialias,
        "default_pad_value": self.default_pad_value,
        "default_pad_label": self.default_pad_label,
    }

    n = self._resolve_n(batch)
    if n is None:
        forward_affine, control_points, max_displacement, has_geometry = (
            self._sample_one(first_shape, first_affine)
        )
        if has_geometry:
            _check_shared_space(images, first_shape, first_affine)
        # Resolve the (possibly random) target after sampling the
        # geometry so the RNG stream matches the batch-shared path.
        target_space = _resolve_target_space(
            self.target,
            batch,
            first_shape,
            first_affine,
        )
        params["target"] = _serialize_space(target_space)
        params["affine_matrix"] = _serialize_matrix(forward_affine)
        params["control_points"] = _serialize_control_points(control_points)
        params["max_displacement"] = (
            list(max_displacement) if max_displacement else None
        )
        return params

    keep = self._keep_mask(batch, n)
    affine_list, control_points_list, displacement_list, any_geometry = (
        self._sample_per_element_geometry(n, first_shape, first_affine, keep)
    )
    if any_geometry:
        _check_shared_space(images, first_shape, first_affine)
    target_space = _resolve_target_space(
        self.target,
        batch,
        first_shape,
        first_affine,
    )
    params["target"] = _serialize_space(target_space)
    params["affine_matrix"] = affine_list
    params["control_points"] = control_points_list
    params["max_displacement"] = displacement_list
    self._tag_batched(
        params,
        batch,
        n,
        keep,
        ["affine_matrix", "control_points", "max_displacement"],
    )
    return params

apply_transform(batch, params)

Apply the spatial mapping to every selected image in batch.

One sampling grid is built per batch element when per-instance parameters are present, and a single shared grid otherwise.

Source code in src/torchio/transforms/spatial/spatial.py
def apply_transform(
    self,
    batch: SubjectsBatch,
    params: dict[str, Any],
) -> SubjectsBatch:
    """Apply the spatial mapping to every selected image in *batch*.

    One sampling grid is built per batch element when per-instance
    parameters are present, and a single shared grid otherwise.
    """
    selected_images = params.get("selected_images", [])
    if not selected_images:
        return batch

    target_space = _deserialize_space(params["target"])

    affine_matrix, control_points, max_displacement, per_sample = (
        _resolve_spatial_params(params)
    )
    is_noop = (
        target_space is None
        and affine_matrix is None
        and control_points is None
        and max_displacement is None
        and per_sample is None
    )
    if is_noop:
        # A true no-op (no resampling and no geometry, e.g. every
        # element gated out) must leave the data and the per-sample
        # affines untouched instead of rebuilding an identity grid.
        return batch
    _apply_spatial_to_batch(
        batch=batch,
        image_names=selected_images,
        target_space=target_space,
        affine_matrix=affine_matrix,
        control_points=control_points,
        max_displacement=max_displacement,
        affine_first=params["affine_first"],
        image_interpolation=params["image_interpolation"],
        label_interpolation=params["label_interpolation"],
        antialias=params.get("antialias", False),
        default_pad_value=params["default_pad_value"],
        default_pad_label=float(params["default_pad_label"]),
        per_sample=per_sample,
    )
    return batch

inverse(params)

Build the inverse transform from recorded parameters.

The affine component is inverted exactly. The elastic component is approximated by negating the sampled displacement field. The affine_first flag is flipped so that the inverse operations run in the opposite order. Per-instance parameters are inverted element by element.

Parameters:

Name Type Description Default
params dict[str, Any]

The parameter dict produced by make_params.

required

Returns:

Type Description
_SpatialInverse

A _SpatialInverse that resamples back to the original grid.

Source code in src/torchio/transforms/spatial/spatial.py
def inverse(self, params: dict[str, Any]) -> _SpatialInverse:
    """Build the inverse transform from recorded parameters.

    The affine component is inverted exactly.  The elastic component
    is approximated by negating the sampled displacement field.  The
    `affine_first` flag is flipped so that the inverse operations
    run in the opposite order. Per-instance parameters are inverted
    element by element.

    Args:
        params: The parameter dict produced by `make_params`.

    Returns:
        A `_SpatialInverse` that resamples back to the original grid.
    """
    original_space = _deserialize_space(params["original"])
    if original_space is None:
        msg = "Spatial inverse needs the original output space"
        raise RuntimeError(msg)

    common: dict[str, Any] = {
        "target": original_space,
        "affine_first": not params["affine_first"],
        "image_interpolation": params["image_interpolation"],
        "label_interpolation": params["label_interpolation"],
        "default_pad_value": params["default_pad_value"],
        "default_pad_label": float(params["default_pad_label"]),
        "copy": False,
        # Invert only the images the forward pass actually transformed,
        # so excluded images (e.g. label maps) are not resampled.
        "include": params["selected_images"],
    }

    batched_keys = params.get("_batched_keys") or []
    if "affine_matrix" in batched_keys:
        per_sample = _invert_per_sample(params)
        return _SpatialInverse(
            affine_matrix=None,
            control_points=None,
            per_sample=per_sample,
            **common,
        )

    affine_matrix = _deserialize_matrix(params["affine_matrix"])
    inverse_affine = None
    if affine_matrix is not None:
        inverse_affine = np.linalg.inv(affine_matrix)
    control_points = _deserialize_control_points(params["control_points"])
    inverse_control_points = None
    if control_points is not None:
        inverse_control_points = -control_points
    return _SpatialInverse(
        affine_matrix=inverse_affine,
        control_points=inverse_control_points,
        **common,
    )

to_hydra()

Export as a Hydra-compatible config dict.

Returns a dict with _target_ set to the fully qualified class name and only non-default field values included.

Returns:

Type Description
dict[str, Any]

Dict suitable for hydra.utils.instantiate().

Source code in src/torchio/transforms/transform.py
def to_hydra(self) -> dict[str, Any]:
    """Export as a Hydra-compatible config dict.

    Returns a dict with `_target_` set to the fully qualified
    class name and only non-default field values included.

    Returns:
        Dict suitable for `hydra.utils.instantiate()`.
    """
    from .parameter_range import _ParameterRange

    cls = type(self)
    target = f"torchio.{cls.__qualname__}"
    cfg: dict[str, Any] = {"_target_": target}

    for name, default in _collect_init_params(cls).items():
        value = getattr(self, name, default)
        if isinstance(value, _ParameterRange):
            if value._original == default:
                continue
            value = _hydra_value(value._original)
        elif value == default:
            continue
        else:
            value = _hydra_value(value)
        cfg[name] = value
    return cfg