Skip to content

Visualization

plot_image(image, *, channel=0, indices=None, coordinates=None, axes=None, cmap=None, percentiles=(0.5, 99.5), figsize=None, title=None, output_path=None, show=True, savefig_kwargs=None, voxels=False, figsize_multiplier=2.0, intersections=True, **imshow_kwargs)

plot_image(image: Image, *, show: Literal[False], channel: int = ..., indices: tuple[int | None, int | None, int | None] | None = ..., coordinates: tuple[float | None, float | None, float | None] | None = ..., axes: Sequence[Axes] | None = ..., cmap: str | Colormap | dict[int, tuple[int, int, int]] | None = ..., percentiles: tuple[float, float] = ..., figsize: tuple[float, float] | None = ..., title: str | None = ..., output_path: TypePath | None = ..., savefig_kwargs: dict[str, Any] | None = ..., voxels: bool = ..., figsize_multiplier: float = ..., intersections: bool = ..., **imshow_kwargs: Any) -> Figure
plot_image(image: Image, *, show: Literal[True] = ..., channel: int = ..., indices: tuple[int | None, int | None, int | None] | None = ..., coordinates: tuple[float | None, float | None, float | None] | None = ..., axes: Sequence[Axes] | None = ..., cmap: str | Colormap | dict[int, tuple[int, int, int]] | None = ..., percentiles: tuple[float, float] = ..., figsize: tuple[float, float] | None = ..., title: str | None = ..., output_path: TypePath | None = ..., savefig_kwargs: dict[str, Any] | None = ..., voxels: bool = ..., figsize_multiplier: float = ..., intersections: bool = ..., **imshow_kwargs: Any) -> None

Plot 3 orthogonal slices of a 3D image.

Always displays Sagittal, Coronal, Axial with fixed anatomical positions regardless of image orientation. Data is flipped and transposed as needed. Uses lazy Image.__getitem__ so only the 3 requested planes are read from disk.

Parameters:

Name Type Description Default
image Image

The image to plot.

required
channel int

Which channel to display.

0
indices tuple[int | None, int | None, int | None] | None

Slice index for each spatial axis. None entries default to the mid-slice. Pass None for all mid-slices. Mutually exclusive with coordinates.

None
coordinates tuple[float | None, float | None, float | None] | None

World coordinates in mm for each slice. None entries default to the mid-slice. Converted to the nearest voxel index via the inverse affine. Mutually exclusive with indices.

None
axes Sequence[Axes] | None

Pre-created sequence of 3 matplotlib Axes. If None, a new figure with correct proportions is created.

None
cmap str | Colormap | dict[int, tuple[int, int, int]] | None

Colormap. Defaults to 'gray' for intensity images.

None
percentiles tuple[float, float]

Intensity percentile range for display windowing. Ignored for label maps.

(0.5, 99.5)
figsize tuple[float, float] | None

Figure size in inches (width, height).

None
title str | None

Figure super-title.

None
output_path TypePath | None

Save figure to this path.

None
show bool

Call plt.show() after plotting.

True
savefig_kwargs dict[str, Any] | None

Extra keyword arguments for fig.savefig().

None
voxels bool

Show voxel indices on ticks instead of world coordinates in mm.

False
figsize_multiplier float

Scale factor applied to the default rcParams["figure.figsize"] when figsize is None.

2.0
intersections bool

Draw coloured cross-hair lines showing where the other two slices intersect each view.

True
**imshow_kwargs Any

Forwarded to ax.imshow().

{}

Returns:

Type Description
Figure | None

The matplotlib Figure, or None when show=True

Figure | None

(the figure is displayed and closed to prevent duplicate

Figure | None

rendering in notebooks).

