Patch-based pipelines
GridSampler
Bases: PatchSampler, Dataset
Extract patches on a regular grid for dense inference.
A map-style Dataset with known length and random access.
Pass directly to a DataLoader for batched inference.
Typically used with
PatchAggregator.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
subject
|
Subject
|
Subject to extract patches from. |
required |
patch_size
|
int | TypeThreeInts
|
Spatial size of each patch. |
required |
patch_overlap
|
int | TypeThreeInts
|
Overlap between adjacent patches. Must be even.
A single |
0
|
padding_mode
|
str | None
|
If not |
None
|
fill
|
float
|
Fill value when |
0
|
Examples:
>>> sampler = tio.GridSampler(subject, patch_size=64, patch_overlap=8)
>>> loader = DataLoader(sampler, batch_size=4)
>>> aggregator = tio.PatchAggregator(subject.spatial_shape, overlap_mode="hann")
>>> for batch in loader:
... outputs = model(batch.t1.data)
... aggregator.add_batch(outputs, batch.patch_location)
>>> volume = aggregator.get_output()
Source code in src/torchio/data/sampler.py
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 | |
UniformSampler
Bases: PatchSampler, IterableDataset
Random patches with uniform spatial probability.
An IterableDataset for training. Also callable for use with
Queue.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
subject
|
Subject
|
Subject to sample patches from (for Dataset use). |
required |
patch_size
|
int | TypeThreeInts
|
Spatial size of each patch. |
required |
num_patches
|
int | None
|
Number of patches per epoch. If |
None
|
Examples:
>>> sampler = tio.UniformSampler(subject, patch_size=64, num_patches=100)
>>> loader = DataLoader(sampler, batch_size=8)
Source code in src/torchio/data/sampler.py
WeightedSampler
Bases: PatchSampler, IterableDataset
Random patches weighted by a probability map.
An IterableDataset for training with spatial priors.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
subject
|
Subject
|
Subject to sample patches from. |
required |
patch_size
|
int | TypeThreeInts
|
Spatial size of each patch. |
required |
probability_map
|
str
|
Name of the image in the subject to use as sampling weights. |
required |
num_patches
|
int | None
|
Number of patches per epoch. If |
None
|
Source code in src/torchio/data/sampler.py
LabelSampler
Bases: WeightedSampler
Random patches centered on labeled voxels.
An IterableDataset for training with class imbalance.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
subject
|
Subject
|
Subject to sample patches from. |
required |
patch_size
|
int | TypeThreeInts
|
Spatial size of each patch. |
required |
label_name
|
str
|
Name of the label image in the subject. |
required |
label_probabilities
|
dict[int, float] | None
|
Dict mapping label values to sampling
weights. If |
None
|
num_patches
|
int | None
|
Number of patches per epoch. |
None
|
Source code in src/torchio/data/sampler.py
PatchAggregator
Reassemble patches into a full volume.
Handles overlapping patches with configurable blending modes. Supports outputs of different spatial sizes than the input patches (e.g., downsampled feature maps or embeddings).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
spatial_shape
|
TypeThreeInts
|
Output volume spatial shape |
required |
overlap_mode
|
str
|
How to handle overlapping regions:
|
'crop'
|
patch_overlap
|
int | TypeThreeInts
|
The overlap used during sampling, needed
for |
0
|
output_shape
|
TypeThreeInts | None
|
If the model output is spatially smaller than the input patch (e.g., due to strided convolutions), specify the output volume shape here. Patch locations will be scaled accordingly. |
None
|
Examples:
>>> aggregator = tio.PatchAggregator(
... spatial_shape=(256, 256, 176),
... overlap_mode="hann",
... )
>>> for batch in loader:
... outputs = model(batch.t1.data)
... aggregator.add_batch(outputs, locations)
>>> volume = aggregator.get_output()
Source code in src/torchio/data/aggregator.py
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 | |
add_batch(batch, locations)
Add a batch of model outputs to the aggregation buffer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Tensor | dict[str, Tensor]
|
5D tensor |
required |
locations
|
list[PatchLocation]
|
List of |
required |
Source code in src/torchio/data/aggregator.py
get_output(key=None)
Get the aggregated output volume.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
str | None
|
Name of the output to retrieve. If |
None
|
Returns:
| Type | Description |
|---|---|
Tensor
|
The aggregated tensor with shape |
Source code in src/torchio/data/aggregator.py
Queue
Bases: IterableDataset
Buffer of patches for stochastic patch-based training.
Loads and preprocesses subjects in background threads, extracts
random patches via a sampler, and yields them one at a time.
Designed for use with SubjectsLoader or DataLoader.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
subjects
|
Sequence[Subject]
|
Sequence of subjects to sample patches from. |
required |
patch_sampler
|
PatchSampler
|
A sampler (e.g.,
|
required |
max_length
|
int
|
Maximum number of patches held in the buffer. Larger values increase diversity but use more RAM. |
300
|
patches_per_volume
|
int
|
Maximum patches to extract from each subject. The sampler may yield fewer if valid positions are exhausted. |
10
|
num_workers
|
int
|
Number of background threads for loading and preprocessing subjects. Set to 0 for synchronous loading. |
0
|
shuffle_subjects
|
bool
|
Shuffle the subject order at the start of each epoch. |
True
|
shuffle_patches
|
bool
|
Shuffle the buffer after each refill. |
True
|
transform
|
Any | None
|
Optional transform applied to each subject after loading and before patch extraction. |
None
|
subject_sampler
|
Sampler | None
|
A |
None
|
Examples:
>>> queue = tio.Queue(
... subjects,
... patch_sampler=tio.UniformSampler(subject, patch_size=64),
... max_length=300,
... patches_per_volume=10,
... num_workers=4,
... )
>>> loader = SubjectsLoader(queue, batch_size=16)
>>> for batch in loader:
... outputs = model(batch.t1.data)
Source code in src/torchio/data/queue.py
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 | |
num_subjects
property
Number of subjects per epoch.
patches_per_epoch
property
Total patches yielded per epoch (upper bound).
max_memory
property
Estimated max RAM for the patch buffer in bytes.
max_memory_pretty
property
Human-readable max memory estimate.
PatchLocation
dataclass
Spatial location of an extracted patch within a volume.
Attributes:
| Name | Type | Description |
|---|---|---|
index |
TypeThreeInts
|
|
size |
TypeThreeInts
|
|
subject_index |
int | None
|
Optional identifier for multi-subject batches. |
Source code in src/torchio/data/patch.py
index_ini
property
Starting voxel indices (i, j, k).
index_fin
property
One-past-the-end voxel indices.
to_slices()
Convert to spatial slices for tensor indexing.
scaled(factor)
Return a new location with indices and size scaled by factor.