Skip to content

Swap

Bases: IntensityTransform

Randomly swap patches within an image.

This is typically used in context restoration for self-supervised learning. Pairs of same-sized patches are selected at random and their contents are exchanged.

Warning

This transform is intended for self-supervised or unsupervised workflows. Because the spatial content is rearranged, aligned label maps become inconsistent with the swapped image. A warning is emitted if LabelMap images are present in the subject.

Parameters:

Name Type Description Default
patch_size int | tuple[int, int, int]

Spatial size of the patches to swap. A single integer \(n\) means \((n, n, n)\).

15
num_iterations int | tuple[int, int]

Number of patch pairs to swap. A 2-tuple \((a, b)\) samples \(n \sim \mathcal{U}(a, b)\).

100
**kwargs Any

See Transform.

{}

Examples:

>>> import torchio as tio
>>> transform = tio.Swap(patch_size=15, num_iterations=100)
Source code in src/torchio/transforms/intensity/swap.py
class Swap(IntensityTransform):
    r"""Randomly swap patches within an image.

    This is typically used in
    [context restoration for self-supervised learning](https://www.sciencedirect.com/science/article/pii/S1361841518304699).
    Pairs of same-sized patches are selected at random and their
    contents are exchanged.

    Warning:
        This transform is intended for **self-supervised** or
        **unsupervised** workflows.  Because the spatial content is
        rearranged, aligned label maps become inconsistent with the
        swapped image.  A warning is emitted if `LabelMap` images
        are present in the subject.

    Args:
        patch_size: Spatial size of the patches to swap.  A single
            integer $n$ means $(n, n, n)$.
        num_iterations: Number of patch pairs to swap.  A 2-tuple
            $(a, b)$ samples $n \sim \mathcal{U}(a, b)$.
        **kwargs: See [`Transform`][torchio.Transform].

    Examples:
        >>> import torchio as tio
        >>> transform = tio.Swap(patch_size=15, num_iterations=100)
    """

    def __init__(
        self,
        *,
        patch_size: int | tuple[int, int, int] = 15,
        num_iterations: int | tuple[int, int] = 100,
        **kwargs: Any,
    ) -> None:
        super().__init__(**kwargs)
        if isinstance(patch_size, int):
            patch_size = (patch_size, patch_size, patch_size)
        self.patch_size = patch_size
        self.num_iterations = to_nonneg_range(num_iterations)

    def make_params(self, batch: SubjectsBatch) -> dict[str, Any]:
        """Sample swap locations for each image."""
        n = max(1, round(self.num_iterations.sample_1d()))

        # Warn if label maps are present.
        for _name, img_batch in batch.images.items():
            if issubclass(img_batch._image_class, LabelMap):
                warnings.warn(
                    "Swap is applied to a subject containing LabelMap "
                    "images. The spatial rearrangement will make labels "
                    "inconsistent with the swapped image. This transform "
                    "is intended for self-supervised learning.",
                    stacklevel=2,
                )
                break

        # Sample one set of locations per image (shared across batch).
        any_img = next(iter(batch.images.values()))
        spatial_shape = any_img.data.shape[2:]  # (I, J, K)
        locations = _sample_swap_locations(
            spatial_shape,
            self.patch_size,
            n,
        )
        return {"locations": locations}

    def apply_transform(
        self,
        batch: SubjectsBatch,
        params: dict[str, Any],
    ) -> SubjectsBatch:
        """Swap patches in each selected image."""
        locations = params["locations"]
        for _name, img_batch in self._get_images(batch).items():
            img_batch.data = _apply_swaps(
                img_batch.data,
                locations,
                self.patch_size,
            )
        return batch

invertible property

Whether this transform can be inverted.

forward(data)

forward(data: Subject) -> Subject
forward(data: Image) -> Image
forward(data: Tensor) -> Tensor
forward(data: np.ndarray) -> np.ndarray
forward(data: sitk.Image) -> sitk.Image
forward(data: nib.Nifti1Image) -> nib.Nifti1Image
forward(data: dict) -> dict
forward(data: ImagesBatch) -> ImagesBatch
forward(data: SubjectsBatch) -> SubjectsBatch

Apply the transform.

The output type always matches the input type.

Parameters:

Name Type Description Default
data Any

Input data to transform.

