Skip to content

Data loading

Loaders

SubjectsLoader

Bases: DataLoader

DataLoader that returns SubjectsBatch instances.

A thin wrapper around torch.utils.data.DataLoader that collates Subject instances into SubjectsBatch.

Parameters:

Name Type Description Default
dataset Dataset

A dataset that returns Subject instances.

required
**kwargs Any

Passed to DataLoader.__init__.

{}

Examples:

>>> loader = tio.SubjectsLoader(dataset, batch_size=4)
>>> batch = next(iter(loader))
>>> batch.t1.data.shape
torch.Size([4, 1, 256, 256, 176])
Source code in src/torchio/loader.py
class SubjectsLoader(DataLoader):
    """DataLoader that returns `SubjectsBatch` instances.

    A thin wrapper around `torch.utils.data.DataLoader` that
    collates `Subject` instances into `SubjectsBatch`.

    Args:
        dataset: A dataset that returns `Subject` instances.
        **kwargs: Passed to `DataLoader.__init__`.

    Examples:
        >>> loader = tio.SubjectsLoader(dataset, batch_size=4)
        >>> batch = next(iter(loader))
        >>> batch.t1.data.shape
        torch.Size([4, 1, 256, 256, 176])
    """

    def __init__(self, dataset: Dataset, **kwargs: Any) -> None:
        if "collate_fn" in kwargs:
            msg = (
                "SubjectsLoader sets collate_fn automatically; "
                "pass a plain DataLoader if you need a custom collate_fn"
            )
            raise ValueError(msg)
        super().__init__(dataset, collate_fn=collate_subjects, **kwargs)

ImagesLoader

Bases: DataLoader

DataLoader that returns ImagesBatch instances.

A thin wrapper around torch.utils.data.DataLoader that collates Image instances into ImagesBatch.

Parameters:

Name Type Description Default
dataset Dataset

A dataset that returns Image instances.

required
**kwargs Any

Passed to DataLoader.__init__.

{}

Examples:

>>> loader = tio.ImagesLoader(dataset, batch_size=4)
>>> batch = next(iter(loader))
>>> batch.data.shape
torch.Size([4, 1, 256, 256, 176])
Source code in src/torchio/loader.py
class ImagesLoader(DataLoader):
    """DataLoader that returns `ImagesBatch` instances.

    A thin wrapper around `torch.utils.data.DataLoader` that
    collates `Image` instances into `ImagesBatch`.

    Args:
        dataset: A dataset that returns `Image` instances.
        **kwargs: Passed to `DataLoader.__init__`.

    Examples:
        >>> loader = tio.ImagesLoader(dataset, batch_size=4)
        >>> batch = next(iter(loader))
        >>> batch.data.shape
        torch.Size([4, 1, 256, 256, 176])
    """

    def __init__(self, dataset: Dataset, **kwargs: Any) -> None:
        if "collate_fn" in kwargs:
            msg = (
                "ImagesLoader sets collate_fn automatically; "
                "pass a plain DataLoader if you need a custom collate_fn"
            )
            raise ValueError(msg)
        super().__init__(dataset, collate_fn=collate_images, **kwargs)

Collation functions

collate_subjects(batch)

Collate a list of Subjects into a SubjectsBatch.

Parameters:

Name Type Description Default
batch Sequence[Any]

Sequence of Subject instances.

required

Returns:

Type Description
SubjectsBatch

A SubjectsBatch with stacked 5D tensors.

Source code in src/torchio/loader.py
def collate_subjects(batch: Sequence[Any]) -> SubjectsBatch:
    """Collate a list of Subjects into a SubjectsBatch.

    Args:
        batch: Sequence of `Subject` instances.

    Returns:
        A `SubjectsBatch` with stacked 5D tensors.
    """
    return SubjectsBatch.from_subjects(list(batch))

collate_images(batch)

Collate a list of Images into an ImagesBatch.

Parameters:

Name Type Description Default
batch Sequence[Any]

Sequence of Image instances.

required

Returns:

Type Description
ImagesBatch

An ImagesBatch with a stacked 5D tensor.

