Skip to content

ElasticDeformation

Bases: Spatial

Apply a dense random elastic deformation.

Convenience wrapper around Spatial exposing only the elastic parameters.

A random displacement is assigned to a coarse grid of control points and trilinearly upsampled to the image resolution.

Parameters:

Name Type Description Default
control_points TypeControlPoints | None

See Spatial.

None
num_control_points int | TypeThreeInts

See Spatial.

7
max_displacement TypeParameterValue

See Spatial. Default: 7.5 mm.

7.5
locked_borders int

See Spatial.

2
image_interpolation TypeInterpolation

See Spatial.

'linear'
label_interpolation TypeInterpolation

See Spatial.

'nearest'
**kwargs Any

See Transform.

{}

Examples:

>>> import torchio as tio
>>> transform = tio.ElasticDeformation()
>>> transform = tio.ElasticDeformation(
...     max_displacement=10,
...     num_control_points=5,
... )
Source code in src/torchio/transforms/spatial/spatial.py
class ElasticDeformation(Spatial):
    r"""Apply a dense random elastic deformation.

    Convenience wrapper around [`Spatial`][torchio.Spatial] exposing
    only the elastic parameters.

    A random displacement is assigned to a coarse grid of control
    points and trilinearly upsampled to the image resolution.

    Args:
        control_points: See [`Spatial`][torchio.Spatial].
        num_control_points: See [`Spatial`][torchio.Spatial].
        max_displacement: See [`Spatial`][torchio.Spatial].
            Default: `7.5` mm.
        locked_borders: See [`Spatial`][torchio.Spatial].
        image_interpolation: See [`Spatial`][torchio.Spatial].
        label_interpolation: See [`Spatial`][torchio.Spatial].
        **kwargs: See [`Transform`][torchio.Transform].

    Examples:
        >>> import torchio as tio
        >>> transform = tio.ElasticDeformation()
        >>> transform = tio.ElasticDeformation(
        ...     max_displacement=10,
        ...     num_control_points=5,
        ... )
    """

    def __init__(
        self,
        *,
        control_points: TypeControlPoints | None = None,
        num_control_points: int | TypeThreeInts = 7,
        max_displacement: TypeParameterValue = 7.5,
        locked_borders: int = 2,
        image_interpolation: TypeInterpolation = "linear",
        label_interpolation: TypeInterpolation = "nearest",
        **kwargs: Any,
    ) -> None:
        super().__init__(
            control_points=control_points,
            num_control_points=num_control_points,
            max_displacement=max_displacement,
            locked_borders=locked_borders,
            image_interpolation=image_interpolation,
            label_interpolation=label_interpolation,
            **kwargs,
        )

invertible property

Whether this transform can be inverted.

forward(data)

forward(data: Subject) -> Subject
forward(data: Image) -> Image
forward(data: Tensor) -> Tensor
forward(data: np.ndarray) -> np.ndarray
forward(data: sitk.Image) -> sitk.Image
forward(data: nib.Nifti1Image) -> nib.Nifti1Image
forward(data: dict) -> dict
forward(data: ImagesBatch) -> ImagesBatch
forward(data: SubjectsBatch) -> SubjectsBatch

Apply the transform.

The output type always matches the input type.

Parameters:

Name Type Description Default
data Any

Input data to transform.

required
Source code in src/torchio/transforms/transform.py
def forward(self, data: Any) -> Any:
    """Apply the transform.

    The output type always matches the input type.

    Args:
        data: Input data to transform.
    """
    if self.copy:
        data = _copy.deepcopy(data)
    batch, unwrap = self._wrap(data)
    if torch.rand(1).item() > self.p:
        return unwrap(batch)
    params = self.make_params(batch)
    batch = self.apply_transform(batch, params)
    # Record history on the batch
    trace = AppliedTransform(name=type(self).__name__, params=params)
    if not hasattr(batch, "applied_transforms"):
        batch.applied_transforms = []
    batch.applied_transforms.append(trace)
    result = unwrap(batch)
    # Propagate history to outputs that can carry it
    if (
        hasattr(batch, "applied_transforms")
        and not isinstance(result, (SubjectsBatch, Tensor, np.ndarray))
        and not isinstance(result, dict)
    ):
        with contextlib.suppress(AttributeError):
            result.applied_transforms = list(batch.applied_transforms)
    return result

make_params(batch)

Sample random parameters and resolve the output space.

Scales, degrees, translation, and control-point displacements are sampled once and applied identically to every sample and every image in the batch.

Returns:

Type Description
dict[str, Any]

Dict of serializable parameters for apply_transform and

dict[str, Any]

history replay.

