Skip to content

EnsureShapeMultiple

Bases: SpatialTransform

Ensure that all values in the image shape are divisible by \(n\).

Some convolutional neural network architectures need the size of the input across all spatial dimensions to be a power of 2.

For example, a 3D U-Net with 3 downsampling (pooling) operations needs all spatial dimensions to be multiples of \(2^3 = 8\).

This transform computes the nearest valid shape and delegates to CropOrPad to reach it.

Parameters:

Name Type Description Default
target_multiple TargetMultipleParam

Tuple \((n_i, n_j, n_k)\) so that the output size along axis \(d\) is a multiple of \(n_d\). If a single value \(n\) is provided, then \(n_i = n_j = n_k = n\).

required
method str

Either 'pad' (default) to pad up to the next multiple, or 'crop' to crop down to the previous multiple.

'pad'
padding_mode str

Padding mode forwarded to CropOrPad when method='pad'. One of 'constant', 'reflect', 'replicate', or 'circular'.

'constant'
fill float

Fill value when padding_mode='constant'.

0
**kwargs Any

See Transform for additional keyword arguments.

{}

Examples:

>>> import torchio as tio
>>> transform = tio.EnsureShapeMultiple(8)
>>> transform = tio.EnsureShapeMultiple(2**3, method='pad')
>>> transform = tio.EnsureShapeMultiple(16, method='crop')
>>> transform = tio.EnsureShapeMultiple((4, 8, 16))
Source code in src/torchio/transforms/spatial/ensure_shape_multiple.py
class EnsureShapeMultiple(SpatialTransform):
    r"""Ensure that all values in the image shape are divisible by $n$.

    Some convolutional neural network architectures need the size of the
    input across all spatial dimensions to be a power of 2.

    For example, a 3D U-Net with 3 downsampling (pooling) operations
    needs all spatial dimensions to be multiples of $2^3 = 8$.

    This transform computes the nearest valid shape and delegates to
    [`CropOrPad`][torchio.CropOrPad] to reach it.

    Args:
        target_multiple: Tuple $(n_i, n_j, n_k)$ so that the output
            size along axis $d$ is a multiple of $n_d$. If a single
            value $n$ is provided, then $n_i = n_j = n_k = n$.
        method: Either `'pad'` (default) to pad up to the next
            multiple, or `'crop'` to crop down to the previous
            multiple.
        padding_mode: Padding mode forwarded to `CropOrPad` when
            `method='pad'`. One of `'constant'`, `'reflect'`,
            `'replicate'`, or `'circular'`.
        fill: Fill value when `padding_mode='constant'`.
        **kwargs: See [`Transform`][torchio.Transform] for additional
            keyword arguments.

    Examples:
        >>> import torchio as tio
        >>> transform = tio.EnsureShapeMultiple(8)
        >>> transform = tio.EnsureShapeMultiple(2**3, method='pad')
        >>> transform = tio.EnsureShapeMultiple(16, method='crop')
        >>> transform = tio.EnsureShapeMultiple((4, 8, 16))
    """

    def __init__(
        self,
        target_multiple: TargetMultipleParam,
        *,
        method: str = "pad",
        padding_mode: str = "constant",
        fill: float = 0,
        **kwargs: Any,
    ) -> None:
        super().__init__(**kwargs)
        self.target_multiple = _parse_target_multiple(target_multiple)
        if method not in ("crop", "pad"):
            msg = f"method must be 'crop' or 'pad', got {method!r}"
            raise ValueError(msg)
        self.method = method
        self.padding_mode = padding_mode
        self.fill = fill

    def forward(self, data: Any) -> Any:
        """Apply the transform.

        For `Subject` and `Image` inputs, delegates to `CropOrPad`
        for lazy operation without loading data from disk.
        """
        if isinstance(data, (Subject, Image)):
            return self._build_crop_or_pad(data).forward(data)
        return super().forward(data)

    def _build_crop_or_pad(self, data: Subject | Image) -> CropOrPad:
        """Build a CropOrPad targeting the nearest valid shape."""
        if isinstance(data, Image):
            current_shape = data.spatial_shape
        else:
            current_shape = data.spatial_shape
        target_shape = _compute_target_shape(
            current_shape,
            self.target_multiple,
            self.method,
        )
        return CropOrPad(
            target_shape=target_shape,
            padding_mode=self.padding_mode,
            fill=self.fill,
            only_crop=self.method == "crop",
            only_pad=self.method == "pad",
            p=self.p,
            copy=self.copy,
            include=self.include,
            exclude=self.exclude,
        )

    def make_params(self, batch: SubjectsBatch) -> dict[str, Any]:
        first_images = next(iter(batch.images.values()))
        data_tensor = first_images.data
        current_shape: TypeThreeInts = (
            data_tensor.shape[-3],
            data_tensor.shape[-2],
            data_tensor.shape[-1],
        )
        target_shape = _compute_target_shape(
            current_shape,
            self.target_multiple,
            self.method,
        )
        return {"target_shape": target_shape}

    def apply_transform(
        self,
        batch: SubjectsBatch,
        params: dict[str, Any],
    ) -> SubjectsBatch:
        target_shape = params["target_shape"]
        crop_or_pad = CropOrPad(
            target_shape=target_shape,
            padding_mode=self.padding_mode,
            fill=self.fill,
            only_crop=self.method == "crop",
            only_pad=self.method == "pad",
            copy=False,
            include=self.include,
            exclude=self.exclude,
        )
        return crop_or_pad.apply_transform(
            batch,
            crop_or_pad.make_params(batch),
        )