Source code in src/torchio/loader.py
def collate_images(batch: Sequence[Any]) -> ImagesBatch:
    """Collate a list of Images into an ImagesBatch.

    Args:
        batch: Sequence of `Image` instances.

    Returns:
        An `ImagesBatch` with a stacked 5D tensor.
    """
    return ImagesBatch.from_images(list(batch))

Batch containers

SubjectsBatch

Bases: Invertible

A batch of subjects with stacked image data.

Each named image entry becomes an ImagesBatch. Metadata is stored as lists (one value per sample).

Created by SubjectsLoader or SubjectsBatch.from_subjects().

Source code in src/torchio/data/batch.py
class SubjectsBatch(Invertible):
    """A batch of subjects with stacked image data.

    Each named image entry becomes an `ImagesBatch`. Metadata is
    stored as lists (one value per sample).

    Created by `SubjectsLoader` or `SubjectsBatch.from_subjects()`.
    """

    def __init__(
        self,
        images: dict[str, ImagesBatch],
        *,
        metadata: dict[str, list[Any]] | None = None,
    ) -> None:
        self._images = images
        self._metadata: dict[str, list[Any]] = metadata or {}
        self.applied_transforms: list[Any] = []

    @classmethod
    def from_subjects(cls, subjects: list[Any]) -> Self:
        """Stack a list of subjects into a batch.

        Args:
            subjects: List of `Subject` instances.
        """
        from .subject import Subject

        if not subjects:
            msg = "Cannot create batch from empty list"
            raise ValueError(msg)

        # Collect image names and types from the first subject
        first: Subject = subjects[0]
        image_names = list(first.images.keys())

        # Stack images
        images: dict[str, ImagesBatch] = {}
        for name in image_names:
            img_list = [sub.images[name] for sub in subjects]
            images[name] = ImagesBatch.from_images(img_list)

        # Collect metadata (non-image, non-annotation entries)
        metadata: dict[str, list[Any]] = {}
        for key in first.metadata:
            metadata[key] = [sub.metadata[key] for sub in subjects]

        return cls(images, metadata=metadata)

    @property
    def batch_size(self) -> int:
        """Number of samples in the batch."""
        first = next(iter(self._images.values()))
        return first.batch_size

    @property
    def images(self) -> dict[str, ImagesBatch]:
        """Dict of named image batches."""
        return self._images

    @property
    def metadata(self) -> dict[str, list[Any]]:
        """Metadata lists (one value per sample)."""
        return self._metadata

    @property
    def device(self) -> torch.device:
        """Device of the batch data."""
        first = next(iter(self._images.values()))
        return first.device

    def to(self, *args: Any, **kwargs: Any) -> Self:
        """Move all data to a device and/or cast dtype."""
        for batch in self._images.values():
            batch.to(*args, **kwargs)
        return self

    def __getitem__(self, key: str) -> ImagesBatch:
        """Get a named image batch."""
        return self._images[key]

    def __getattr__(self, name: str) -> ImagesBatch:
        """Attribute-style access to image batches."""
        if name.startswith("_"):
            raise AttributeError(name)
        if name in self._images:
            return self._images[name]
        msg = f"SubjectsBatch has no attribute {name!r}"
        raise AttributeError(msg)

    def unbatch(self) -> list[Any]:
        """Split the batch back into individual Subjects."""
        from .subject import Subject

        n = self.batch_size
        subjects = []
        for i in range(n):
            kwargs: dict[str, Any] = {}
            for name, img_batch in self._images.items():
                kwargs[name] = img_batch[i]
            for key, values in self._metadata.items():
                kwargs[key] = values[i]
            sub = Subject(**kwargs)
            sub.applied_transforms = list(self.applied_transforms)
            subjects.append(sub)
        return subjects

    def __len__(self) -> int:
        return self.batch_size

    def __repr__(self) -> str:
        names = ", ".join(self._images.keys())
        return f"SubjectsBatch(batch_size={self.batch_size}, images=[{names}])"

batch_size property

Number of samples in the batch.

images property

Dict of named image batches.

metadata property

Metadata lists (one value per sample).

device property

Device of the batch data.

get_inverse_transform(*, warn=True, ignore_intensity=False)