Source code in src/torchio/transforms/spatial/spatial.py
def make_params(self, batch: SubjectsBatch) -> dict[str, Any]:
    """Sample random parameters and resolve the output space.

    Scales, degrees, translation, and control-point displacements are
    sampled once and applied identically to every sample and every
    image in the batch.

    Returns:
        Dict of serializable parameters for `apply_transform` and
        history replay.
    """
    images = self._get_images(batch)
    if not images:
        return {"selected_images": []}

    _, first_batch = next(iter(images.items()))
    first_shape = _get_spatial_shape(first_batch)
    first_affine = first_batch.affines[0]

    sampled_scales = _sample_scales(self.scales, self.isotropic)
    sampled_degrees = self.degrees.sample()
    sampled_translation = self.translation.sample()
    has_affine = _has_affine_component(
        sampled_scales,
        sampled_degrees,
        sampled_translation,
    )
    control_points, max_displacement = _resolve_control_points(
        self.control_points,
        self.num_control_points,
        self.max_displacement,
        self.locked_borders,
    )
    has_elastic = control_points is not None

    if has_affine or has_elastic:
        _check_shared_space(images, first_shape, first_affine)

    target_space = _resolve_target_space(
        self.target,
        batch,
        first_shape,
        first_affine,
    )
    forward_affine = None
    if has_affine:
        forward_affine = _build_forward_affine(
            scales=sampled_scales,
            degrees=sampled_degrees,
            translation=sampled_translation,
            center=self.center,
            shape=first_shape,
            affine=first_affine,
        )

    return {
        "selected_images": list(images.keys()),
        "target": _serialize_space(target_space),
        "original": _serialize_space((first_shape, first_affine)),
        "affine_matrix": _serialize_matrix(forward_affine),
        "control_points": _serialize_control_points(control_points),
        "max_displacement": list(max_displacement) if max_displacement else None,
        "affine_first": self.affine_first,
        "image_interpolation": self.image_interpolation,
        "label_interpolation": self.label_interpolation,
        "antialias": self.antialias,
        "default_pad_value": self.default_pad_value,
        "default_pad_label": self.default_pad_label,
    }

apply_transform(batch, params)

Apply the spatial mapping to every selected image in batch.

The sampling grid is built once from the parameters produced by make_params and reused for all images and all batch samples.

Source code in src/torchio/transforms/spatial/spatial.py
def apply_transform(
    self,
    batch: SubjectsBatch,
    params: dict[str, Any],
) -> SubjectsBatch:
    """Apply the spatial mapping to every selected image in *batch*.

    The sampling grid is built once from the parameters produced by
    `make_params` and reused for all images and all batch samples.
    """
    selected_images = params.get("selected_images", [])
    if not selected_images:
        return batch

    target_space = _deserialize_space(params["target"])
    affine_matrix = _deserialize_matrix(params["affine_matrix"])
    control_points = _deserialize_control_points(params["control_points"])
    max_displacement = _deserialize_max_displacement(params["max_displacement"])

    _apply_spatial_to_batch(
        batch=batch,
        image_names=selected_images,
        target_space=target_space,
        affine_matrix=affine_matrix,
        control_points=control_points,
        max_displacement=max_displacement,
        affine_first=params["affine_first"],
        image_interpolation=params["image_interpolation"],
        label_interpolation=params["label_interpolation"],
        antialias=params.get("antialias", False),
        default_pad_value=params["default_pad_value"],
        default_pad_label=float(params["default_pad_label"]),
    )
    return batch

inverse(params)

Build the inverse transform from recorded parameters.

The affine component is inverted exactly. The elastic component is approximated by negating the sampled displacement field. The affine_first flag is flipped so that the inverse operations run in the opposite order.

Parameters:

Name Type Description Default
params dict[str, Any]

The parameter dict produced by make_params.

required

Returns:

Type Description
_SpatialInverse

A _SpatialInverse that resamples back to the original grid.

Source code in src/torchio/transforms/spatial/spatial.py
def inverse(self, params: dict[str, Any]) -> _SpatialInverse:
    """Build the inverse transform from recorded parameters.

    The affine component is inverted exactly.  The elastic component
    is approximated by negating the sampled displacement field.  The
    `affine_first` flag is flipped so that the inverse operations
    run in the opposite order.

    Args:
        params: The parameter dict produced by `make_params`.

    Returns:
        A `_SpatialInverse` that resamples back to the original grid.
    """
    affine_matrix = _deserialize_matrix(params["affine_matrix"])
    inverse_affine = None
    if affine_matrix is not None:
        inverse_affine = np.linalg.inv(affine_matrix)

    control_points = _deserialize_control_points(params["control_points"])
    inverse_control_points = None
    if control_points is not None:
        inverse_control_points = -control_points
    original_space = _deserialize_space(params["original"])
    if original_space is None:
        msg = "Spatial inverse needs the original output space"
        raise RuntimeError(msg)

    return _SpatialInverse(
        target=original_space,
        affine_matrix=inverse_affine,
        control_points=inverse_control_points,
        affine_first=not params["affine_first"],
        image_interpolation=params["image_interpolation"],
        label_interpolation=params["label_interpolation"],
        default_pad_value=params["default_pad_value"],
        default_pad_label=float(params["default_pad_label"]),
        copy=False,
    )

to_hydra()

Export as a Hydra-compatible config dict.

Returns a dict with _target_ set to the fully qualified class name and only non-default field values included.

Returns:

Type Description
dict[str, Any]

Dict suitable for hydra.utils.instantiate().

Source code in src/torchio/transforms/transform.py
def to_hydra(self) -> dict[str, Any]:
    """Export as a Hydra-compatible config dict.

    Returns a dict with `_target_` set to the fully qualified
    class name and only non-default field values included.

    Returns:
        Dict suitable for `hydra.utils.instantiate()`.
    """
    from .parameter_range import _ParameterRange

    cls = type(self)
    target = f"torchio.{cls.__qualname__}"
    cfg: dict[str, Any] = {"_target_": target}

    for name, default in _collect_init_params(cls).items():
        value = getattr(self, name, default)
        if isinstance(value, _ParameterRange):
            if value._original == default:
                continue
            value = _hydra_value(value._original)
        elif value == default:
            continue
        else:
            value = _hydra_value(value)
        cfg[name] = value
    return cfg