Source code in src/torchio/visualization.py
def plot_image(
    image: Image,
    *,
    channel: int = 0,
    indices: tuple[int | None, int | None, int | None] | None = None,
    coordinates: tuple[float | None, float | None, float | None] | None = None,
    axes: Sequence[Axes] | None = None,
    cmap: str | Colormap | dict[int, tuple[int, int, int]] | None = None,
    percentiles: tuple[float, float] = (0.5, 99.5),
    figsize: tuple[float, float] | None = None,
    title: str | None = None,
    output_path: TypePath | None = None,
    show: bool = True,
    savefig_kwargs: dict[str, Any] | None = None,
    voxels: bool = False,
    figsize_multiplier: float = 2.0,
    intersections: bool = True,
    **imshow_kwargs: Any,
) -> Figure | None:
    """Plot 3 orthogonal slices of a 3D image.

    Always displays Sagittal, Coronal, Axial with fixed anatomical
    positions regardless of image orientation. Data is flipped and
    transposed as needed. Uses lazy `Image.__getitem__` so only
    the 3 requested planes are read from disk.

    Args:
        image: The image to plot.
        channel: Which channel to display.
        indices: Slice index for each spatial axis. `None` entries
            default to the mid-slice. Pass `None` for all mid-slices.
            Mutually exclusive with `coordinates`.
        coordinates: World coordinates in mm for each slice.
            `None` entries default to the mid-slice. Converted to
            the nearest voxel index via the inverse affine. Mutually
            exclusive with `indices`.
        axes: Pre-created sequence of 3 matplotlib `Axes`. If
            `None`, a new figure with correct proportions is created.
        cmap: Colormap. Defaults to `'gray'` for intensity images.
        percentiles: Intensity percentile range for display windowing.
            Ignored for label maps.
        figsize: Figure size in inches `(width, height)`.
        title: Figure super-title.
        output_path: Save figure to this path.
        show: Call `plt.show()` after plotting.
        savefig_kwargs: Extra keyword arguments for `fig.savefig()`.
        voxels: Show voxel indices on ticks instead of world
            coordinates in mm.
        figsize_multiplier: Scale factor applied to the default
            `rcParams["figure.figsize"]` when `figsize` is `None`.
        intersections: Draw coloured cross-hair lines showing where
            the other two slices intersect each view.
        **imshow_kwargs: Forwarded to `ax.imshow()`.

    Returns:
        The matplotlib `Figure`, or `None` when `show=True`
        (the figure is displayed and closed to prevent duplicate
        rendering in notebooks).
    """
    mpl, plt = _get_mpl()

    resolved = _resolve_indices(image, indices, coordinates)

    # Read spatial metadata from headers (no data load)
    spatial_shape = image.spatial_shape
    spacing = image.spacing
    orientation = image.orientation

    # Find tensor axis for each anatomical pair
    axis_for: dict[str, int] = {}
    for pair in ("LR", "AP", "SI"):
        axis_for[pair] = _find_axis(orientation, pair)

    # Compute physical extents for proportional subplot sizing
    lr_mm = spatial_shape[axis_for["LR"]] * spacing[axis_for["LR"]]
    ap_mm = spatial_shape[axis_for["AP"]] * spacing[axis_for["AP"]]
    width_ratios = [ap_mm, lr_mm, lr_mm]

    # Create figure if needed
    fig: Figure
    if axes is None:
        if figsize is None:
            default_w, default_h = plt.rcParams["figure.figsize"]
            figsize = (
                default_w * figsize_multiplier,
                default_h * figsize_multiplier,
            )
        gs = mpl.gridspec.GridSpec(1, 3, width_ratios=width_ratios)
        fig = plt.figure(figsize=figsize)
        plt.close(fig)
        plot_axes: Sequence[Axes] = [fig.add_subplot(gs[0, i]) for i in range(3)]
    else:
        if len(axes) < 3:
            msg = f"Expected 3 axes, got {len(axes)}"
            raise ValueError(msg)
        plot_axes = axes
        fig = cast("Figure", plot_axes[0].get_figure())

    _plot_image_on_axes(
        image=image,
        plot_axes=plot_axes,
        channel=channel,
        resolved=resolved,
        cmap=cmap,
        percentiles=percentiles,
        voxels=voxels,
        intersections=intersections,
        **imshow_kwargs,
    )

    if title is not None:
        fig.suptitle(title)
    fig.tight_layout()

    if output_path is not None:
        fig.savefig(output_path, **(savefig_kwargs or {}))
    if show:
        _display_figure(fig)
        return None

    return fig

plot_subject(subject, *, channel=0, indices=None, coordinates=None, cmap_dict=None, percentiles=(0.5, 99.5), figsize=None, title=None, output_path=None, show=True, savefig_kwargs=None, voxels=False, figsize_multiplier=2.0, intersections=True, **imshow_kwargs)

plot_subject(subject: Subject, *, show: Literal[False], channel: int = ..., indices: tuple[int | None, int | None, int | None] | None = ..., coordinates: tuple[float | None, float | None, float | None] | None = ..., cmap_dict: dict[str, Any] | None = ..., percentiles: tuple[float, float] = ..., figsize: tuple[float, float] | None = ..., title: str | None = ..., output_path: TypePath | None = ..., savefig_kwargs: dict[str, Any] | None = ..., voxels: bool = ..., figsize_multiplier: float = ..., intersections: bool = ..., **imshow_kwargs: Any) -> Figure
plot_subject(subject: Subject, *, show: Literal[True] = ..., channel: int = ..., indices: tuple[int | None, int | None, int | None] | None = ..., coordinates: tuple[float | None, float | None, float | None] | None = ..., cmap_dict: dict[str, Any] | None = ..., percentiles: tuple[float, float] = ..., figsize: tuple[float, float] | None = ..., title: str | None = ..., output_path: TypePath | None = ..., savefig_kwargs: dict[str, Any] | None = ..., voxels: bool = ..., figsize_multiplier: float = ..., intersections: bool = ..., **imshow_kwargs: Any) -> None