Get a composed transform that inverts the applied history.

Returns a Compose of the inverse of each applied transform, in reverse order. Non-invertible transforms are skipped (with a warning if warn=True).

Parameters:

Name Type Description Default
warn bool

Issue a warning for non-invertible transforms.

True
ignore_intensity bool

Skip all intensity transforms.

False

Returns:

Type Description
Any

A Compose transform that undoes the history.

Source code in src/torchio/data/invertible.py
def get_inverse_transform(
    self,
    *,
    warn: bool = True,
    ignore_intensity: bool = False,
) -> Any:
    """Get a composed transform that inverts the applied history.

    Returns a [`Compose`][torchio.Compose] of the inverse of each
    applied transform, in reverse order. Non-invertible transforms
    are skipped (with a warning if `warn=True`).

    Args:
        warn: Issue a warning for non-invertible transforms.
        ignore_intensity: Skip all intensity transforms.

    Returns:
        A `Compose` transform that undoes the history.
    """
    from ..transforms.inverse import get_inverse_transform

    return get_inverse_transform(
        self.applied_transforms,
        warn=warn,
        ignore_intensity=ignore_intensity,
    )

apply_inverse_transform(**kwargs)

Apply the inverse of all applied transforms, in reverse order.

Non-invertible transforms are skipped. Intensity transforms can be ignored with ignore_intensity=True.

Parameters:

Name Type Description Default
**kwargs Any

Forwarded to get_inverse_transform() (warn, ignore_intensity).

{}

Returns:

Type Description
Self

Data with transforms undone.

Examples:

>>> transformed = transform(subject)
>>> restored = transformed.apply_inverse_transform()
Source code in src/torchio/data/invertible.py
def apply_inverse_transform(self, **kwargs: Any) -> Self:
    """Apply the inverse of all applied transforms, in reverse order.

    Non-invertible transforms are skipped. Intensity transforms
    can be ignored with `ignore_intensity=True`.

    Args:
        **kwargs: Forwarded to
            `get_inverse_transform()` (`warn`,
            `ignore_intensity`).

    Returns:
        Data with transforms undone.

    Examples:
        >>> transformed = transform(subject)
        >>> restored = transformed.apply_inverse_transform()
    """
    inverse_transform = self.get_inverse_transform(**kwargs)
    result = inverse_transform(self)
    if hasattr(result, "applied_transforms"):
        result.applied_transforms = []
    return result

clear_history()

Remove all applied transform records.

Source code in src/torchio/data/invertible.py
def clear_history(self) -> None:
    """Remove all applied transform records."""
    self.applied_transforms = []

from_subjects(subjects) classmethod

Stack a list of subjects into a batch.

Parameters:

Name Type Description Default
subjects list[Any]

List of Subject instances.

required
Source code in src/torchio/data/batch.py
@classmethod
def from_subjects(cls, subjects: list[Any]) -> Self:
    """Stack a list of subjects into a batch.

    Args:
        subjects: List of `Subject` instances.
    """
    from .subject import Subject

    if not subjects:
        msg = "Cannot create batch from empty list"
        raise ValueError(msg)

    # Collect image names and types from the first subject
    first: Subject = subjects[0]
    image_names = list(first.images.keys())

    # Stack images
    images: dict[str, ImagesBatch] = {}
    for name in image_names:
        img_list = [sub.images[name] for sub in subjects]
        images[name] = ImagesBatch.from_images(img_list)

    # Collect metadata (non-image, non-annotation entries)
    metadata: dict[str, list[Any]] = {}
    for key in first.metadata:
        metadata[key] = [sub.metadata[key] for sub in subjects]

    return cls(images, metadata=metadata)

to(*args, **kwargs)

Move all data to a device and/or cast dtype.

Source code in src/torchio/data/batch.py
def to(self, *args: Any, **kwargs: Any) -> Self:
    """Move all data to a device and/or cast dtype."""
    for batch in self._images.values():
        batch.to(*args, **kwargs)
    return self

unbatch()

Split the batch back into individual Subjects.