required
Source code in src/torchio/transforms/transform.py
def forward(self, data: Any) -> Any:
    """Apply the transform.

    The output type always matches the input type.

    Args:
        data: Input data to transform.
    """
    if self.copy:
        data = _copy.deepcopy(data)
    batch, unwrap = self._wrap(data)
    if torch.rand(1).item() > self.p:
        return unwrap(batch)
    params = self.make_params(batch)
    batch = self.apply_transform(batch, params)
    # Record history on the batch
    trace = AppliedTransform(name=type(self).__name__, params=params)
    if not hasattr(batch, "applied_transforms"):
        batch.applied_transforms = []
    batch.applied_transforms.append(trace)
    result = unwrap(batch)
    # Propagate history to outputs that can carry it
    if (
        hasattr(batch, "applied_transforms")
        and not isinstance(result, (SubjectsBatch, Tensor, np.ndarray))
        and not isinstance(result, dict)
    ):
        with contextlib.suppress(AttributeError):
            result.applied_transforms = list(batch.applied_transforms)
    return result

inverse(params)

Return a transform that undoes this one.

Override in invertible subclasses. The returned transform, when applied, reverses the effect of the forward pass with the given parameters.

Parameters:

Name Type Description Default
params dict[str, Any]

The parameters recorded in the forward pass.

required

Returns:

Type Description
Transform

A new Transform instance that inverts this one.

Source code in src/torchio/transforms/transform.py
def inverse(self, params: dict[str, Any]) -> Transform:
    """Return a transform that undoes this one.

    Override in invertible subclasses. The returned transform,
    when applied, reverses the effect of the forward pass with
    the given parameters.

    Args:
        params: The parameters recorded in the forward pass.

    Returns:
        A new `Transform` instance that inverts this one.
    """
    msg = f"{type(self).__name__} is not invertible"
    raise NotImplementedError(msg)

to_hydra()

Export as a Hydra-compatible config dict.

Returns a dict with _target_ set to the fully qualified class name and only non-default field values included.

Returns:

Type Description
dict[str, Any]

Dict suitable for hydra.utils.instantiate().

Source code in src/torchio/transforms/transform.py
def to_hydra(self) -> dict[str, Any]:
    """Export as a Hydra-compatible config dict.

    Returns a dict with `_target_` set to the fully qualified
    class name and only non-default field values included.

    Returns:
        Dict suitable for `hydra.utils.instantiate()`.
    """
    from .parameter_range import _ParameterRange

    cls = type(self)
    target = f"torchio.{cls.__qualname__}"
    cfg: dict[str, Any] = {"_target_": target}

    for name, default in _collect_init_params(cls).items():
        value = getattr(self, name, default)
        if isinstance(value, _ParameterRange):
            if value._original == default:
                continue
            value = _hydra_value(value._original)
        elif value == default:
            continue
        else:
            value = _hydra_value(value)
        cfg[name] = value
    return cfg

make_params(batch)

Sample swap locations for each image.

Source code in src/torchio/transforms/intensity/swap.py
def make_params(self, batch: SubjectsBatch) -> dict[str, Any]:
    """Sample swap locations for each image."""
    n = max(1, round(self.num_iterations.sample_1d()))

    # Warn if label maps are present.
    for _name, img_batch in batch.images.items():
        if issubclass(img_batch._image_class, LabelMap):
            warnings.warn(
                "Swap is applied to a subject containing LabelMap "
                "images. The spatial rearrangement will make labels "
                "inconsistent with the swapped image. This transform "
                "is intended for self-supervised learning.",
                stacklevel=2,
            )
            break

    # Sample one set of locations per image (shared across batch).
    any_img = next(iter(batch.images.values()))
    spatial_shape = any_img.data.shape[2:]  # (I, J, K)
    locations = _sample_swap_locations(
        spatial_shape,
        self.patch_size,
        n,
    )
    return {"locations": locations}

apply_transform(batch, params)

Swap patches in each selected image.

Source code in src/torchio/transforms/intensity/swap.py
def apply_transform(
    self,
    batch: SubjectsBatch,
    params: dict[str, Any],
) -> SubjectsBatch:
    """Swap patches in each selected image."""
    locations = params["locations"]
    for _name, img_batch in self._get_images(batch).items():
        img_batch.data = _apply_swaps(
            img_batch.data,
            locations,
            self.patch_size,
        )
    return batch