Skip to content

LabelsToImage

Bases: Transform

Generate a synthetic image from a label map.

For each label, Gaussian-distributed tissue is created with a sampled mean and standard deviation, weighted by the label mask. The per-label contributions are summed to produce the output image.

This is the building block behind SynthSeg-style synthesis. For best results, compose with Blur and BiasField.

The generated image is added to the subject under the key given by image_key. Existing images are not modified.

Only LabelMap images are used as input.

Parameters:

Name Type Description Default
label_key str | None

Name of the label map to use. If None, the first LabelMap found is used.

None
image_key str

Name for the generated ScalarImage.

'image_from_labels'
mean Sequence[float | tuple[float, float]] | None

Per-label mean ranges. If None, each label gets a mean sampled from default_mean.

None
std Sequence[float | tuple[float, float]] | None

Per-label std ranges. If None, each label gets a std sampled from default_std.

None
default_mean float | tuple[float, float]

Fallback range for label means.

(0.1, 0.9)
default_std float | tuple[float, float]

Fallback range for label stds.

(0.01, 0.1)
ignore_background bool

If True, label 0 is left as zero.

False
**kwargs Any

See Transform.

{}

Examples:

>>> import torchio as tio
>>> transform = tio.LabelsToImage(label_key="seg")
>>> transform = tio.LabelsToImage(
...     label_key="seg",
...     mean=[(0.8, 1.0), (0.3, 0.5)],
...     std=[(0.01, 0.05), (0.02, 0.08)],
... )
Source code in src/torchio/transforms/intensity/labels_to_image.py
class LabelsToImage(Transform):
    r"""Generate a synthetic image from a label map.

    For each label, Gaussian-distributed tissue is created with a
    sampled mean and standard deviation, weighted by the label mask.
    The per-label contributions are summed to produce the output
    image.

    This is the building block behind
    [SynthSeg](https://github.com/BBillot/SynthSeg)-style synthesis.
    For best results, compose with
    [`Blur`][torchio.Blur] and
    [`BiasField`][torchio.BiasField].

    The generated image is added to the subject under the key given
    by *image_key*.  Existing images are **not** modified.

    Only [`LabelMap`][torchio.LabelMap] images are used as input.

    Args:
        label_key: Name of the label map to use.  If `None`, the
            first `LabelMap` found is used.
        image_key: Name for the generated `ScalarImage`.
        mean: Per-label mean ranges.  If `None`, each label gets a
            mean sampled from *default_mean*.
        std: Per-label std ranges.  If `None`, each label gets a
            std sampled from *default_std*.
        default_mean: Fallback range for label means.
        default_std: Fallback range for label stds.
        ignore_background: If `True`, label 0 is left as zero.
        **kwargs: See [`Transform`][torchio.Transform].

    Examples:
        >>> import torchio as tio
        >>> transform = tio.LabelsToImage(label_key="seg")
        >>> transform = tio.LabelsToImage(
        ...     label_key="seg",
        ...     mean=[(0.8, 1.0), (0.3, 0.5)],
        ...     std=[(0.01, 0.05), (0.02, 0.08)],
        ... )
    """

    def __init__(
        self,
        label_key: str | None = None,
        *,
        image_key: str = "image_from_labels",
        mean: Sequence[float | tuple[float, float]] | None = None,
        std: Sequence[float | tuple[float, float]] | None = None,
        default_mean: float | tuple[float, float] = (0.1, 0.9),
        default_std: float | tuple[float, float] = (0.01, 0.1),
        ignore_background: bool = False,
        **kwargs: Any,
    ) -> None:
        super().__init__(**kwargs)
        self.label_key = label_key
        self.image_key = image_key
        self.mean_ranges = [to_range(m) for m in mean] if mean is not None else None
        self.std_ranges = [to_range(s) for s in std] if std is not None else None
        self.default_mean = to_range(default_mean)
        self.default_std = to_range(default_std)
        self.ignore_background = ignore_background

    def make_params(self, batch: SubjectsBatch) -> dict[str, Any]:
        """Sample per-label mean and std values."""
        label_batch = self._find_label_batch(batch)
        # Discover unique labels from the first sample.
        unique = sorted(int(v) for v in label_batch.data[0].unique().tolist())
        means: dict[int, float] = {}
        stds: dict[int, float] = {}
        for idx, label in enumerate(unique):
            if self.ignore_background and label == 0:
                means[label] = 0.0
                stds[label] = 0.0
                continue
            if self.mean_ranges is not None and idx < len(self.mean_ranges):
                means[label] = self.mean_ranges[idx].sample_1d()
            else:
                means[label] = self.default_mean.sample_1d()
            if self.std_ranges is not None and idx < len(self.std_ranges):
                stds[label] = self.std_ranges[idx].sample_1d()
            else:
                stds[label] = abs(self.default_std.sample_1d())
        return {"means": means, "stds": stds}

    def apply_transform(
        self,
        batch: SubjectsBatch,
        params: dict[str, Any],
    ) -> SubjectsBatch:
        """Generate a synthetic image and add it to the batch."""
        label_batch = self._find_label_batch(batch)
        means = params["means"]
        stds = params["stds"]
        generated = _generate_from_labels(
            label_batch.data,
            means,
            stds,
        )
        # Create a new image batch entry.
        from ...data.batch import ImagesBatch

        new_batch = ImagesBatch(
            data=generated,
            affines=label_batch.affines,
            image_class=ScalarImage,
        )
        batch.images[self.image_key] = new_batch
        return batch

    def _find_label_batch(self, batch: SubjectsBatch) -> Any:
        """Find the label map batch to use."""
        if self.label_key is not None:
            if self.label_key not in batch.images:
                msg = (
                    f"Label key '{self.label_key}' not found. "
                    f"Available: {list(batch.images.keys())}"
                )
                raise KeyError(msg)
            return batch.images[self.label_key]
        # Auto-detect first LabelMap.
        for _name, img_batch in batch.images.items():
            if issubclass(img_batch._image_class, LabelMap):
                return img_batch
        msg = "No LabelMap found in the subject"
        raise KeyError(msg)

