Skip to content

PCA

Bases: IntensityTransform

Apply PCA to reduce the channel dimension.

Reshapes a \((C, I, J, K)\) image to \((N, I \cdot J \cdot K)\), performs PCA, and reshapes back to \((\text{num\_components}, I, J, K)\).

This is useful for visualizing high-dimensional feature maps (e.g., neural network embeddings) as RGB images.

The implementation uses torch.pca_lowrank, so no external dependencies are needed.

Parameters:

Name Type Description Default
num_components int

Number of principal components to keep.

3
whiten bool

If True, normalize each component to unit variance.

True
normalize bool

If True, divide all components by the standard deviation of the first component.

True
values_range tuple[float, float]

Linear mapping range for normalization to \([0, 1]\). The default \((-2.3, 2.3)\) covers \(\approx 99\%\) of a standard normal distribution.

(-2.3, 2.3)
clip bool

If True, clip output to \([0, 1]\) after normalization.

True
**kwargs Any

See Transform.

{}

Examples:

>>> import torchio as tio
>>> transform = tio.PCA(num_components=3)
Source code in src/torchio/transforms/intensity/pca.py
class PCA(IntensityTransform):
    r"""Apply PCA to reduce the channel dimension.

    Reshapes a $(C, I, J, K)$ image to $(N, I \cdot J \cdot K)$,
    performs PCA, and reshapes back to
    $(\text{num\_components}, I, J, K)$.

    This is useful for visualizing high-dimensional feature maps
    (e.g., neural network embeddings) as RGB images.

    The implementation uses [`torch.pca_lowrank`][torch.pca_lowrank], so no
    external dependencies are needed.

    Args:
        num_components: Number of principal components to keep.
        whiten: If `True`, normalize each component to unit
            variance.
        normalize: If `True`, divide all components by the
            standard deviation of the first component.
        values_range: Linear mapping range for normalization to
            $[0, 1]$.  The default $(-2.3, 2.3)$ covers
            $\approx 99\%$ of a standard normal distribution.
        clip: If `True`, clip output to $[0, 1]$ after
            normalization.
        **kwargs: See [`Transform`][torchio.Transform].

    Examples:
        >>> import torchio as tio
        >>> transform = tio.PCA(num_components=3)
    """

    def __init__(
        self,
        num_components: int = 3,
        *,
        whiten: bool = True,
        normalize: bool = True,
        values_range: tuple[float, float] = (-2.3, 2.3),
        clip: bool = True,
        **kwargs: Any,
    ) -> None:
        super().__init__(**kwargs)
        if num_components < 1:
            msg = f"num_components must be >= 1, got {num_components}"
            raise ValueError(msg)
        self.num_components = num_components
        self.whiten = whiten
        self.normalize = normalize
        self.values_range = values_range
        self.clip = clip

    def make_params(self, batch: SubjectsBatch) -> dict[str, Any]:
        """No random parameters."""
        return {}

    def apply_transform(
        self,
        batch: SubjectsBatch,
        params: dict[str, Any],
    ) -> SubjectsBatch:
        """Apply PCA to each selected image."""
        for _name, img_batch in self._get_images(batch).items():
            results = []
            for i in range(img_batch.batch_size):
                results.append(self._pca_single(img_batch.data[i]))
            img_batch.data = torch.stack(results)
        return batch

    def _pca_single(self, tensor: Tensor) -> Tensor:
        """Apply PCA to a single `(C, I, J, K)` tensor.

        Args:
            tensor: Input with *C* channels.

        Returns:
            Tensor with *num_components* channels.
        """
        c, si, sj, sk = tensor.shape
        if c < self.num_components:
            msg = (
                f"Image has {c} channels but num_components="
                f"{self.num_components}. Need at least as many "
                "channels as components."
            )
            raise ValueError(msg)

        # (C, I*J*K) → (voxels, channels)
        flat = rearrange(tensor.float(), "c i j k -> (i j k) c")

        # Center.
        mean = flat.mean(dim=0, keepdim=True)
        centered = flat - mean

        # PCA via torch.pca_lowrank.
        _u, s, v = torch.pca_lowrank(centered, q=self.num_components)
        # Project: (voxels, channels) @ (channels, n_comp) → (voxels, n_comp)
        projected = centered @ v

        if self.whiten:
            # s are singular values; variance ≈ s² / (n - 1)
            n = flat.shape[0]
            denom = (n - 1) ** 0.5 if n > 1 else 1.0
            std = s / denom
            std = std.clamp(min=1e-8)
            projected = projected / std.unsqueeze(0)

        if self.normalize and projected.shape[1] > 0:
            first_std = projected[:, 0].std().clamp(min=1e-8)
            projected = projected / first_std

        # Map values_range to [0, 1].
        lo, hi = self.values_range
        projected = (projected - lo) / (hi - lo)

        if self.clip:
            projected = projected.clamp(0, 1)

        # Reshape back: (voxels, n_comp) → (n_comp, I, J, K)
        result = rearrange(
            projected,
            "(i j k) c -> c i j k",
            i=si,
            j=sj,
            k=sk,
        )
        return result

supports_per_instance_params property

Whether this transform can sample parameters per batch element.

