Skip to content

OneOf

Bases: Transform

Apply one of the given transforms, chosen at random.

Parameters:

Name Type Description Default
transforms Sequence[Transform] | dict[Transform, float]

Sequence of transforms, or a dict mapping transforms to their relative weights. If a sequence is given, all transforms have equal probability.

required
**kwargs Any

See Transform for additional keyword arguments.

{}

Examples:

>>> import torchio as tio
>>> augmentation = tio.OneOf({
...     tio.Noise(std=0.1): 0.7,
...     tio.Flip(axes=(0,)): 0.3,
... })
Source code in src/torchio/transforms/compose.py
class OneOf(Transform):
    """Apply one of the given transforms, chosen at random.

    Args:
        transforms: Sequence of transforms, or a `dict` mapping
            transforms to their relative weights. If a sequence is
            given, all transforms have equal probability.
        **kwargs: See [`Transform`][torchio.Transform] for additional
            keyword arguments.

    Examples:
        >>> import torchio as tio
        >>> augmentation = tio.OneOf({
        ...     tio.Noise(std=0.1): 0.7,
        ...     tio.Flip(axes=(0,)): 0.3,
        ... })
    """

    def __init__(
        self,
        transforms: Sequence[Transform] | dict[Transform, float],
        **kwargs: Any,
    ) -> None:
        super().__init__(**kwargs)
        if isinstance(transforms, dict):
            weight_dict = cast(dict[Transform, float], transforms)
            self.transforms = list(weight_dict.keys())
            w_list: list[float] = list(weight_dict.values())
            total: float = sum(w_list)
            self.weights = [w / total for w in w_list]
        else:
            self.transforms = list(transforms)
            n = len(self.transforms)
            self.weights = [1.0 / n] * n

    def forward(self, data):
        subject, unwrap = self._wrap(data)
        if torch.rand(1).item() > self.p:
            return unwrap(subject)
        idx = int(
            torch.multinomial(
                torch.tensor(self.weights),
                num_samples=1,
            ).item()
        )
        subject = self.transforms[idx](subject)
        return unwrap(subject)

    def to_hydra(self) -> dict[str, Any]:
        cfg = super().to_hydra()
        cfg["transforms"] = [t.to_hydra() for t in self.transforms]
        return cfg

invertible property

Whether this transform can be inverted.

make_params(batch)

Sample random parameters for this transform.

Override in subclasses that have random behavior.

Parameters:

Name Type Description Default
batch SubjectsBatch

A SubjectsBatch.

required

Returns:

Type Description
dict[str, Any]

Dict of sampled parameters.

Source code in src/torchio/transforms/transform.py
def make_params(self, batch: SubjectsBatch) -> dict[str, Any]:
    """Sample random parameters for this transform.

    Override in subclasses that have random behavior.

    Args:
        batch: A `SubjectsBatch`.

    Returns:
        Dict of sampled parameters.
    """
    return {}

apply_transform(batch, params)

Apply the transform with the given parameters.

Must be overridden by subclasses. Receives a SubjectsBatch whose ImagesBatch entries contain 5D tensors (B, C, I, J, K). Use negative indexing (-3, -2, -1) for spatial dims.

Parameters:

Name Type Description Default
batch SubjectsBatch

A SubjectsBatch to transform.

required
params dict[str, Any]

Parameters from make_params.

required

Returns:

Type Description
SubjectsBatch

Transformed SubjectsBatch.

Source code in src/torchio/transforms/transform.py
def apply_transform(
    self,
    batch: SubjectsBatch,
    params: dict[str, Any],
) -> SubjectsBatch:
    """Apply the transform with the given parameters.

    Must be overridden by subclasses. Receives a `SubjectsBatch`
    whose `ImagesBatch` entries contain 5D tensors
    `(B, C, I, J, K)`. Use negative indexing (`-3`, `-2`,
    `-1`) for spatial dims.

    Args:
        batch: A `SubjectsBatch` to transform.
        params: Parameters from `make_params`.

    Returns:
        Transformed `SubjectsBatch`.
    """
    raise NotImplementedError

inverse(params)

Return a transform that undoes this one.

Override in invertible subclasses. The returned transform, when applied, reverses the effect of the forward pass with the given parameters.

Parameters:

Name Type Description Default
params dict[str, Any]

The parameters recorded in the forward pass.

required

Returns:

Type Description
Transform

A new Transform instance that inverts this one.

Source code in src/torchio/transforms/transform.py
def inverse(self, params: dict[str, Any]) -> Transform:
    """Return a transform that undoes this one.

    Override in invertible subclasses. The returned transform,
    when applied, reverses the effect of the forward pass with
    the given parameters.

    Args:
        params: The parameters recorded in the forward pass.

    Returns:
        A new `Transform` instance that inverts this one.
    """
    msg = f"{type(self).__name__} is not invertible"
    raise NotImplementedError(msg)