Skip to content

OneHot

Bases: Transform

One-hot encode label maps.

Each label map with \(K\) classes (including background) is converted from shape \((1, I, J, K)\) to \((K, I, J, K)\), where channel \(k\) is 1 where the label equals \(k\) and 0 elsewhere.

Only LabelMap images are affected. ScalarImage instances are left unchanged.

The inverse takes the argmax across channels, restoring the original single-channel label map.

Parameters:

Name Type Description Default
num_classes int

Total number of classes. -1 (default) infers from the data as max_label + 1.

-1
**kwargs Any

See Transform.

{}

Examples:

>>> import torchio as tio
>>> transform = tio.OneHot()
>>> transform = tio.OneHot(num_classes=5)
>>> # Invert back to single-channel
>>> restored = transformed.apply_inverse_transform()
Source code in src/torchio/transforms/label/one_hot.py
class OneHot(Transform):
    r"""One-hot encode label maps.

    Each label map with $K$ classes (including background) is converted
    from shape $(1, I, J, K)$ to $(K, I, J, K)$, where channel $k$
    is 1 where the label equals $k$ and 0 elsewhere.

    Only [`LabelMap`][torchio.LabelMap] images are affected.
    [`ScalarImage`][torchio.ScalarImage] instances are left unchanged.

    The inverse takes the argmax across channels, restoring the
    original single-channel label map.

    Args:
        num_classes: Total number of classes. `-1` (default) infers
            from the data as `max_label + 1`.
        **kwargs: See [`Transform`][torchio.Transform].

    Examples:
        >>> import torchio as tio
        >>> transform = tio.OneHot()
        >>> transform = tio.OneHot(num_classes=5)
        >>> # Invert back to single-channel
        >>> restored = transformed.apply_inverse_transform()
    """

    def __init__(
        self,
        *,
        num_classes: int = -1,
        **kwargs: Any,
    ) -> None:
        super().__init__(**kwargs)
        self.num_classes = num_classes

    def make_params(self, batch: SubjectsBatch) -> dict[str, Any]:
        """No random parameters."""
        return {"num_classes": self.num_classes}

    def apply_transform(
        self,
        batch: SubjectsBatch,
        params: dict[str, Any],
    ) -> SubjectsBatch:
        """One-hot encode each label map in the batch."""
        num_classes = params["num_classes"]
        for _name, img_batch in batch.images.items():
            if not issubclass(img_batch._image_class, LabelMap):
                continue
            # (B, 1, I, J, K) -> (B, num_classes, I, J, K)
            data = img_batch.data.long()
            flat = data[:, 0]  # (B, I, J, K)
            encoded = functional.one_hot(flat, num_classes=num_classes)
            # (B, I, J, K, num_classes) -> (B, num_classes, I, J, K)
            img_batch.data = encoded.permute(0, 4, 1, 2, 3).float()
        return batch

    @property
    def invertible(self) -> bool:
        """Whether this transform can be inverted."""
        return True

    def inverse(self, params: dict[str, Any]) -> _OneHotInverse:
        """Invert by taking argmax."""
        return _OneHotInverse(copy=False)

invertible property

Whether this transform can be inverted.

forward(data)

forward(data: Subject) -> Subject
forward(data: Image) -> Image
forward(data: Tensor) -> Tensor
forward(data: np.ndarray) -> np.ndarray
forward(data: sitk.Image) -> sitk.Image
forward(data: nib.Nifti1Image) -> nib.Nifti1Image
forward(data: dict) -> dict
forward(data: ImagesBatch) -> ImagesBatch
forward(data: SubjectsBatch) -> SubjectsBatch

Apply the transform.

The output type always matches the input type.

Parameters:

Name Type Description Default
data Any

Input data to transform.

required
Source code in src/torchio/transforms/transform.py
def forward(self, data: Any) -> Any:
    """Apply the transform.

    The output type always matches the input type.

    Args:
        data: Input data to transform.
    """
    if self.copy:
        data = _copy.deepcopy(data)
    batch, unwrap = self._wrap(data)
    if torch.rand(1).item() > self.p:
        return unwrap(batch)
    params = self.make_params(batch)
    batch = self.apply_transform(batch, params)
    # Record history on the batch
    trace = AppliedTransform(name=type(self).__name__, params=params)
    if not hasattr(batch, "applied_transforms"):
        batch.applied_transforms = []
    batch.applied_transforms.append(trace)
    result = unwrap(batch)
    # Propagate history to outputs that can carry it
    if (
        hasattr(batch, "applied_transforms")
        and not isinstance(result, (SubjectsBatch, Tensor, np.ndarray))
        and not isinstance(result, dict)
    ):
        with contextlib.suppress(AttributeError):
            result.applied_transforms = list(batch.applied_transforms)
    return result

to_hydra()

Export as a Hydra-compatible config dict.

Returns a dict with _target_ set to the fully qualified class name and only non-default field values included.

Returns:

Type Description
dict[str, Any]

Dict suitable for hydra.utils.instantiate().

Source code in src/torchio/transforms/transform.py
def to_hydra(self) -> dict[str, Any]:
    """Export as a Hydra-compatible config dict.

    Returns a dict with `_target_` set to the fully qualified
    class name and only non-default field values included.

    Returns:
        Dict suitable for `hydra.utils.instantiate()`.
    """
    from .parameter_range import _ParameterRange

    cls = type(self)
    target = f"torchio.{cls.__qualname__}"
    cfg: dict[str, Any] = {"_target_": target}

    for name, default in _collect_init_params(cls).items():
        value = getattr(self, name, default)
        if isinstance(value, _ParameterRange):
            if value._original == default:
                continue
            value = _hydra_value(value._original)
        elif value == default:
            continue
        else:
            value = _hydra_value(value)
        cfg[name] = value
    return cfg

make_params(batch)

No random parameters.

Source code in src/torchio/transforms/label/one_hot.py
def make_params(self, batch: SubjectsBatch) -> dict[str, Any]:
    """No random parameters."""
    return {"num_classes": self.num_classes}

apply_transform(batch, params)

One-hot encode each label map in the batch.

Source code in src/torchio/transforms/label/one_hot.py
def apply_transform(
    self,
    batch: SubjectsBatch,
    params: dict[str, Any],
) -> SubjectsBatch:
    """One-hot encode each label map in the batch."""
    num_classes = params["num_classes"]
    for _name, img_batch in batch.images.items():
        if not issubclass(img_batch._image_class, LabelMap):
            continue
        # (B, 1, I, J, K) -> (B, num_classes, I, J, K)
        data = img_batch.data.long()
        flat = data[:, 0]  # (B, I, J, K)
        encoded = functional.one_hot(flat, num_classes=num_classes)
        # (B, I, J, K, num_classes) -> (B, num_classes, I, J, K)
        img_batch.data = encoded.permute(0, 4, 1, 2, 3).float()
    return batch

inverse(params)

Invert by taking argmax.

Source code in src/torchio/transforms/label/one_hot.py
def inverse(self, params: dict[str, Any]) -> _OneHotInverse:
    """Invert by taking argmax."""
    return _OneHotInverse(copy=False)