Skip to content

Patch-based pipelines

GridSampler

Bases: PatchSampler, Dataset

Extract patches on a regular grid for dense inference.

A map-style Dataset with known length and random access. Pass directly to a DataLoader for batched inference. Typically used with PatchAggregator.

Parameters:

Name Type Description Default
subject Subject

Subject to extract patches from.

required
patch_size int | TypeThreeInts

Spatial size of each patch.

required
patch_overlap int | TypeThreeInts

Overlap between adjacent patches. Must be even. A single int is broadcast to all axes.

0
padding_mode str | None

If not None, pad the volume by overlap // 2 on each side before sampling.

None
fill float

Fill value when padding_mode='constant'.

0

Examples:

>>> sampler = tio.GridSampler(subject, patch_size=64, patch_overlap=8)
>>> loader = DataLoader(sampler, batch_size=4)
>>> aggregator = tio.PatchAggregator(subject.spatial_shape, overlap_mode="hann")
>>> for batch in loader:
...     outputs = model(batch.t1.data)
...     aggregator.add_batch(outputs, batch.patch_location)
>>> volume = aggregator.get_output()
Source code in src/torchio/data/sampler.py
class GridSampler(PatchSampler, Dataset):
    """Extract patches on a regular grid for dense inference.

    A map-style `Dataset` with known length and random access.
    Pass directly to a `DataLoader` for batched inference.
    Typically used with
    [`PatchAggregator`][torchio.data.PatchAggregator].

    Args:
        subject: Subject to extract patches from.
        patch_size: Spatial size of each patch.
        patch_overlap: Overlap between adjacent patches. Must be even.
            A single `int` is broadcast to all axes.
        padding_mode: If not `None`, pad the volume by
            `overlap // 2` on each side before sampling.
        fill: Fill value when `padding_mode='constant'`.

    Examples:
        >>> sampler = tio.GridSampler(subject, patch_size=64, patch_overlap=8)
        >>> loader = DataLoader(sampler, batch_size=4)
        >>> aggregator = tio.PatchAggregator(subject.spatial_shape, overlap_mode="hann")
        >>> for batch in loader:
        ...     outputs = model(batch.t1.data)
        ...     aggregator.add_batch(outputs, batch.patch_location)
        >>> volume = aggregator.get_output()
    """

    def __init__(
        self,
        subject: Subject,
        patch_size: int | TypeThreeInts,
        patch_overlap: int | TypeThreeInts = 0,
        padding_mode: str | None = None,
        fill: float = 0,
    ) -> None:
        super().__init__(patch_size)
        if isinstance(patch_overlap, int):
            patch_overlap = (patch_overlap, patch_overlap, patch_overlap)
        self.patch_overlap: TypeThreeInts = patch_overlap
        self.padding_mode = padding_mode
        self.fill = fill
        self.subject = self._maybe_pad(subject)
        self.locations = self._compute_locations(self.subject.spatial_shape)

    def __len__(self) -> int:
        return len(self.locations)

    def __getitem__(self, index: int) -> Subject:
        return self._extract_patch(self.subject, self.locations[index])

    def _maybe_pad(self, subject: Subject) -> Subject:
        if self.padding_mode is None:
            return subject
        from ..transforms.spatial.pad import Pad

        border = tuple(v // 2 for v in self.patch_overlap)
        padding = (
            border[0],
            border[0],
            border[1],
            border[1],
            border[2],
            border[2],
        )
        pad = Pad(
            padding=padding,
            padding_mode=self.padding_mode,
            fill=self.fill,
            copy=False,
        )
        return pad(subject)

    def _compute_locations(
        self,
        spatial_shape: TypeThreeInts,
    ) -> list[PatchLocation]:
        """Compute grid locations covering the volume."""
        locations: list[PatchLocation] = []
        indices_per_axis: list[list[int]] = []
        for dim in range(3):
            size = spatial_shape[dim]
            patch = self.patch_size[dim]
            overlap = self.patch_overlap[dim]
            step = max(patch - overlap, 1)
            indices = list(range(0, size - patch + 1, step))
            if not indices or indices[-1] != size - patch:
                indices.append(max(size - patch, 0))
            indices_per_axis.append(indices)

        for i in indices_per_axis[0]:
            for j in indices_per_axis[1]:
                for k in indices_per_axis[2]:
                    locations.append(
                        PatchLocation(
                            index=(i, j, k),
                            size=self.patch_size,
                        ),
                    )
        return locations

UniformSampler

Bases: PatchSampler, IterableDataset

Random patches with uniform spatial probability.

An IterableDataset for training. Also callable for use with Queue.

Parameters:

Name Type Description Default
subject Subject

Subject to sample patches from (for Dataset use).

required
patch_size int | TypeThreeInts

Spatial size of each patch.

required
num_patches int | None

Number of patches per epoch. If None, yields indefinitely.

None

Examples:

>>> sampler = tio.UniformSampler(subject, patch_size=64, num_patches=100)
>>> loader = DataLoader(sampler, batch_size=8)
Source code in src/torchio/data/sampler.py
class UniformSampler(PatchSampler, IterableDataset):
    """Random patches with uniform spatial probability.

    An `IterableDataset` for training. Also callable for use with
    [`Queue`][torchio.data.Queue].

    Args:
        subject: Subject to sample patches from (for Dataset use).
        patch_size: Spatial size of each patch.
        num_patches: Number of patches per epoch. If `None`,
            yields indefinitely.

    Examples:
        >>> sampler = tio.UniformSampler(subject, patch_size=64, num_patches=100)
        >>> loader = DataLoader(sampler, batch_size=8)
    """

    def __init__(
        self,
        subject: Subject,
        patch_size: int | TypeThreeInts,
        num_patches: int | None = None,
    ) -> None:
        super().__init__(patch_size)
        self.subject = subject
        self.num_patches = num_patches

    def __call__(
        self,
        subject: Subject,
        num_patches: int | None = None,
    ) -> Iterator[Subject]:
        """Sample random patches from a given subject."""
        limit = num_patches or self.num_patches
        count = 0
        while limit is None or count < limit:
            index = self._random_index(subject.spatial_shape)
            loc = PatchLocation(index=index, size=self.patch_size)
            yield self._extract_patch(subject, loc)
            count += 1

    def __iter__(self) -> Iterator[Subject]:
        return self(self.subject, self.num_patches)

    def _random_index(
        self,
        spatial_shape: TypeThreeInts,
    ) -> TypeThreeInts:
        def _rand(d: int) -> int:
            hi = max(spatial_shape[d] - self.patch_size[d], 0) + 1
            return int(torch.randint(0, hi, (1,)).item())

        return (_rand(0), _rand(1), _rand(2))

WeightedSampler

Bases: PatchSampler, IterableDataset

Random patches weighted by a probability map.

An IterableDataset for training with spatial priors.

Parameters:

Name Type Description Default
subject Subject

Subject to sample patches from.

required
patch_size int | TypeThreeInts

Spatial size of each patch.

required
probability_map str

Name of the image in the subject to use as sampling weights.

required
num_patches int | None

Number of patches per epoch. If None, yields indefinitely.

None
Source code in src/torchio/data/sampler.py
class WeightedSampler(PatchSampler, IterableDataset):
    """Random patches weighted by a probability map.

    An `IterableDataset` for training with spatial priors.

    Args:
        subject: Subject to sample patches from.
        patch_size: Spatial size of each patch.
        probability_map: Name of the image in the subject to use
            as sampling weights.
        num_patches: Number of patches per epoch. If `None`,
            yields indefinitely.
    """

    def __init__(
        self,
        subject: Subject,
        patch_size: int | TypeThreeInts,
        probability_map: str,
        num_patches: int | None = None,
    ) -> None:
        super().__init__(patch_size)
        self.subject = subject
        self.probability_map = probability_map
        self.num_patches = num_patches

    def __call__(
        self,
        subject: Subject,
        num_patches: int | None = None,
    ) -> Iterator[Subject]:
        """Sample weighted patches from a given subject."""
        prob_data = self._build_probability_map_for(subject)
        flat = prob_data.flatten()
        if flat.sum() == 0:
            msg = f"Probability map '{self.probability_map}' is all zeros"
            raise RuntimeError(msg)

        limit = num_patches or self.num_patches
        count = 0
        while limit is None or count < limit:
            idx_flat = torch.multinomial(flat, 1).item()
            center = tuple(
                int(x) for x in np.unravel_index(int(idx_flat), prob_data.shape)
            )
            index = _center_to_corner(center, subject.spatial_shape, self.patch_size)
            loc = PatchLocation(index=index, size=self.patch_size)
            yield self._extract_patch(subject, loc)
            count += 1

    def __iter__(self) -> Iterator[Subject]:
        return self(self.subject, self.num_patches)

    def _build_probability_map_for(self, subject: Subject) -> Tensor:
        prob_image = subject.images[self.probability_map]
        prob_data = prob_image.data[0].float()
        return _mask_borders(prob_data, subject.spatial_shape, self.patch_size)

    def _build_probability_map(self) -> Tensor:
        return self._build_probability_map_for(self.subject)

LabelSampler

Bases: WeightedSampler

Random patches centered on labeled voxels.

An IterableDataset for training with class imbalance.

Parameters:

Name Type Description Default
subject Subject

Subject to sample patches from.

required
patch_size int | TypeThreeInts

Spatial size of each patch.

required
label_name str

Name of the label image in the subject.

required
label_probabilities dict[int, float] | None

Dict mapping label values to sampling weights. If None, all non-zero labels have equal weight.

None
num_patches int | None

Number of patches per epoch.

None
Source code in src/torchio/data/sampler.py
class LabelSampler(WeightedSampler):
    """Random patches centered on labeled voxels.

    An `IterableDataset` for training with class imbalance.

    Args:
        subject: Subject to sample patches from.
        patch_size: Spatial size of each patch.
        label_name: Name of the label image in the subject.
        label_probabilities: Dict mapping label values to sampling
            weights. If `None`, all non-zero labels have equal
            weight.
        num_patches: Number of patches per epoch.
    """

    def __init__(
        self,
        subject: Subject,
        patch_size: int | TypeThreeInts,
        label_name: str,
        label_probabilities: dict[int, float] | None = None,
        num_patches: int | None = None,
    ) -> None:
        super().__init__(
            subject,
            patch_size,
            probability_map=label_name,
            num_patches=num_patches,
        )
        self.label_name = label_name
        self.label_probabilities = label_probabilities

    def _build_probability_map_for(self, subject: Subject) -> Tensor:
        label_image = subject.images[self.label_name]
        label_data = label_image.data[0]

        if self.label_probabilities is not None:
            prob = torch.zeros_like(label_data, dtype=torch.float32)
            for label, weight in self.label_probabilities.items():
                prob[label_data == label] = weight
        else:
            prob = (label_data > 0).float()

        return _mask_borders(prob, subject.spatial_shape, self.patch_size)

    def _build_probability_map(self) -> Tensor:
        return self._build_probability_map_for(self.subject)

PatchAggregator

Reassemble patches into a full volume.

Handles overlapping patches with configurable blending modes. Supports outputs of different spatial sizes than the input patches (e.g., downsampled feature maps or embeddings).

Parameters:

Name Type Description Default
spatial_shape TypeThreeInts

Output volume spatial shape (I, J, K).

required
overlap_mode str

How to handle overlapping regions: 'crop' keeps only non-overlapping centers (fast, best for argmax segmentation); 'average' averages overlapping values (best for probabilistic outputs); 'hann' uses Hann-window weighting (smoothest, best for continuous outputs).

'crop'
patch_overlap int | TypeThreeInts

The overlap used during sampling, needed for 'crop' mode to compute how much to trim.

0
output_shape TypeThreeInts | None

If the model output is spatially smaller than the input patch (e.g., due to strided convolutions), specify the output volume shape here. Patch locations will be scaled accordingly.

None

Examples:

>>> aggregator = tio.PatchAggregator(
...     spatial_shape=(256, 256, 176),
...     overlap_mode="hann",
... )
>>> for batch in loader:
...     outputs = model(batch.t1.data)
...     aggregator.add_batch(outputs, locations)
>>> volume = aggregator.get_output()
Source code in src/torchio/data/aggregator.py
class PatchAggregator:
    """Reassemble patches into a full volume.

    Handles overlapping patches with configurable blending modes.
    Supports outputs of different spatial sizes than the input
    patches (e.g., downsampled feature maps or embeddings).

    Args:
        spatial_shape: Output volume spatial shape `(I, J, K)`.
        overlap_mode: How to handle overlapping regions:
            `'crop'` keeps only non-overlapping centers (fast,
            best for argmax segmentation);
            `'average'` averages overlapping values (best for
            probabilistic outputs);
            `'hann'` uses Hann-window weighting (smoothest,
            best for continuous outputs).
        patch_overlap: The overlap used during sampling, needed
            for `'crop'` mode to compute how much to trim.
        output_shape: If the model output is spatially smaller
            than the input patch (e.g., due to strided
            convolutions), specify the output volume shape here.
            Patch locations will be scaled accordingly.

    Examples:
        >>> aggregator = tio.PatchAggregator(
        ...     spatial_shape=(256, 256, 176),
        ...     overlap_mode="hann",
        ... )
        >>> for batch in loader:
        ...     outputs = model(batch.t1.data)
        ...     aggregator.add_batch(outputs, locations)
        >>> volume = aggregator.get_output()
    """

    def __init__(
        self,
        spatial_shape: TypeThreeInts,
        overlap_mode: str = "crop",
        patch_overlap: int | TypeThreeInts = 0,
        output_shape: TypeThreeInts | None = None,
    ) -> None:
        _validate_overlap_mode(overlap_mode)
        self.input_spatial_shape = spatial_shape
        self.overlap_mode = overlap_mode

        if isinstance(patch_overlap, int):
            patch_overlap = (patch_overlap, patch_overlap, patch_overlap)
        self.patch_overlap: TypeThreeInts = patch_overlap

        if output_shape is not None:
            self.spatial_shape = output_shape
            self._scale = (
                output_shape[0] / spatial_shape[0],
                output_shape[1] / spatial_shape[1],
                output_shape[2] / spatial_shape[2],
            )
        else:
            self.spatial_shape = spatial_shape
            self._scale = (1.0, 1.0, 1.0)

        self._outputs: dict[str, Tensor] = {}
        self._counts: dict[str, Tensor] = {}
        self._hann_cache: dict[TypeThreeInts, Tensor] = {}

    def add_batch(
        self,
        batch: Tensor | dict[str, Tensor],
        locations: list[PatchLocation],
    ) -> None:
        """Add a batch of model outputs to the aggregation buffer.

        Args:
            batch: 5D tensor `(B, C, I, J, K)` or dict of such
                tensors keyed by name.
            locations: List of `PatchLocation` for each item in
                the batch.
        """
        tensors: dict[str, Tensor] = (
            {"__default__": batch} if isinstance(batch, Tensor) else batch
        )

        for key, tensor in tensors.items():
            tensor = tensor.cpu()
            for idx, loc in enumerate(locations):
                patch = tensor[idx]
                if self._scale != (1.0, 1.0, 1.0):
                    loc = loc.scaled(self._scale)
                self._add_patch(key, patch, loc)

    def get_output(self, key: str | None = None) -> Tensor:
        """Get the aggregated output volume.

        Args:
            key: Name of the output to retrieve. If `None` and
                only a single (unnamed) output was added, return it.

        Returns:
            The aggregated tensor with shape `(C, I, J, K)`.
        """
        resolve_key = key if key is not None else "__default__"
        if resolve_key not in self._outputs:
            available = [k for k in self._outputs if k != "__default__"]
            msg = f"No output for key {key!r}. Available: {available}"
            raise KeyError(msg)

        output = self._outputs[resolve_key]

        if self.overlap_mode in ("average", "hann"):
            counts = self._counts[resolve_key]
            counts = counts.clamp(min=1)
            output = output / counts

        return output

    def _add_patch(
        self,
        key: str,
        patch: Tensor,
        location: PatchLocation,
    ) -> None:
        self._ensure_buffer(key, patch)
        match self.overlap_mode:
            case "crop":
                self._add_crop(key, patch, location)
            case "average":
                self._add_average(key, patch, location)
            case "hann":
                self._add_hann(key, patch, location)

    def _ensure_buffer(self, key: str, patch: Tensor) -> None:
        if key in self._outputs:
            return
        num_channels = patch.shape[0]
        self._outputs[key] = torch.zeros(
            num_channels,
            *self.spatial_shape,
            dtype=patch.dtype,
        )
        if self.overlap_mode in ("average", "hann"):
            self._counts[key] = torch.zeros(
                num_channels,
                *self.spatial_shape,
                dtype=patch.dtype,
            )

    def _add_crop(
        self,
        key: str,
        patch: Tensor,
        location: PatchLocation,
    ) -> None:
        """Place only the non-overlapping center of the patch."""
        scaled_overlap = (
            round(self.patch_overlap[0] * self._scale[0]),
            round(self.patch_overlap[1] * self._scale[1]),
            round(self.patch_overlap[2] * self._scale[2]),
        )
        half = [o // 2 for o in scaled_overlap]
        ini = list(location.index_ini)
        fin = list(location.index_fin)
        crop_ini = [0, 0, 0]
        crop_fin = list(location.size)

        for d in range(3):
            if ini[d] > 0:
                ini[d] += half[d]
                crop_ini[d] += half[d]
            if fin[d] < self.spatial_shape[d]:
                fin[d] -= half[d]
                crop_fin[d] -= half[d]

        cropped = patch[
            :,
            crop_ini[0] : crop_fin[0],
            crop_ini[1] : crop_fin[1],
            crop_ini[2] : crop_fin[2],
        ]
        self._outputs[key][
            :,
            ini[0] : fin[0],
            ini[1] : fin[1],
            ini[2] : fin[2],
        ] = cropped

    def _add_average(
        self,
        key: str,
        patch: Tensor,
        location: PatchLocation,
    ) -> None:
        si, sj, sk = location.to_slices()
        self._outputs[key][:, si, sj, sk] += patch
        self._counts[key][:, si, sj, sk] += 1

    def _add_hann(
        self,
        key: str,
        patch: Tensor,
        location: PatchLocation,
    ) -> None:
        patch_shape = (
            patch.shape[-3],
            patch.shape[-2],
            patch.shape[-1],
        )
        window = self._get_hann_window(patch_shape)
        si, sj, sk = location.to_slices()
        self._outputs[key][:, si, sj, sk] += patch * window
        self._counts[key][:, si, sj, sk] += window

    def _get_hann_window(self, patch_size: TypeThreeInts) -> Tensor:
        if patch_size in self._hann_cache:
            return self._hann_cache[patch_size]
        window = _build_hann_3d(patch_size)
        self._hann_cache[patch_size] = window
        return window

add_batch(batch, locations)

Add a batch of model outputs to the aggregation buffer.

Parameters:

Name Type Description Default
batch Tensor | dict[str, Tensor]

5D tensor (B, C, I, J, K) or dict of such tensors keyed by name.

required
locations list[PatchLocation]

List of PatchLocation for each item in the batch.

required
Source code in src/torchio/data/aggregator.py
def add_batch(
    self,
    batch: Tensor | dict[str, Tensor],
    locations: list[PatchLocation],
) -> None:
    """Add a batch of model outputs to the aggregation buffer.

    Args:
        batch: 5D tensor `(B, C, I, J, K)` or dict of such
            tensors keyed by name.
        locations: List of `PatchLocation` for each item in
            the batch.
    """
    tensors: dict[str, Tensor] = (
        {"__default__": batch} if isinstance(batch, Tensor) else batch
    )

    for key, tensor in tensors.items():
        tensor = tensor.cpu()
        for idx, loc in enumerate(locations):
            patch = tensor[idx]
            if self._scale != (1.0, 1.0, 1.0):
                loc = loc.scaled(self._scale)
            self._add_patch(key, patch, loc)

get_output(key=None)

Get the aggregated output volume.

Parameters:

Name Type Description Default
key str | None

Name of the output to retrieve. If None and only a single (unnamed) output was added, return it.

None

Returns:

Type Description
Tensor

The aggregated tensor with shape (C, I, J, K).

Source code in src/torchio/data/aggregator.py
def get_output(self, key: str | None = None) -> Tensor:
    """Get the aggregated output volume.

    Args:
        key: Name of the output to retrieve. If `None` and
            only a single (unnamed) output was added, return it.

    Returns:
        The aggregated tensor with shape `(C, I, J, K)`.
    """
    resolve_key = key if key is not None else "__default__"
    if resolve_key not in self._outputs:
        available = [k for k in self._outputs if k != "__default__"]
        msg = f"No output for key {key!r}. Available: {available}"
        raise KeyError(msg)

    output = self._outputs[resolve_key]

    if self.overlap_mode in ("average", "hann"):
        counts = self._counts[resolve_key]
        counts = counts.clamp(min=1)
        output = output / counts

    return output

Queue

Bases: IterableDataset

Buffer of patches for stochastic patch-based training.

Loads and preprocesses subjects in background threads, extracts random patches via a sampler, and yields them one at a time. Designed for use with SubjectsLoader or DataLoader.

Parameters:

Name Type Description Default
subjects Sequence[Subject]

Sequence of subjects to sample patches from.

required
patch_sampler PatchSampler

A sampler (e.g., UniformSampler) used to extract patches from loaded subjects. The sampler must accept a subject and num_patches in its __call__.

required
max_length int

Maximum number of patches held in the buffer. Larger values increase diversity but use more RAM.

300
patches_per_volume int

Maximum patches to extract from each subject. The sampler may yield fewer if valid positions are exhausted.

10
num_workers int

Number of background threads for loading and preprocessing subjects. Set to 0 for synchronous loading.

0
shuffle_subjects bool

Shuffle the subject order at the start of each epoch.

True
shuffle_patches bool

Shuffle the buffer after each refill.

True
transform Any | None

Optional transform applied to each subject after loading and before patch extraction.

None
subject_sampler Sampler | None

A torch.utils.data.Sampler (e.g., DistributedSampler) that yields subject indices. When provided, shuffle_subjects must be False.

None

Examples:

>>> queue = tio.Queue(
...     subjects,
...     patch_sampler=tio.UniformSampler(subject, patch_size=64),
...     max_length=300,
...     patches_per_volume=10,
...     num_workers=4,
... )
>>> loader = SubjectsLoader(queue, batch_size=16)
>>> for batch in loader:
...     outputs = model(batch.t1.data)
Source code in src/torchio/data/queue.py
class Queue(IterableDataset):
    """Buffer of patches for stochastic patch-based training.

    Loads and preprocesses subjects in background threads, extracts
    random patches via a sampler, and yields them one at a time.
    Designed for use with `SubjectsLoader` or `DataLoader`.

    Args:
        subjects: Sequence of subjects to sample patches from.
        patch_sampler: A sampler (e.g.,
            [`UniformSampler`][torchio.data.UniformSampler]) used to
            extract patches from loaded subjects. The sampler must
            accept a subject and `num_patches` in its `__call__`.
        max_length: Maximum number of patches held in the buffer.
            Larger values increase diversity but use more RAM.
        patches_per_volume: Maximum patches to extract from each
            subject. The sampler may yield fewer if valid positions
            are exhausted.
        num_workers: Number of background threads for loading and
            preprocessing subjects. Set to 0 for synchronous loading.
        shuffle_subjects: Shuffle the subject order at the start of
            each epoch.
        shuffle_patches: Shuffle the buffer after each refill.
        transform: Optional transform applied to each subject after
            loading and before patch extraction.
        subject_sampler: A `torch.utils.data.Sampler` (e.g.,
            `DistributedSampler`) that yields subject indices.
            When provided, `shuffle_subjects` must be `False`.

    Examples:
        >>> queue = tio.Queue(
        ...     subjects,
        ...     patch_sampler=tio.UniformSampler(subject, patch_size=64),
        ...     max_length=300,
        ...     patches_per_volume=10,
        ...     num_workers=4,
        ... )
        >>> loader = SubjectsLoader(queue, batch_size=16)
        >>> for batch in loader:
        ...     outputs = model(batch.t1.data)
    """

    def __init__(
        self,
        subjects: Sequence[Subject],
        patch_sampler: PatchSampler,
        max_length: int = 300,
        patches_per_volume: int = 10,
        num_workers: int = 0,
        shuffle_subjects: bool = True,
        shuffle_patches: bool = True,
        transform: Any | None = None,
        subject_sampler: Sampler | None = None,
    ) -> None:
        if subject_sampler is not None and shuffle_subjects:
            msg = (
                "shuffle_subjects must be False when subject_sampler"
                " is provided (the sampler controls the order)"
            )
            raise ValueError(msg)
        self.subjects = subjects
        self.patch_sampler = patch_sampler
        self.max_length = max_length
        self.patches_per_volume = patches_per_volume
        self.num_workers = num_workers
        self.shuffle_subjects = shuffle_subjects
        self.shuffle_patches = shuffle_patches
        self.transform = transform
        self.subject_sampler = subject_sampler

    def __iter__(self) -> Iterator[Subject]:
        """Yield patches, loading subjects in the background."""
        buffer: list[Subject] = []
        subject_iter = self._make_subject_iter()

        if self.num_workers > 0:
            yield from self._iter_threaded(subject_iter, buffer)
        else:
            yield from self._iter_sync(subject_iter, buffer)

    def _iter_sync(
        self,
        subject_iter: Iterator[Subject],
        buffer: list[Subject],
    ) -> Iterator[Subject]:
        for raw in subject_iter:
            prepared = self._prepare(raw)
            buffer.extend(self._sample_patches(prepared))
            yield from self._drain_if_full(buffer)
        yield from self._flush(buffer)

    def _iter_threaded(
        self,
        subject_iter: Iterator[Subject],
        buffer: list[Subject],
    ) -> Iterator[Subject]:
        with ThreadPoolExecutor(max_workers=self.num_workers) as pool:
            futures: deque[Future[Subject]] = deque()

            for raw in subject_iter:
                futures.append(pool.submit(self._prepare, raw))
                yield from self._collect_ready(futures, buffer)
                yield from self._drain_if_full(buffer)

            # Drain remaining futures
            for future in futures:
                prepared = future.result()
                buffer.extend(self._sample_patches(prepared))

        yield from self._flush(buffer)

    def _collect_ready(
        self,
        futures: deque[Future[Subject]],
        buffer: list[Subject],
    ) -> Iterator[Subject]:
        """Move patches from completed futures into the buffer."""
        while futures and futures[0].done():
            prepared = futures.popleft().result()
            buffer.extend(self._sample_patches(prepared))
        return iter(())  # nothing to yield yet

    def _drain_if_full(self, buffer: list[Subject]) -> Iterator[Subject]:
        """Yield all patches from buffer if it reached max_length."""
        if len(buffer) >= self.max_length:
            yield from self._flush(buffer)

    def _flush(self, buffer: list[Subject]) -> Iterator[Subject]:
        """Shuffle (if enabled) and yield all patches from buffer."""
        if self.shuffle_patches:
            _random.shuffle(buffer)
        while buffer:
            yield buffer.pop()

    def _prepare(self, subject: Subject) -> Subject:
        """Load images and apply transform (may run in a thread)."""
        subject.load()
        if self.transform is not None:
            subject = self.transform(subject)
        return subject

    def _sample_patches(self, subject: Subject) -> list[Subject]:
        """Extract up to patches_per_volume patches."""
        gen = iter(self.patch_sampler(subject))
        return list(islice(gen, self.patches_per_volume))

    def _make_subject_iter(self) -> Iterator[Subject]:
        """Build the subject iterator for one epoch."""
        if self.subject_sampler is not None:
            indices = list(self.subject_sampler)
            return (self.subjects[i] for i in indices)
        subjects = list(self.subjects)
        if self.shuffle_subjects:
            _random.shuffle(subjects)
        return iter(subjects)

    @property
    def num_subjects(self) -> int:
        """Number of subjects per epoch."""
        sampler = self.subject_sampler
        if sampler is not None:
            if not isinstance(sampler, Sized):
                msg = "subject_sampler must have a __len__ method"
                raise TypeError(msg)
            return len(sampler)
        return len(self.subjects)

    @property
    def patches_per_epoch(self) -> int:
        """Total patches yielded per epoch (upper bound)."""
        return self.num_subjects * self.patches_per_volume

    @property
    def max_memory(self) -> int:
        """Estimated max RAM for the patch buffer in bytes."""
        sample = self.subjects[0]
        channels = sum(img.num_channels for img in sample.images.values())
        voxels = 1
        for s in self.patch_sampler.patch_size:
            voxels *= s
        return 4 * channels * voxels * self.max_length

    @property
    def max_memory_pretty(self) -> str:
        """Human-readable max memory estimate."""
        return humanize.naturalsize(self.max_memory, binary=True)

num_subjects property

Number of subjects per epoch.

patches_per_epoch property

Total patches yielded per epoch (upper bound).

max_memory property

Estimated max RAM for the patch buffer in bytes.

max_memory_pretty property

Human-readable max memory estimate.

PatchLocation dataclass

Spatial location of an extracted patch within a volume.

Attributes:

Name Type Description
index TypeThreeInts

(i, j, k) voxel indices of the patch corner (the corner closest to the origin).

size TypeThreeInts

(si, sj, sk) spatial shape of the patch.

subject_index int | None

Optional identifier for multi-subject batches.

Source code in src/torchio/data/patch.py
@dataclass(frozen=True)
class PatchLocation:
    """Spatial location of an extracted patch within a volume.

    Attributes:
        index: `(i, j, k)` voxel indices of the patch corner
            (the corner closest to the origin).
        size: `(si, sj, sk)` spatial shape of the patch.
        subject_index: Optional identifier for multi-subject batches.
    """

    index: TypeThreeInts
    size: TypeThreeInts
    subject_index: int | None = None

    @property
    def index_ini(self) -> TypeThreeInts:
        """Starting voxel indices `(i, j, k)`."""
        return self.index

    @property
    def index_fin(self) -> TypeThreeInts:
        """One-past-the-end voxel indices."""
        return (
            self.index[0] + self.size[0],
            self.index[1] + self.size[1],
            self.index[2] + self.size[2],
        )

    def to_slices(self) -> tuple[slice, slice, slice]:
        """Convert to spatial slices for tensor indexing."""
        ini = self.index_ini
        fin = self.index_fin
        return (
            slice(ini[0], fin[0]),
            slice(ini[1], fin[1]),
            slice(ini[2], fin[2]),
        )

    def scaled(self, factor: tuple[float, float, float]) -> PatchLocation:
        """Return a new location with indices and size scaled by factor."""
        return PatchLocation(
            index=(
                round(self.index[0] * factor[0]),
                round(self.index[1] * factor[1]),
                round(self.index[2] * factor[2]),
            ),
            size=(
                round(self.size[0] * factor[0]),
                round(self.size[1] * factor[1]),
                round(self.size[2] * factor[2]),
            ),
            subject_index=self.subject_index,
        )

index_ini property

Starting voxel indices (i, j, k).

index_fin property

One-past-the-end voxel indices.

to_slices()

Convert to spatial slices for tensor indexing.

Source code in src/torchio/data/patch.py
def to_slices(self) -> tuple[slice, slice, slice]:
    """Convert to spatial slices for tensor indexing."""
    ini = self.index_ini
    fin = self.index_fin
    return (
        slice(ini[0], fin[0]),
        slice(ini[1], fin[1]),
        slice(ini[2], fin[2]),
    )

scaled(factor)

Return a new location with indices and size scaled by factor.

Source code in src/torchio/data/patch.py
def scaled(self, factor: tuple[float, float, float]) -> PatchLocation:
    """Return a new location with indices and size scaled by factor."""
    return PatchLocation(
        index=(
            round(self.index[0] * factor[0]),
            round(self.index[1] * factor[1]),
            round(self.index[2] * factor[2]),
        ),
        size=(
            round(self.size[0] * factor[0]),
            round(self.size[1] * factor[1]),
            round(self.size[2] * factor[2]),
        ),
        subject_index=self.subject_index,
    )