Skip to content

Spike

Bases: IntensityTransform

Add random MRI spike artifacts.

Also known as herringbone artifact, crisscross artifact, or corduroy artifact. Spikes in k-space create stripes in image space.

The artifact is simulated by adding point impulses to the Fourier spectrum of the image. All operations use torch.fft and run on GPU.

Parameters:

Name Type Description Default
num_spikes int | tuple[int, int]

Number of spikes. A scalar \(n\) is deterministic; a 2-tuple \((a, b)\) samples \(n \sim \mathcal{U}(a, b) \cap \mathbb{N}\).

1
intensity float | tuple[float, float]

Ratio between the spike amplitude and the spectrum maximum. A scalar is deterministic; a 2-tuple \((a, b)\) means \(r \sim \mathcal{U}(a, b)\). The default intensity=0 is a no-op (and warns).

0.0
**kwargs Any

See Transform.

{}
Note

Execution time does not depend on the number of spikes.

Examples:

>>> import torchio as tio
>>> transform = tio.Spike(intensity=2.0)
>>> transform = tio.Spike(num_spikes=3, intensity=2.0)
Source code in src/torchio/transforms/intensity/spike.py
class Spike(IntensityTransform):
    r"""Add random MRI spike artifacts.

    Also known as
    [herringbone artifact](https://radiopaedia.org/articles/herringbone-artifact),
    crisscross artifact, or corduroy artifact.  Spikes in k-space
    create stripes in image space.

    The artifact is simulated by adding point impulses to the Fourier
    spectrum of the image.  All operations use `torch.fft` and run
    on GPU.

    Args:
        num_spikes: Number of spikes.  A scalar $n$ is deterministic;
            a 2-tuple $(a, b)$ samples
            $n \sim \mathcal{U}(a, b) \cap \mathbb{N}$.
        intensity: Ratio between the spike amplitude and the spectrum
            maximum.  A scalar is deterministic; a 2-tuple $(a, b)$
            means $r \sim \mathcal{U}(a, b)$.
            The default `intensity=0` is a no-op (and warns).
        **kwargs: See [`Transform`][torchio.Transform].

    Note:
        Execution time does not depend on the number of spikes.

    Examples:
        >>> import torchio as tio
        >>> transform = tio.Spike(intensity=2.0)
        >>> transform = tio.Spike(num_spikes=3, intensity=2.0)
    """

    def __init__(
        self,
        *,
        num_spikes: int | tuple[int, int] = 1,
        intensity: float | tuple[float, float] = 0.0,
        **kwargs: Any,
    ) -> None:
        super().__init__(**kwargs)
        self.num_spikes = to_nonneg_range(num_spikes)
        self.intensity = to_range(intensity)
        self._warn_if_noop(
            is_noop=self.intensity.is_constant(0.0) or self.num_spikes.is_constant(0.0),
            hint="intensity=(1, 3)",
        )

    def make_params(self, batch: SubjectsBatch) -> dict[str, Any]:
        """Sample the number and positions of spikes (per element when batched)."""
        n = self._resolve_n(batch)
        if n is None:
            num_spikes = max(1, round(self.num_spikes.sample_1d()))
            positions = torch.rand(num_spikes, 3).tolist()
            intensity = self.intensity.sample_1d()
            return {
                "positions": positions,
                "intensity": intensity,
            }
        keep = self._keep_mask(batch, n)
        positions_list: list[list[list[float]]] = []
        intensity_list: list[float] = []
        keep_values = [True] * n if keep is None else keep.tolist()
        for should_keep in keep_values:
            if not should_keep:
                positions_list.append([])
                intensity_list.append(0.0)
                continue
            num_spikes = max(1, round(self.num_spikes.sample_1d()))
            positions_list.append(torch.rand(num_spikes, 3).tolist())
            intensity_list.append(self.intensity.sample_1d())
        params = {
            "positions": positions_list,
            "intensity": intensity_list,
        }
        self._tag_batched(params, batch, n, keep, ["positions", "intensity"])
        return params

    @property
    def supports_per_instance_params(self) -> bool:
        return True

    @property
    def supports_per_instance_p(self) -> bool:
        return True

    def apply_transform(
        self,
        batch: SubjectsBatch,
        params: dict[str, Any],
    ) -> SubjectsBatch:
        """Add spike artifacts to each selected image."""
        per_instance = self._is_per_instance_params(params)
        for _name, img_batch in self._get_images(batch).items():
            if per_instance:
                img_batch.data = _add_spikes_per_instance(
                    img_batch.data,
                    params["positions"],
                    params["intensity"],
                )
            else:
                img_batch.data = _add_spikes(
                    img_batch.data,
                    params["positions"],
                    params["intensity"],
                )
        return batch

invertible property

Whether this transform can be inverted.

forward(data)

forward(data: Subject) -> Subject
forward(data: Image) -> Image
forward(data: Tensor) -> Tensor
forward(data: np.ndarray) -> np.ndarray
forward(data: sitk.Image) -> sitk.Image
forward(data: nib.Nifti1Image) -> nib.Nifti1Image
forward(data: dict) -> dict
forward(data: ImagesBatch) -> ImagesBatch
forward(data: SubjectsBatch) -> SubjectsBatch

Apply the transform.

The output type always matches the input type.

Parameters:

Name Type Description Default
data Any

Input data to transform.

