Skip to content

LabelsToImage

Bases: Transform

Generate a synthetic image from a label map.

For each label, Gaussian-distributed tissue is created with a sampled mean and standard deviation, weighted by the label mask. The per-label contributions are summed to produce the output image.

This is the building block behind SynthSeg-style synthesis. For best results, compose with Blur and BiasField.

The generated image is added to the subject under the key given by image_key. Existing images are not modified.

Only LabelMap images are used as input.

Parameters:

Name Type Description Default
label_key str | None

Name of the label map to use. If None, the first LabelMap found is used.

None
image_key str

Name for the generated ScalarImage.

'image_from_labels'
mean Sequence[float | tuple[float, float]] | None

Per-label mean ranges. If None, each label gets a mean sampled from default_mean.

None
std Sequence[float | tuple[float, float]] | None

Per-label std ranges. If None, each label gets a std sampled from default_std.

None
default_mean float | tuple[float, float]

Fallback range for label means.

(0.1, 0.9)
default_std float | tuple[float, float]

Fallback range for label stds.

(0.01, 0.1)
ignore_background bool

If True, label 0 is left as zero.

False
**kwargs Any

See Transform.

{}

Examples:

>>> import torchio as tio
>>> transform = tio.LabelsToImage(label_key="seg")
>>> transform = tio.LabelsToImage(
...     label_key="seg",
...     mean=[(0.8, 1.0), (0.3, 0.5)],
...     std=[(0.01, 0.05), (0.02, 0.08)],
... )
Source code in src/torchio/transforms/intensity/labels_to_image.py
class LabelsToImage(Transform):
    r"""Generate a synthetic image from a label map.

    For each label, Gaussian-distributed tissue is created with a
    sampled mean and standard deviation, weighted by the label mask.
    The per-label contributions are summed to produce the output
    image.

    This is the building block behind
    [SynthSeg](https://github.com/BBillot/SynthSeg)-style synthesis.
    For best results, compose with
    [`Blur`][torchio.Blur] and
    [`BiasField`][torchio.BiasField].

    The generated image is added to the subject under the key given
    by *image_key*.  Existing images are **not** modified.

    Only [`LabelMap`][torchio.LabelMap] images are used as input.

    Args:
        label_key: Name of the label map to use.  If `None`, the
            first `LabelMap` found is used.
        image_key: Name for the generated `ScalarImage`.
        mean: Per-label mean ranges.  If `None`, each label gets a
            mean sampled from *default_mean*.
        std: Per-label std ranges.  If `None`, each label gets a
            std sampled from *default_std*.
        default_mean: Fallback range for label means.
        default_std: Fallback range for label stds.
        ignore_background: If `True`, label 0 is left as zero.
        **kwargs: See [`Transform`][torchio.Transform].

    Examples:
        >>> import torchio as tio
        >>> transform = tio.LabelsToImage(label_key="seg")
        >>> transform = tio.LabelsToImage(
        ...     label_key="seg",
        ...     mean=[(0.8, 1.0), (0.3, 0.5)],
        ...     std=[(0.01, 0.05), (0.02, 0.08)],
        ... )
    """

    def __init__(
        self,
        label_key: str | None = None,
        *,
        image_key: str = "image_from_labels",
        mean: Sequence[float | tuple[float, float]] | None = None,
        std: Sequence[float | tuple[float, float]] | None = None,
        default_mean: float | tuple[float, float] = (0.1, 0.9),
        default_std: float | tuple[float, float] = (0.01, 0.1),
        ignore_background: bool = False,
        **kwargs: Any,
    ) -> None:
        super().__init__(**kwargs)
        self.label_key = label_key
        self.image_key = image_key
        self.mean_ranges = [to_range(m) for m in mean] if mean is not None else None
        self.std_ranges = [to_range(s) for s in std] if std is not None else None
        self.default_mean = to_range(default_mean)
        self.default_std = to_range(default_std)
        self.ignore_background = ignore_background

    def make_params(self, batch: SubjectsBatch) -> dict[str, Any]:
        """Sample per-label mean and std values (per element when batched)."""
        label_batch = self._find_label_batch(batch)
        # Discover unique labels from the first sample.
        unique = sorted(int(v) for v in label_batch.data[0].unique().tolist())
        n = self._resolve_n(batch)
        if n is None:
            means, stds = self._sample_label_values(unique)
            return {"means": means, "stds": stds}
        means_list: list[dict[int, float]] = []
        stds_list: list[dict[int, float]] = []
        for _ in range(n):
            means, stds = self._sample_label_values(unique)
            means_list.append(means)
            stds_list.append(stds)
        params = {"means": means_list, "stds": stds_list}
        self._tag_batched(params, batch, n, None, ["means", "stds"])
        return params

    @property
    def supports_per_instance_params(self) -> bool:
        return True

    def _sample_label_values(
        self,
        unique: list[int],
    ) -> tuple[dict[int, float], dict[int, float]]:
        """Sample one mean and std per label.

        Args:
            unique: Sorted list of unique label values.

        Returns:
            A `(means, stds)` pair of per-label dictionaries.
        """
        means: dict[int, float] = {}
        stds: dict[int, float] = {}
        for idx, label in enumerate(unique):
            if self.ignore_background and label == 0:
                means[label] = 0.0
                stds[label] = 0.0
                continue
            if self.mean_ranges is not None and idx < len(self.mean_ranges):
                means[label] = self.mean_ranges[idx].sample_1d()
            else:
                means[label] = self.default_mean.sample_1d()
            if self.std_ranges is not None and idx < len(self.std_ranges):
                stds[label] = self.std_ranges[idx].sample_1d()
            else:
                stds[label] = abs(self.default_std.sample_1d())
        return means, stds

    def apply_transform(
        self,
        batch: SubjectsBatch,
        params: dict[str, Any],
    ) -> SubjectsBatch:
        """Generate a synthetic image and add it to the batch."""
        label_batch = self._find_label_batch(batch)
        if self._is_per_instance_params(params):
            generated = _generate_per_element(
                label_batch.data,
                params["means"],
                params["stds"],
            )
        else:
            generated = _generate_from_labels(
                label_batch.data,
                params["means"],
                params["stds"],
            )
        # Create a new image batch entry.
        from ...data.batch import ImagesBatch

        new_batch = ImagesBatch(
            data=generated,
            affines=label_batch.affines,
            image_class=ScalarImage,
        )
        batch.images[self.image_key] = new_batch
        return batch

    def _find_label_batch(self, batch: SubjectsBatch) -> Any:
        """Find the label map batch to use."""
        if self.label_key is not None:
            if self.label_key not in batch.images:
                msg = (
                    f"Label key '{self.label_key}' not found. "
                    f"Available: {list(batch.images.keys())}"
                )
                raise KeyError(msg)
            return batch.images[self.label_key]
        # Auto-detect first LabelMap.
        for _name, img_batch in batch.images.items():
            if issubclass(img_batch._image_class, LabelMap):
                return img_batch
        msg = "No LabelMap found in the subject"
        raise KeyError(msg)

