Skip to content

OneOf

Bases: Transform

Apply one of the given transforms, chosen at random.

When applied to a batch with per_instance=True (the default), each batch element independently chooses which transform to apply. This requires shape- and schema-preserving transforms so the elements can be re-stacked. Pass per_instance=False to choose a single transform for the whole batch.

Parameters:

Name Type Description Default
transforms Sequence[Transform] | dict[Transform, float]

Sequence of transforms, or a dict mapping transforms to their relative weights. If a sequence is given, all transforms have equal probability.

required
**kwargs Any

See Transform for additional keyword arguments.

{}

Examples:

>>> import torchio as tio
>>> augmentation = tio.OneOf({
...     tio.Noise(std=0.1): 0.7,
...     tio.Flip(axes=(0,)): 0.3,
... })
Source code in src/torchio/transforms/compose.py
class OneOf(Transform):
    """Apply one of the given transforms, chosen at random.

    When applied to a batch with `per_instance=True` (the default),
    each batch element independently chooses which transform to apply.
    This requires shape- and schema-preserving transforms so the
    elements can be re-stacked. Pass `per_instance=False` to choose a
    single transform for the whole batch.

    Args:
        transforms: Sequence of transforms, or a `dict` mapping
            transforms to their relative weights. If a sequence is
            given, all transforms have equal probability.
        **kwargs: See [`Transform`][torchio.Transform] for additional
            keyword arguments.

    Examples:
        >>> import torchio as tio
        >>> augmentation = tio.OneOf({
        ...     tio.Noise(std=0.1): 0.7,
        ...     tio.Flip(axes=(0,)): 0.3,
        ... })
    """

    def __init__(
        self,
        transforms: Sequence[Transform] | dict[Transform, float],
        **kwargs: Any,
    ) -> None:
        super().__init__(**kwargs)
        if isinstance(transforms, dict):
            weight_dict = cast(dict[Transform, float], transforms)
            self.transforms = list(weight_dict.keys())
            w_list: list[float] = list(weight_dict.values())
            total: float = sum(w_list)
            self.weights = [w / total for w in w_list]
        else:
            self.transforms = list(transforms)
            n = len(self.transforms)
            self.weights = [1.0 / n] * n

    def forward(self, data):
        if self.copy:
            data = copy.deepcopy(data)
        batch, unwrap = self._wrap(data)
        # The input is copied once above, so children apply without copying.
        with _disabled_copy(self.transforms):
            if self.per_instance and batch.batch_size > 1:
                return unwrap(self._forward_per_element(batch))
            if torch.rand(1).item() >= self.p:
                return unwrap(batch)
            idx = int(
                torch.multinomial(
                    torch.tensor(self.weights),
                    num_samples=1,
                ).item()
            )
            batch = self.transforms[idx](batch)
            return unwrap(batch)

    def _forward_per_element(self, batch):
        """Apply an independently chosen transform to each batch element."""
        if self.p == 0:
            return batch
        weights = torch.tensor(self.weights)
        out_subjects = []
        any_applied = False
        for subject in batch.unbatch():
            if torch.rand(1).item() < self.p:
                any_applied = True
                idx = int(torch.multinomial(weights, num_samples=1).item())
                subject = _apply_to_element(subject, self.transforms[idx])
            out_subjects.append(subject)
        if not any_applied:
            return batch
        return _rebatch_with_history(out_subjects, "OneOf")

    def to_hydra(self) -> dict[str, Any]:
        cfg = super().to_hydra()
        cfg["transforms"] = [t.to_hydra() for t in self.transforms]
        return cfg

supports_per_instance_params property

Whether this transform can sample parameters per batch element.

Defaults to False. Transforms that implement per-instance parameter sampling override this to return True. When False, the transform always uses batch-shared parameters regardless of the per_instance flag, preserving the legacy behavior.

supports_per_instance_p property

Whether this transform can gate each batch element independently.

Defaults to False. Shape-preserving transforms that implement per-element probability override this to return True. Shape-changing transforms must leave it False because masked and unmasked elements would have incompatible shapes.

invertible property

Whether this transform can be inverted.

make_params(batch)

Sample random parameters for this transform.

Override in subclasses that have random behavior.

Parameters:

Name Type Description Default
batch SubjectsBatch

A SubjectsBatch.

required

Returns:

Type Description
dict[str, Any]

Dict of sampled parameters.

Source code in src/torchio/transforms/transform.py
def make_params(self, batch: SubjectsBatch) -> dict[str, Any]:
    """Sample random parameters for this transform.

    Override in subclasses that have random behavior.

    Args:
        batch: A `SubjectsBatch`.

    Returns:
        Dict of sampled parameters.
    """
    return {}

apply_transform(batch, params)

Apply the transform with the given parameters.

Must be overridden by subclasses. Receives a SubjectsBatch whose ImagesBatch entries contain 5D tensors (B, C, I, J, K). Use negative indexing (-3, -2, -1) for spatial dims.

Parameters:

Name Type Description Default
batch SubjectsBatch

A SubjectsBatch to transform.

required
params dict[str, Any]

Parameters from make_params.

required

Returns:

Type Description
SubjectsBatch

Transformed SubjectsBatch.

Source code in src/torchio/transforms/transform.py
def apply_transform(
    self,
    batch: SubjectsBatch,
    params: dict[str, Any],
) -> SubjectsBatch:
    """Apply the transform with the given parameters.

    Must be overridden by subclasses. Receives a `SubjectsBatch`
    whose `ImagesBatch` entries contain 5D tensors
    `(B, C, I, J, K)`. Use negative indexing (`-3`, `-2`,
    `-1`) for spatial dims.

    Args:
        batch: A `SubjectsBatch` to transform.
        params: Parameters from `make_params`.

    Returns:
        Transformed `SubjectsBatch`.
    """
    raise NotImplementedError

inverse(params)

Return a transform that undoes this one.

Override in invertible subclasses. The returned transform, when applied, reverses the effect of the forward pass with the given parameters.

Parameters:

Name Type Description Default
params dict[str, Any]

The parameters recorded in the forward pass.

required

Returns:

Type Description
Transform

A new Transform instance that inverts this one.

Source code in src/torchio/transforms/transform.py
def inverse(self, params: dict[str, Any]) -> Transform:
    """Return a transform that undoes this one.

    Override in invertible subclasses. The returned transform,
    when applied, reverses the effect of the forward pass with
    the given parameters.

    Args:
        params: The parameters recorded in the forward pass.

    Returns:
        A new `Transform` instance that inverts this one.
    """
    msg = f"{type(self).__name__} is not invertible"
    raise NotImplementedError(msg)