required
Source code in src/torchio/transforms/transform.py
def forward(self, data: Any) -> Any:
    """Apply the transform.

    The output type always matches the input type.

    Args:
        data: Input data to transform.
    """
    if self.copy:
        data = _copy.deepcopy(data)
    batch, unwrap = self._wrap(data)
    # When per-element gating is active, the transform handles the
    # probability itself (masked-out elements get identity params),
    # so skip the batch-wide coin flip here. Apply iff rand < p, so
    # p=0 is always a no-op and p=1 always applies.
    if not self._per_instance_p_active(batch) and torch.rand(1).item() >= self.p:
        return unwrap(batch)
    params = self.make_params(batch)
    batch = self.apply_transform(batch, params)
    # Record history on the batch, unless every element was gated out by
    # per-element probability: that is an exact no-op, and recording it
    # would let history replay (e.g. an invertible spatial transform)
    # trigger an unnecessary identity resample.
    if not _all_elements_gated_out(params):
        trace = AppliedTransform(name=type(self).__name__, params=params)
        if not hasattr(batch, "applied_transforms"):
            batch.applied_transforms = []
        batch.applied_transforms.append(trace)
    result = unwrap(batch)
    # Propagate history to outputs that can carry it
    if (
        hasattr(batch, "applied_transforms")
        and not isinstance(result, (SubjectsBatch, Tensor, np.ndarray))
        and not isinstance(result, dict)
    ):
        with contextlib.suppress(AttributeError):
            result.applied_transforms = list(batch.applied_transforms)
    return result

inverse(params)

Return a transform that undoes this one.

Override in invertible subclasses. The returned transform, when applied, reverses the effect of the forward pass with the given parameters.

Parameters:

Name Type Description Default
params dict[str, Any]

The parameters recorded in the forward pass.

required

Returns:

Type Description
Transform

A new Transform instance that inverts this one.

Source code in src/torchio/transforms/transform.py
def inverse(self, params: dict[str, Any]) -> Transform:
    """Return a transform that undoes this one.

    Override in invertible subclasses. The returned transform,
    when applied, reverses the effect of the forward pass with
    the given parameters.

    Args:
        params: The parameters recorded in the forward pass.

    Returns:
        A new `Transform` instance that inverts this one.
    """
    msg = f"{type(self).__name__} is not invertible"
    raise NotImplementedError(msg)

to_hydra()

Export as a Hydra-compatible config dict.

Returns a dict with _target_ set to the fully qualified class name and only non-default field values included.

Returns:

Type Description
dict[str, Any]

Dict suitable for hydra.utils.instantiate().

Source code in src/torchio/transforms/transform.py
def to_hydra(self) -> dict[str, Any]:
    """Export as a Hydra-compatible config dict.

    Returns a dict with `_target_` set to the fully qualified
    class name and only non-default field values included.

    Returns:
        Dict suitable for `hydra.utils.instantiate()`.
    """
    from .parameter_range import _ParameterRange

    cls = type(self)
    target = f"torchio.{cls.__qualname__}"
    cfg: dict[str, Any] = {"_target_": target}

    for name, default in _collect_init_params(cls).items():
        value = getattr(self, name, default)
        if isinstance(value, _ParameterRange):
            if value._original == default:
                continue
            value = _hydra_value(value._original)
        elif value == default:
            continue
        else:
            value = _hydra_value(value)
        cfg[name] = value
    return cfg

make_params(batch)

Sample the number and positions of spikes (per element when batched).

Source code in src/torchio/transforms/intensity/spike.py
def make_params(self, batch: SubjectsBatch) -> dict[str, Any]:
    """Sample the number and positions of spikes (per element when batched)."""
    n = self._resolve_n(batch)
    if n is None:
        num_spikes = max(1, round(self.num_spikes.sample_1d()))
        positions = torch.rand(num_spikes, 3).tolist()
        intensity = self.intensity.sample_1d()
        return {
            "positions": positions,
            "intensity": intensity,
        }
    keep = self._keep_mask(batch, n)
    positions_list: list[list[list[float]]] = []
    intensity_list: list[float] = []
    keep_values = [True] * n if keep is None else keep.tolist()
    for should_keep in keep_values:
        if not should_keep:
            positions_list.append([])
            intensity_list.append(0.0)
            continue
        num_spikes = max(1, round(self.num_spikes.sample_1d()))
        positions_list.append(torch.rand(num_spikes, 3).tolist())
        intensity_list.append(self.intensity.sample_1d())
    params = {
        "positions": positions_list,
        "intensity": intensity_list,
    }
    self._tag_batched(params, batch, n, keep, ["positions", "intensity"])
    return params

apply_transform(batch, params)

Add spike artifacts to each selected image.

Source code in src/torchio/transforms/intensity/spike.py
def apply_transform(
    self,
    batch: SubjectsBatch,
    params: dict[str, Any],
) -> SubjectsBatch:
    """Add spike artifacts to each selected image."""
    per_instance = self._is_per_instance_params(params)
    for _name, img_batch in self._get_images(batch).items():
        if per_instance:
            img_batch.data = _add_spikes_per_instance(
                img_batch.data,
                params["positions"],
                params["intensity"],
            )
        else:
            img_batch.data = _add_spikes(
                img_batch.data,
                params["positions"],
                params["intensity"],
            )
    return batch