Defaults to False. Transforms that implement per-instance parameter sampling override this to return True. When False, the transform always uses batch-shared parameters regardless of the per_instance flag, preserving the legacy behavior.

supports_per_instance_p property

Whether this transform can gate each batch element independently.

Defaults to False. Shape-preserving transforms that implement per-element probability override this to return True. Shape-changing transforms must leave it False because masked and unmasked elements would have incompatible shapes.

invertible property

Whether this transform can be inverted.

forward(data)

forward(data: Subject) -> Subject
forward(data: Image) -> Image
forward(data: Tensor) -> Tensor
forward(data: np.ndarray) -> np.ndarray
forward(data: sitk.Image) -> sitk.Image
forward(data: nib.Nifti1Image) -> nib.Nifti1Image
forward(data: dict) -> dict
forward(data: ImagesBatch) -> ImagesBatch
forward(data: SubjectsBatch) -> SubjectsBatch

Apply the transform.

The output type always matches the input type.

Parameters:

Name Type Description Default
data Any

Input data to transform.

required
Source code in src/torchio/transforms/transform.py
def forward(self, data: Any) -> Any:
    """Apply the transform.

    The output type always matches the input type.

    Args:
        data: Input data to transform.
    """
    if self.copy:
        data = _copy.deepcopy(data)
    batch, unwrap = self._wrap(data)
    # When per-element gating is active, the transform handles the
    # probability itself (masked-out elements get identity params),
    # so skip the batch-wide coin flip here. Apply iff rand < p, so
    # p=0 is always a no-op and p=1 always applies.
    if not self._per_instance_p_active(batch) and torch.rand(1).item() >= self.p:
        return unwrap(batch)
    params = self.make_params(batch)
    batch = self.apply_transform(batch, params)
    # Record history on the batch, unless every element was gated out by
    # per-element probability: that is an exact no-op, and recording it
    # would let history replay (e.g. an invertible spatial transform)
    # trigger an unnecessary identity resample.
    if not _all_elements_gated_out(params):
        trace = AppliedTransform(name=type(self).__name__, params=params)
        if not hasattr(batch, "applied_transforms"):
            batch.applied_transforms = []
        batch.applied_transforms.append(trace)
    result = unwrap(batch)
    # Propagate history to outputs that can carry it
    if (
        hasattr(batch, "applied_transforms")
        and not isinstance(result, (SubjectsBatch, Tensor, np.ndarray))
        and not isinstance(result, dict)
    ):
        with contextlib.suppress(AttributeError):
            result.applied_transforms = list(batch.applied_transforms)
    return result

inverse(params)

Return a transform that undoes this one.

Override in invertible subclasses. The returned transform, when applied, reverses the effect of the forward pass with the given parameters.

Parameters:

Name Type Description Default
params dict[str, Any]

The parameters recorded in the forward pass.

required

Returns:

Type Description
Transform

A new Transform instance that inverts this one.

Source code in src/torchio/transforms/transform.py
def inverse(self, params: dict[str, Any]) -> Transform:
    """Return a transform that undoes this one.

    Override in invertible subclasses. The returned transform,
    when applied, reverses the effect of the forward pass with
    the given parameters.

    Args:
        params: The parameters recorded in the forward pass.

    Returns:
        A new `Transform` instance that inverts this one.
    """
    msg = f"{type(self).__name__} is not invertible"
    raise NotImplementedError(msg)

to_hydra()

Export as a Hydra-compatible config dict.

Returns a dict with _target_ set to the fully qualified class name and only non-default field values included.

Returns:

Type Description
dict[str, Any]

Dict suitable for hydra.utils.instantiate().

Source code in src/torchio/transforms/transform.py
def to_hydra(self) -> dict[str, Any]:
    """Export as a Hydra-compatible config dict.

    Returns a dict with `_target_` set to the fully qualified
    class name and only non-default field values included.

    Returns:
        Dict suitable for `hydra.utils.instantiate()`.
    """
    from .parameter_range import _ParameterRange

    cls = type(self)
    target = f"torchio.{cls.__qualname__}"
    cfg: dict[str, Any] = {"_target_": target}

    for name, default in _collect_init_params(cls).items():
        value = getattr(self, name, default)
        if isinstance(value, _ParameterRange):
            if value._original == default:
                continue
            value = _hydra_value(value._original)
        elif value == default:
            continue
        else:
            value = _hydra_value(value)
        cfg[name] = value
    return cfg

make_params(batch)

No random parameters.

Source code in src/torchio/transforms/intensity/pca.py
def make_params(self, batch: SubjectsBatch) -> dict[str, Any]:
    """No random parameters."""
    return {}

apply_transform(batch, params)

Apply PCA to each selected image.

Source code in src/torchio/transforms/intensity/pca.py
def apply_transform(
    self,
    batch: SubjectsBatch,
    params: dict[str, Any],
) -> SubjectsBatch:
    """Apply PCA to each selected image."""
    for _name, img_batch in self._get_images(batch).items():
        results = []
        for i in range(img_batch.batch_size):
            results.append(self._pca_single(img_batch.data[i]))
        img_batch.data = torch.stack(results)
    return batch