invertible property

Whether this transform can be inverted.

forward(data)

forward(data: Subject) -> Subject
forward(data: Image) -> Image
forward(data: Tensor) -> Tensor
forward(data: np.ndarray) -> np.ndarray
forward(data: sitk.Image) -> sitk.Image
forward(data: nib.Nifti1Image) -> nib.Nifti1Image
forward(data: dict) -> dict
forward(data: ImagesBatch) -> ImagesBatch
forward(data: SubjectsBatch) -> SubjectsBatch

Apply the transform.

The output type always matches the input type.

Parameters:

Name Type Description Default
data Any

Input data to transform.

required
Source code in src/torchio/transforms/transform.py
def forward(self, data: Any) -> Any:
    """Apply the transform.

    The output type always matches the input type.

    Args:
        data: Input data to transform.
    """
    if self.copy:
        data = _copy.deepcopy(data)
    batch, unwrap = self._wrap(data)
    if torch.rand(1).item() > self.p:
        return unwrap(batch)
    params = self.make_params(batch)
    batch = self.apply_transform(batch, params)
    # Record history on the batch
    trace = AppliedTransform(name=type(self).__name__, params=params)
    if not hasattr(batch, "applied_transforms"):
        batch.applied_transforms = []
    batch.applied_transforms.append(trace)
    result = unwrap(batch)
    # Propagate history to outputs that can carry it
    if (
        hasattr(batch, "applied_transforms")
        and not isinstance(result, (SubjectsBatch, Tensor, np.ndarray))
        and not isinstance(result, dict)
    ):
        with contextlib.suppress(AttributeError):
            result.applied_transforms = list(batch.applied_transforms)
    return result

inverse(params)

Return a transform that undoes this one.

Override in invertible subclasses. The returned transform, when applied, reverses the effect of the forward pass with the given parameters.

Parameters:

Name Type Description Default
params dict[str, Any]

The parameters recorded in the forward pass.

required

Returns:

Type Description
Transform

A new Transform instance that inverts this one.

Source code in src/torchio/transforms/transform.py
def inverse(self, params: dict[str, Any]) -> Transform:
    """Return a transform that undoes this one.

    Override in invertible subclasses. The returned transform,
    when applied, reverses the effect of the forward pass with
    the given parameters.

    Args:
        params: The parameters recorded in the forward pass.

    Returns:
        A new `Transform` instance that inverts this one.
    """
    msg = f"{type(self).__name__} is not invertible"
    raise NotImplementedError(msg)

to_hydra()

Export as a Hydra-compatible config dict.

Returns a dict with _target_ set to the fully qualified class name and only non-default field values included.

Returns:

Type Description
dict[str, Any]

Dict suitable for hydra.utils.instantiate().

Source code in src/torchio/transforms/transform.py
def to_hydra(self) -> dict[str, Any]:
    """Export as a Hydra-compatible config dict.

    Returns a dict with `_target_` set to the fully qualified
    class name and only non-default field values included.

    Returns:
        Dict suitable for `hydra.utils.instantiate()`.
    """
    from .parameter_range import _ParameterRange

    cls = type(self)
    target = f"torchio.{cls.__qualname__}"
    cfg: dict[str, Any] = {"_target_": target}

    for name, default in _collect_init_params(cls).items():
        value = getattr(self, name, default)
        if isinstance(value, _ParameterRange):
            if value._original == default:
                continue
            value = _hydra_value(value._original)
        elif value == default:
            continue
        else:
            value = _hydra_value(value)
        cfg[name] = value
    return cfg

make_params(batch)

Sample per-label mean and std values.

Source code in src/torchio/transforms/intensity/labels_to_image.py
def make_params(self, batch: SubjectsBatch) -> dict[str, Any]:
    """Sample per-label mean and std values."""
    label_batch = self._find_label_batch(batch)
    # Discover unique labels from the first sample.
    unique = sorted(int(v) for v in label_batch.data[0].unique().tolist())
    means: dict[int, float] = {}
    stds: dict[int, float] = {}
    for idx, label in enumerate(unique):
        if self.ignore_background and label == 0:
            means[label] = 0.0
            stds[label] = 0.0
            continue
        if self.mean_ranges is not None and idx < len(self.mean_ranges):
            means[label] = self.mean_ranges[idx].sample_1d()
        else:
            means[label] = self.default_mean.sample_1d()
        if self.std_ranges is not None and idx < len(self.std_ranges):
            stds[label] = self.std_ranges[idx].sample_1d()
        else:
            stds[label] = abs(self.default_std.sample_1d())
    return {"means": means, "stds": stds}

apply_transform(batch, params)

Generate a synthetic image and add it to the batch.

Source code in src/torchio/transforms/intensity/labels_to_image.py
def apply_transform(
    self,
    batch: SubjectsBatch,
    params: dict[str, Any],
) -> SubjectsBatch:
    """Generate a synthetic image and add it to the batch."""
    label_batch = self._find_label_batch(batch)
    means = params["means"]
    stds = params["stds"]
    generated = _generate_from_labels(
        label_batch.data,
        means,
        stds,
    )
    # Create a new image batch entry.
    from ...data.batch import ImagesBatch

    new_batch = ImagesBatch(
        data=generated,
        affines=label_batch.affines,
        image_class=ScalarImage,
    )
    batch.images[self.image_key] = new_batch
    return batch