supports_per_instance_p property

Whether this transform can gate each batch element independently.

Defaults to False. Shape-preserving transforms that implement per-element probability override this to return True. Shape-changing transforms must leave it False because masked and unmasked elements would have incompatible shapes.

invertible property

Whether this transform can be inverted.

forward(data)

forward(data: Subject) -> Subject
forward(data: Image) -> Image
forward(data: Tensor) -> Tensor
forward(data: np.ndarray) -> np.ndarray
forward(data: sitk.Image) -> sitk.Image
forward(data: nib.Nifti1Image) -> nib.Nifti1Image
forward(data: dict) -> dict
forward(data: ImagesBatch) -> ImagesBatch
forward(data: SubjectsBatch) -> SubjectsBatch

Apply the transform.

The output type always matches the input type.

Parameters:

Name Type Description Default
data Any

Input data to transform.

required
Source code in src/torchio/transforms/transform.py
def forward(self, data: Any) -> Any:
    """Apply the transform.

    The output type always matches the input type.

    Args:
        data: Input data to transform.
    """
    if self.copy:
        data = _copy.deepcopy(data)
    batch, unwrap = self._wrap(data)
    # When per-element gating is active, the transform handles the
    # probability itself (masked-out elements get identity params),
    # so skip the batch-wide coin flip here. Apply iff rand < p, so
    # p=0 is always a no-op and p=1 always applies.
    if not self._per_instance_p_active(batch) and torch.rand(1).item() >= self.p:
        return unwrap(batch)
    params = self.make_params(batch)
    batch = self.apply_transform(batch, params)
    # Record history on the batch, unless every element was gated out by
    # per-element probability: that is an exact no-op, and recording it
    # would let history replay (e.g. an invertible spatial transform)
    # trigger an unnecessary identity resample.
    if not _all_elements_gated_out(params):
        trace = AppliedTransform(name=type(self).__name__, params=params)
        if not hasattr(batch, "applied_transforms"):
            batch.applied_transforms = []
        batch.applied_transforms.append(trace)
    result = unwrap(batch)
    # Propagate history to outputs that can carry it
    if (
        hasattr(batch, "applied_transforms")
        and not isinstance(result, (SubjectsBatch, Tensor, np.ndarray))
        and not isinstance(result, dict)
    ):
        with contextlib.suppress(AttributeError):
            result.applied_transforms = list(batch.applied_transforms)
    return result