Source code in src/torchio/data/batch.py
def unbatch(self) -> list[Any]:
    """Split the batch back into individual Subjects."""
    from .subject import Subject

    n = self.batch_size
    subjects = []
    for i in range(n):
        kwargs: dict[str, Any] = {}
        for name, img_batch in self._images.items():
            kwargs[name] = img_batch[i]
        for key, values in self._metadata.items():
            kwargs[key] = values[i]
        sub = Subject(**kwargs)
        sub.applied_transforms = list(self.applied_transforms)
        subjects.append(sub)
    return subjects

ImagesBatch

Bases: Invertible

A batch of images with per-sample affines.

Wraps a 5D tensor (B, C, I, J, K) and a list of AffineMatrix matrices (one per sample). Created by stacking multiple Image objects or directly from a 5D tensor.

Parameters:

Name Type Description Default
data Tensor

5D tensor with shape (B, C, I, J, K).

required
affines list[AffineMatrix]

List of affine matrices, one per sample.

required
image_class type[Image]

The Image subclass to use when unbatching.

ScalarImage
Source code in src/torchio/data/batch.py
class ImagesBatch(Invertible):
    """A batch of images with per-sample affines.

    Wraps a 5D tensor `(B, C, I, J, K)` and a list of `AffineMatrix`
    matrices (one per sample). Created by stacking multiple `Image`
    objects or directly from a 5D tensor.

    Args:
        data: 5D tensor with shape `(B, C, I, J, K)`.
        affines: List of affine matrices, one per sample.
        image_class: The `Image` subclass to use when unbatching.
    """

    def __init__(
        self,
        data: Tensor,
        affines: list[AffineMatrix],
        *,
        image_class: type[Image] = ScalarImage,
    ) -> None:
        if data.ndim != 5:
            msg = f"Expected 5D tensor (B, C, I, J, K), got {data.ndim}D"
            raise ValueError(msg)
        if len(affines) != data.shape[0]:
            msg = f"Expected {data.shape[0]} affines, got {len(affines)}"
            raise ValueError(msg)
        self._data = data
        self._affines = affines
        self._image_class = image_class
        self.applied_transforms: list[Any] = []

    @classmethod
    def from_images(cls, images: list[Image]) -> Self:
        """Stack a list of images into a batch.

        All images must have the same shape.

        Args:
            images: List of `Image` instances to stack.
        """
        if not images:
            msg = "Cannot create batch from empty list"
            raise ValueError(msg)
        tensors = [img.data for img in images]
        stacked = torch.stack(tensors)
        affines = [img.affine.clone() for img in images]
        image_class = type(images[0])
        return cls(stacked, affines, image_class=image_class)

    @property
    def data(self) -> Tensor:
        """5D tensor with shape `(B, C, I, J, K)`."""
        return self._data

    @data.setter
    def data(self, value: Tensor) -> None:
        if value.ndim != 5:
            msg = f"Expected 5D tensor, got {value.ndim}D"
            raise ValueError(msg)
        self._data = value

    @property
    def affines(self) -> list[AffineMatrix]:
        """List of affine matrices, one per sample."""
        return self._affines

    @property
    def batch_size(self) -> int:
        """Number of samples in the batch."""
        return self._data.shape[0]

    @property
    def device(self) -> torch.device:
        """Device the batch data resides on."""
        return self._data.device

    def to(self, *args: Any, **kwargs: Any) -> Self:
        """Move batch data to a device and/or cast dtype."""
        self._data = self._data.to(*args, **kwargs)
        for affine in self._affines:
            affine.to(*args, **kwargs)
        return self

    def __getitem__(self, index: int) -> Image:
        """Get a single image from the batch by index."""
        return self._image_class(
            self._data[index],
            affine=self._affines[index].clone(),
        )

    def __len__(self) -> int:
        return self.batch_size

    def unbatch(self) -> list[Image]:
        """Split the batch into individual images."""
        return [self[i] for i in range(self.batch_size)]

    def __repr__(self) -> str:
        b, c, i, j, k = self._data.shape
        cls = self._image_class.__name__
        return f"ImagesBatch({cls}, batch_size={b}, shape=({c}, {i}, {j}, {k}))"

data property writable