Plot all images in a subject as a grid.

Each image gets a row (or column if >3 images) of Sagittal, Coronal, Axial views. LabelMaps are automatically detected and use categorical colormaps.

Parameters:

Name Type Description Default
subject Subject

The subject to plot.

required
channel int

Which channel to display.

0
indices tuple[int | None, int | None, int | None] | None

Voxel indices for each slice. Mutually exclusive with coordinates.

None
coordinates tuple[float | None, float | None, float | None] | None

World coordinates in mm. Mutually exclusive with indices.

None
cmap_dict dict[str, Any] | None

Per-image colormap overrides, keyed by image name.

None
percentiles tuple[float, float]

Intensity percentile range for windowing.

(0.5, 99.5)
figsize tuple[float, float] | None

Figure size in inches.

None
title str | None

Figure super-title.

None
output_path TypePath | None

Save figure to this path.

None
show bool

Call plt.show() after plotting.

True
savefig_kwargs dict[str, Any] | None

Extra keyword arguments for fig.savefig().

None
voxels bool

Show voxel ticks instead of world coordinates.

False
figsize_multiplier float

Scale factor for default figure size.

2.0
intersections bool

Draw slice intersection cross-hairs.

True
**imshow_kwargs Any

Forwarded to ax.imshow().

{}

Returns:

Type Description
Figure | None

The Figure, or None when show=True.

Source code in src/torchio/visualization.py
def plot_subject(
    subject: Subject,
    *,
    channel: int = 0,
    indices: tuple[int | None, int | None, int | None] | None = None,
    coordinates: tuple[float | None, float | None, float | None] | None = None,
    cmap_dict: dict[str, Any] | None = None,
    percentiles: tuple[float, float] = (0.5, 99.5),
    figsize: tuple[float, float] | None = None,
    title: str | None = None,
    output_path: TypePath | None = None,
    show: bool = True,
    savefig_kwargs: dict[str, Any] | None = None,
    voxels: bool = False,
    figsize_multiplier: float = 2.0,
    intersections: bool = True,
    **imshow_kwargs: Any,
) -> Figure | None:
    """Plot all images in a subject as a grid.

    Each image gets a row (or column if >3 images) of Sagittal,
    Coronal, Axial views. LabelMaps are automatically detected and
    use categorical colormaps.

    Args:
        subject: The subject to plot.
        channel: Which channel to display.
        indices: Voxel indices for each slice. Mutually exclusive
            with `coordinates`.
        coordinates: World coordinates in mm. Mutually exclusive
            with `indices`.
        cmap_dict: Per-image colormap overrides, keyed by image name.
        percentiles: Intensity percentile range for windowing.
        figsize: Figure size in inches.
        title: Figure super-title.
        output_path: Save figure to this path.
        show: Call `plt.show()` after plotting.
        savefig_kwargs: Extra keyword arguments for `fig.savefig()`.
        voxels: Show voxel ticks instead of world coordinates.
        figsize_multiplier: Scale factor for default figure size.
        intersections: Draw slice intersection cross-hairs.
        **imshow_kwargs: Forwarded to `ax.imshow()`.

    Returns:
        The `Figure`, or `None` when `show=True`.
    """
    mpl, plt = _get_mpl()

    images = subject.images
    num_images = len(images)
    if num_images == 0:
        msg = "Subject has no images to plot"
        raise ValueError(msg)

    first_image = next(iter(images.values()))
    _resolve_indices(first_image, indices, coordinates)

    many = num_images > 3
    fig, all_axes = _create_subject_grid(
        first_image,
        num_images,
        many,
        figsize,
        figsize_multiplier,
        mpl,
        plt,
    )

    _populate_subject_grid(
        images,
        all_axes,
        many,
        indices,
        coordinates,
        channel=channel,
        cmap_dict=cmap_dict,
        percentiles=percentiles,
        voxels=voxels,
        intersections=intersections,
        **imshow_kwargs,
    )

    if title is not None:
        fig.suptitle(title)
    fig.tight_layout()

    if output_path is not None:
        fig.savefig(output_path, **(savefig_kwargs or {}))
    if show:
        _display_figure(fig)
        return None

    return fig