supports_per_instance_params property

Whether this transform can sample parameters per batch element.

Defaults to False. Transforms that implement per-instance parameter sampling override this to return True. When False, the transform always uses batch-shared parameters regardless of the per_instance flag, preserving the legacy behavior.

supports_per_instance_p property

Whether this transform can gate each batch element independently.

Defaults to False. Shape-preserving transforms that implement per-element probability override this to return True. Shape-changing transforms must leave it False because masked and unmasked elements would have incompatible shapes.

invertible property

Whether this transform can be inverted.

inverse(params)

Return a transform that undoes this one.

Override in invertible subclasses. The returned transform, when applied, reverses the effect of the forward pass with the given parameters.

Parameters:

Name Type Description Default
params dict[str, Any]

The parameters recorded in the forward pass.

required

Returns:

Type Description
Transform

A new Transform instance that inverts this one.

Source code in src/torchio/transforms/transform.py
def inverse(self, params: dict[str, Any]) -> Transform:
    """Return a transform that undoes this one.

    Override in invertible subclasses. The returned transform,
    when applied, reverses the effect of the forward pass with
    the given parameters.

    Args:
        params: The parameters recorded in the forward pass.

    Returns:
        A new `Transform` instance that inverts this one.
    """
    msg = f"{type(self).__name__} is not invertible"
    raise NotImplementedError(msg)

to_hydra()

Export as a Hydra-compatible config dict.

Returns a dict with _target_ set to the fully qualified class name and only non-default field values included.

Returns:

Type Description
dict[str, Any]

Dict suitable for hydra.utils.instantiate().

Source code in src/torchio/transforms/transform.py
def to_hydra(self) -> dict[str, Any]:
    """Export as a Hydra-compatible config dict.

    Returns a dict with `_target_` set to the fully qualified
    class name and only non-default field values included.

    Returns:
        Dict suitable for `hydra.utils.instantiate()`.
    """
    from .parameter_range import _ParameterRange

    cls = type(self)
    target = f"torchio.{cls.__qualname__}"
    cfg: dict[str, Any] = {"_target_": target}

    for name, default in _collect_init_params(cls).items():
        value = getattr(self, name, default)
        if isinstance(value, _ParameterRange):
            if value._original == default:
                continue
            value = _hydra_value(value._original)
        elif value == default:
            continue
        else:
            value = _hydra_value(value)
        cfg[name] = value
    return cfg

forward(data)

Apply the transform.

For Subject and Image inputs, delegates to CropOrPad for lazy operation without loading data from disk.

Source code in src/torchio/transforms/spatial/ensure_shape_multiple.py
def forward(self, data: Any) -> Any:
    """Apply the transform.

    For `Subject` and `Image` inputs, delegates to `CropOrPad`
    for lazy operation without loading data from disk.
    """
    if isinstance(data, (Subject, Image)):
        return self._build_crop_or_pad(data).forward(data)
    return super().forward(data)