inverse(params)

Return a transform that undoes this one.

Override in invertible subclasses. The returned transform, when applied, reverses the effect of the forward pass with the given parameters.

Parameters:

Name Type Description Default
params dict[str, Any]

The parameters recorded in the forward pass.

required

Returns:

Type Description
Transform

A new Transform instance that inverts this one.

Source code in src/torchio/transforms/transform.py
def inverse(self, params: dict[str, Any]) -> Transform:
    """Return a transform that undoes this one.

    Override in invertible subclasses. The returned transform,
    when applied, reverses the effect of the forward pass with
    the given parameters.

    Args:
        params: The parameters recorded in the forward pass.

    Returns:
        A new `Transform` instance that inverts this one.
    """
    msg = f"{type(self).__name__} is not invertible"
    raise NotImplementedError(msg)

to_hydra()

Export as a Hydra-compatible config dict.

Returns a dict with _target_ set to the fully qualified class name and only non-default field values included.

Returns:

Type Description
dict[str, Any]

Dict suitable for hydra.utils.instantiate().

Source code in src/torchio/transforms/transform.py
def to_hydra(self) -> dict[str, Any]:
    """Export as a Hydra-compatible config dict.

    Returns a dict with `_target_` set to the fully qualified
    class name and only non-default field values included.

    Returns:
        Dict suitable for `hydra.utils.instantiate()`.
    """
    from .parameter_range import _ParameterRange

    cls = type(self)
    target = f"torchio.{cls.__qualname__}"
    cfg: dict[str, Any] = {"_target_": target}

    for name, default in _collect_init_params(cls).items():
        value = getattr(self, name, default)
        if isinstance(value, _ParameterRange):
            if value._original == default:
                continue
            value = _hydra_value(value._original)
        elif value == default:
            continue
        else:
            value = _hydra_value(value)
        cfg[name] = value
    return cfg

make_params(batch)

Sample per-label mean and std values (per element when batched).

Source code in src/torchio/transforms/intensity/labels_to_image.py
def make_params(self, batch: SubjectsBatch) -> dict[str, Any]:
    """Sample per-label mean and std values (per element when batched)."""
    label_batch = self._find_label_batch(batch)
    # Discover unique labels from the first sample.
    unique = sorted(int(v) for v in label_batch.data[0].unique().tolist())
    n = self._resolve_n(batch)
    if n is None:
        means, stds = self._sample_label_values(unique)
        return {"means": means, "stds": stds}
    means_list: list[dict[int, float]] = []
    stds_list: list[dict[int, float]] = []
    for _ in range(n):
        means, stds = self._sample_label_values(unique)
        means_list.append(means)
        stds_list.append(stds)
    params = {"means": means_list, "stds": stds_list}
    self._tag_batched(params, batch, n, None, ["means", "stds"])
    return params

apply_transform(batch, params)

Generate a synthetic image and add it to the batch.

Source code in src/torchio/transforms/intensity/labels_to_image.py
def apply_transform(
    self,
    batch: SubjectsBatch,
    params: dict[str, Any],
) -> SubjectsBatch:
    """Generate a synthetic image and add it to the batch."""
    label_batch = self._find_label_batch(batch)
    if self._is_per_instance_params(params):
        generated = _generate_per_element(
            label_batch.data,
            params["means"],
            params["stds"],
        )
    else:
        generated = _generate_from_labels(
            label_batch.data,
            params["means"],
            params["stds"],
        )
    # Create a new image batch entry.
    from ...data.batch import ImagesBatch

    new_batch = ImagesBatch(
        data=generated,
        affines=label_batch.affines,
        image_class=ScalarImage,
    )
    batch.images[self.image_key] = new_batch
    return batch