5D tensor with shape (B, C, I, J, K).

affines property

List of affine matrices, one per sample.

batch_size property

Number of samples in the batch.

device property

Device the batch data resides on.

get_inverse_transform(*, warn=True, ignore_intensity=False)

Get a composed transform that inverts the applied history.

Returns a Compose of the inverse of each applied transform, in reverse order. Non-invertible transforms are skipped (with a warning if warn=True).

Parameters:

Name Type Description Default
warn bool

Issue a warning for non-invertible transforms.

True
ignore_intensity bool

Skip all intensity transforms.

False

Returns:

Type Description
Any

A Compose transform that undoes the history.

Source code in src/torchio/data/invertible.py
def get_inverse_transform(
    self,
    *,
    warn: bool = True,
    ignore_intensity: bool = False,
) -> Any:
    """Get a composed transform that inverts the applied history.

    Returns a [`Compose`][torchio.Compose] of the inverse of each
    applied transform, in reverse order. Non-invertible transforms
    are skipped (with a warning if `warn=True`).

    Args:
        warn: Issue a warning for non-invertible transforms.
        ignore_intensity: Skip all intensity transforms.

    Returns:
        A `Compose` transform that undoes the history.
    """
    from ..transforms.inverse import get_inverse_transform

    return get_inverse_transform(
        self.applied_transforms,
        warn=warn,
        ignore_intensity=ignore_intensity,
    )

apply_inverse_transform(**kwargs)

Apply the inverse of all applied transforms, in reverse order.

Non-invertible transforms are skipped. Intensity transforms can be ignored with ignore_intensity=True.

Parameters:

Name Type Description Default
**kwargs Any

Forwarded to get_inverse_transform() (warn, ignore_intensity).

{}

Returns:

Type Description
Self

Data with transforms undone.

Examples:

>>> transformed = transform(subject)
>>> restored = transformed.apply_inverse_transform()
Source code in src/torchio/data/invertible.py
def apply_inverse_transform(self, **kwargs: Any) -> Self:
    """Apply the inverse of all applied transforms, in reverse order.

    Non-invertible transforms are skipped. Intensity transforms
    can be ignored with `ignore_intensity=True`.

    Args:
        **kwargs: Forwarded to
            `get_inverse_transform()` (`warn`,
            `ignore_intensity`).

    Returns:
        Data with transforms undone.

    Examples:
        >>> transformed = transform(subject)
        >>> restored = transformed.apply_inverse_transform()
    """
    inverse_transform = self.get_inverse_transform(**kwargs)
    result = inverse_transform(self)
    if hasattr(result, "applied_transforms"):
        result.applied_transforms = []
    return result

clear_history()

Remove all applied transform records.

Source code in src/torchio/data/invertible.py
def clear_history(self) -> None:
    """Remove all applied transform records."""
    self.applied_transforms = []

from_images(images) classmethod

Stack a list of images into a batch.

All images must have the same shape.

Parameters:

Name Type Description Default
images list[Image]

List of Image instances to stack.

required
Source code in src/torchio/data/batch.py
@classmethod
def from_images(cls, images: list[Image]) -> Self:
    """Stack a list of images into a batch.

    All images must have the same shape.

    Args:
        images: List of `Image` instances to stack.
    """
    if not images:
        msg = "Cannot create batch from empty list"
        raise ValueError(msg)
    tensors = [img.data for img in images]
    stacked = torch.stack(tensors)
    affines = [img.affine.clone() for img in images]
    image_class = type(images[0])
    return cls(stacked, affines, image_class=image_class)

to(*args, **kwargs)

Move batch data to a device and/or cast dtype.

Source code in src/torchio/data/batch.py
def to(self, *args: Any, **kwargs: Any) -> Self:
    """Move batch data to a device and/or cast dtype."""
    self._data = self._data.to(*args, **kwargs)
    for affine in self._affines:
        affine.to(*args, **kwargs)
    return self

unbatch()

Split the batch into individual images.

Source code in src/torchio/data/batch.py
def unbatch(self) -> list[Image]:
    """Split the batch into individual images."""
    return [self[i] for i in range(self.batch_size)]