Skip to content

SomeOf

Bases: Transform

Apply a random subset of the given transforms.

When applied to a batch with per_instance=True (the default), each batch element independently samples its own subset. This requires shape- and schema-preserving transforms so the elements can be re-stacked. Pass per_instance=False to sample a single subset for the whole batch.

Parameters:

Name Type Description Default
transforms Sequence[Transform] | None

Sequence of candidate transforms.

None
num_transforms int | tuple[int, int]

How many transforms to apply. An int for a fixed count, or a (min, max) tuple to sample the count uniformly from that range.

1
replace bool

If True, sample with replacement (the same transform may be applied more than once).

False
**kwargs Any

See Transform for additional keyword arguments.

{}

Examples:

>>> import torchio as tio
>>> augmentation = tio.SomeOf(
...     [tio.Noise(), tio.Flip(), tio.Noise(std=0.5)],
...     num_transforms=2,
... )
Source code in src/torchio/transforms/compose.py
class SomeOf(Transform):
    """Apply a random subset of the given transforms.

    When applied to a batch with `per_instance=True` (the default),
    each batch element independently samples its own subset. This
    requires shape- and schema-preserving transforms so the elements
    can be re-stacked. Pass `per_instance=False` to sample a single
    subset for the whole batch.

    Args:
        transforms: Sequence of candidate transforms.
        num_transforms: How many transforms to apply. An `int` for a
            fixed count, or a `(min, max)` tuple to sample the count
            uniformly from that range.
        replace: If `True`, sample with replacement (the same
            transform may be applied more than once).
        **kwargs: See [`Transform`][torchio.Transform] for additional
            keyword arguments.

    Examples:
        >>> import torchio as tio
        >>> augmentation = tio.SomeOf(
        ...     [tio.Noise(), tio.Flip(), tio.Noise(std=0.5)],
        ...     num_transforms=2,
        ... )
    """

    def __init__(
        self,
        transforms: Sequence[Transform] | None = None,
        *,
        num_transforms: int | tuple[int, int] = 1,
        replace: bool = False,
        **kwargs: Any,
    ) -> None:
        super().__init__(**kwargs)
        self.transforms = list(transforms) if transforms else []
        self.num_transforms = num_transforms
        self.replace = replace

    @property
    def _min_n(self) -> int:
        if isinstance(self.num_transforms, int):
            return self.num_transforms
        return self.num_transforms[0]

    @property
    def _max_n(self) -> int:
        if isinstance(self.num_transforms, int):
            return self.num_transforms
        return self.num_transforms[1]

    def forward(self, data):
        if self.copy:
            data = copy.deepcopy(data)
        batch, unwrap = self._wrap(data)
        # The input is copied once above, so children apply without copying.
        with _disabled_copy(self.transforms):
            if self.per_instance and batch.batch_size > 1:
                return unwrap(self._forward_per_element(batch))
            if torch.rand(1).item() >= self.p:
                return unwrap(batch)
            batch = self._apply_subset(batch)
            return unwrap(batch)

    def _apply_subset(self, batch):
        """Apply a randomly chosen subset of transforms to *batch*."""
        n = int(torch.randint(self._min_n, self._max_n + 1, size=(1,)).item())
        n_transforms = len(self.transforms)
        if self.replace:
            indices = torch.randint(0, n_transforms, (n,))
        else:
            n = min(n, n_transforms)
            indices = torch.randperm(n_transforms)[:n]
        for idx in indices:
            batch = self.transforms[idx](batch)
        return batch

    def _forward_per_element(self, batch):
        """Apply an independently chosen subset to each batch element."""
        if self.p == 0:
            return batch
        out_subjects = []
        any_applied = False
        for subject in batch.unbatch():
            if torch.rand(1).item() < self.p:
                any_applied = True
                subject = _apply_to_element(subject, self._apply_subset)
            out_subjects.append(subject)
        if not any_applied:
            return batch
        return _rebatch_with_history(out_subjects, "SomeOf")

    def to_hydra(self) -> dict[str, Any]:
        cfg = super().to_hydra()
        cfg["transforms"] = [t.to_hydra() for t in self.transforms]
        return cfg

supports_per_instance_params property

Whether this transform can sample parameters per batch element.

Defaults to False. Transforms that implement per-instance parameter sampling override this to return True. When False, the transform always uses batch-shared parameters regardless of the per_instance flag, preserving the legacy behavior.

supports_per_instance_p property

Whether this transform can gate each batch element independently.

Defaults to False. Shape-preserving transforms that implement per-element probability override this to return True. Shape-changing transforms must leave it False because masked and unmasked elements would have incompatible shapes.

invertible property

Whether this transform can be inverted.

make_params(batch)

Sample random parameters for this transform.

Override in subclasses that have random behavior.

Parameters:

Name Type Description Default
batch SubjectsBatch

A SubjectsBatch.

required

Returns:

Type Description
dict[str, Any]

Dict of sampled parameters.

Source code in src/torchio/transforms/transform.py
def make_params(self, batch: SubjectsBatch) -> dict[str, Any]:
    """Sample random parameters for this transform.

    Override in subclasses that have random behavior.

    Args:
        batch: A `SubjectsBatch`.

    Returns:
        Dict of sampled parameters.
    """
    return {}

apply_transform(batch, params)

Apply the transform with the given parameters.

Must be overridden by subclasses. Receives a SubjectsBatch whose ImagesBatch entries contain 5D tensors (B, C, I, J, K). Use negative indexing (-3, -2, -1) for spatial dims.

Parameters:

Name Type Description Default
batch SubjectsBatch

A SubjectsBatch to transform.

required
params dict[str, Any]

Parameters from make_params.

required

Returns:

Type Description
SubjectsBatch

Transformed SubjectsBatch.

Source code in src/torchio/transforms/transform.py
def apply_transform(
    self,
    batch: SubjectsBatch,
    params: dict[str, Any],
) -> SubjectsBatch:
    """Apply the transform with the given parameters.

    Must be overridden by subclasses. Receives a `SubjectsBatch`
    whose `ImagesBatch` entries contain 5D tensors
    `(B, C, I, J, K)`. Use negative indexing (`-3`, `-2`,
    `-1`) for spatial dims.

    Args:
        batch: A `SubjectsBatch` to transform.
        params: Parameters from `make_params`.

    Returns:
        Transformed `SubjectsBatch`.
    """
    raise NotImplementedError

inverse(params)

Return a transform that undoes this one.

Override in invertible subclasses. The returned transform, when applied, reverses the effect of the forward pass with the given parameters.

Parameters:

Name Type Description Default
params dict[str, Any]

The parameters recorded in the forward pass.

required

Returns:

Type Description
Transform

A new Transform instance that inverts this one.

Source code in src/torchio/transforms/transform.py
def inverse(self, params: dict[str, Any]) -> Transform:
    """Return a transform that undoes this one.

    Override in invertible subclasses. The returned transform,
    when applied, reverses the effect of the forward pass with
    the given parameters.

    Args:
        params: The parameters recorded in the forward pass.

    Returns:
        A new `Transform` instance that inverts this one.
    """
    msg = f"{type(self).__name__} is not invertible"
    raise NotImplementedError(msg)