Source code for Garfield.plot.spatial_plot

## modified by cell2location and scTM
## please refer to https://github.com/JinmiaoChenLab/scTM/blob/main/sctm/pl.py

import anndata as ad
import matplotlib as mpl
import matplotlib.pyplot as plt

import textwrap
import numpy as np
import scanpy as sc

# import seaborn as sns
from matplotlib import rcParams
from matplotlib.axes import Axes
from collections import defaultdict
import pandas as pd
import anndata

# import matplotlib.pyplot as plt
# import numpy as np
from matplotlib.colors import ListedColormap
from matplotlib.gridspec import GridSpec

# import pandas as pd
from matplotlib.patches import Patch

# from upsetplot import plot, from_contents
# from itertools import chain
from scanpy._utils import Empty, _empty
from scanpy.pl._tools.scatterplots import (
    _check_crop_coord,
    _check_img,
    _check_na_color,
    _check_scale_factor,
    _check_spatial_data,
    _check_spot_size,
)


def get_rgb_function(cmap, min_value, max_value):
    r"""Generate a function to map continous values to RGB values using colormap
    between min_value & max_value."""

    if min_value > max_value:
        raise ValueError("Max_value should be greater or than min_value.")

        # if min_value == max_value:
        #     warnings.warn(
        #         "Max_color is equal to min_color. It might be because of the data or
        #  bad
        #         parameter choice. "
        #         "If you are using plot_contours function try increasing
        # max_color_quantile
        #         parameter and"
        #         "removing cell types with all zero values."
        #     )

        def func_equal(x):
            factor = 0 if max_value == 0 else 0.5
            return cmap(np.ones_like(x) * factor)

        return func_equal

    def func(x):
        return cmap(
            (np.clip(x, min_value, max_value) - min_value) / (max_value - min_value)
        )

    return func


def rgb_to_ryb(rgb):
    """
    Converts colours from RGB colorspace to RYB

    Parameters
    ----------

    rgb
        numpy array Nx3

    Returns
    -------
    Numpy array Nx3
    """
    rgb = np.array(rgb)
    if len(rgb.shape) == 1:
        rgb = rgb[np.newaxis, :]

    white = rgb.min(axis=1)
    black = (1 - rgb).min(axis=1)
    rgb = rgb - white[:, np.newaxis]

    yellow = rgb[:, :2].min(axis=1)
    ryb = np.zeros_like(rgb)
    ryb[:, 0] = rgb[:, 0] - yellow
    ryb[:, 1] = (yellow + rgb[:, 1]) / 2
    ryb[:, 2] = (rgb[:, 2] + rgb[:, 1] - yellow) / 2

    mask = ~(ryb == 0).all(axis=1)
    if mask.any():
        norm = ryb[mask].max(axis=1) / rgb[mask].max(axis=1)
        ryb[mask] = ryb[mask] / norm[:, np.newaxis]

    return ryb + black[:, np.newaxis]


def ryb_to_rgb(ryb):
    """
    Converts colours from RYB colorspace to RGB

    Parameters
    ----------

    ryb
        numpy array Nx3

    Returns
    -------
    Numpy array Nx3
    """
    ryb = np.array(ryb)
    if len(ryb.shape) == 1:
        ryb = ryb[np.newaxis, :]

    black = ryb.min(axis=1)
    white = (1 - ryb).min(axis=1)
    ryb = ryb - black[:, np.newaxis]

    green = ryb[:, 1:].min(axis=1)
    rgb = np.zeros_like(ryb)
    rgb[:, 0] = ryb[:, 0] + ryb[:, 1] - green
    rgb[:, 1] = green + ryb[:, 1]
    rgb[:, 2] = (ryb[:, 2] - green) * 2

    mask = ~(ryb == 0).all(axis=1)
    if mask.any():
        norm = rgb[mask].max(axis=1) / ryb[mask].max(axis=1)
        rgb[mask] = rgb[mask] / norm[:, np.newaxis]

    return rgb + white[:, np.newaxis]


def plot_spatial_general(
    value_df,
    coords,
    labels,
    text=None,
    circle_radius=None,
    display_zeros=False,
    figsize=(10, 10),
    alpha_scaling=1.0,
    max_col=(np.inf, np.inf, np.inf, np.inf, np.inf, np.inf, np.inf),
    max_color_quantile=0.98,
    show_img=True,
    img=None,
    img_alpha=1.0,
    adjust_text=False,
    plt_axis="off",
    axis_y_flipped=False,
    x_y_labels=("", ""),
    crop_x=None,
    crop_y=None,
    text_box_alpha=0.9,
    reorder_cmap=range(7),
    style="fast",
    colorbar_position="right",
    colorbar_label_kw={},
    colorbar_shape={},
    colorbar_tick_size=12,
    colorbar_grid=None,
    image_cmap="Greys_r",
    white_spacing=20,
):
    if value_df.shape[1] > 7:
        raise ValueError(
            "Maximum of 7 cell types / factors can be plotted at the moment"
        )

    def create_colormap(R, G, B):
        spacing = int(white_spacing * 2.55)

        N = 255
        M = 3

        alphas = np.concatenate(
            [[0] * spacing * M, np.linspace(0, 1.0, (N - spacing) * M)]
        )

        vals = np.ones((N * M, 4))
        #         vals[:, 0] = np.linspace(1, R / 255, N * M)
        #         vals[:, 1] = np.linspace(1, G / 255, N * M)
        #         vals[:, 2] = np.linspace(1, B / 255, N * M)
        for i, color in enumerate([R, G, B]):
            vals[:, i] = color / 255
        vals[:, 3] = alphas

        return ListedColormap(vals)

    # Create linearly scaled colormaps
    YellowCM = create_colormap(
        240, 228, 66
    )  # #F0E442 ['#F0E442', '#D55E00', '#56B4E9',
    # '#009E73', '#5A14A5', '#C8C8C8', '#323232']
    RedCM = create_colormap(213, 94, 0)  # #D55E00
    BlueCM = create_colormap(86, 180, 233)  # #56B4E9
    GreenCM = create_colormap(0, 158, 115)  # #009E73
    PinkCM = create_colormap(255, 105, 180)  # #C8C8C8
    WhiteCM = create_colormap(50, 50, 50)  # #323232
    PurpleCM = create_colormap(90, 20, 165)  # #5A14A5
    # LightGreyCM = create_colormap(240, 240, 240)  # Very Light Grey: #F0F0F0

    cmaps = [YellowCM, RedCM, BlueCM, GreenCM, PurpleCM, PinkCM, WhiteCM]

    cmaps = [cmaps[i] for i in reorder_cmap]

    with mpl.style.context(style):
        fig = plt.figure(figsize=figsize)
        if colorbar_position == "right":
            if colorbar_grid is None:
                colorbar_grid = (len(labels), 1)

            shape = {
                "vertical_gaps": 1.5,
                "horizontal_gaps": 0,
                "width": 0.15,
                "height": 0.2,
            }
            shape = {**shape, **colorbar_shape}

            gs = GridSpec(
                nrows=colorbar_grid[0] + 2,
                ncols=colorbar_grid[1] + 1,
                width_ratios=[1, *[shape["width"]] * colorbar_grid[1]],
                height_ratios=[1, *[shape["height"]] * colorbar_grid[0], 1],
                hspace=shape["vertical_gaps"],
                wspace=shape["horizontal_gaps"],
            )
            ax = fig.add_subplot(gs[:, 0], aspect="equal", rasterized=True)

        if colorbar_position == "bottom":
            if colorbar_grid is None:
                if len(labels) <= 3:
                    colorbar_grid = (1, len(labels))
                else:
                    n_rows = round(len(labels) / 3 + 0.5 - 1e-9)
                    colorbar_grid = (n_rows, 3)

            shape = {
                "vertical_gaps": 0.3,
                "horizontal_gaps": 0.6,
                "width": 0.2,
                "height": 0.035,
            }
            shape = {**shape, **colorbar_shape}

            gs = GridSpec(
                nrows=colorbar_grid[0] + 1,
                ncols=colorbar_grid[1] + 2,
                width_ratios=[0.3, *[shape["width"]] * colorbar_grid[1], 0.3],
                height_ratios=[1, *[shape["height"]] * colorbar_grid[0]],
                hspace=shape["vertical_gaps"],
                wspace=shape["horizontal_gaps"],
            )

            ax = fig.add_subplot(gs[0, :], aspect="equal", rasterized=True)

        if colorbar_position is None:
            ax = fig.add_subplot(aspect="equal", rasterized=True)

        if colorbar_position is not None:
            cbar_axes = []
            for row in range(1, colorbar_grid[0] + 1):
                for column in range(1, colorbar_grid[1] + 1):
                    cbar_axes.append(fig.add_subplot(gs[row, column]))

            n_excess = colorbar_grid[0] * colorbar_grid[1] - len(labels)
            if n_excess > 0:
                for i in range(1, n_excess + 1):
                    cbar_axes[-i].set_visible(False)

        ax.set_xlabel(x_y_labels[0])
        ax.set_ylabel(x_y_labels[1])

        if img is not None and show_img:
            ax.imshow(img, alpha=img_alpha, cmap=image_cmap)

        # crop images in needed
        if crop_x is not None:
            ax.set_xlim(crop_x[0], crop_x[1])
        if crop_y is not None:
            ax.set_ylim(crop_y[0], crop_y[1])

        if axis_y_flipped:
            ax.invert_yaxis()

        if plt_axis == "off":
            for spine in ax.spines.values():
                spine.set_visible(False)
            ax.tick_params(bottom=False, labelbottom=False, left=False, labelleft=False)

        counts = value_df.values.copy()

        # plot spots as circles
        c_ord = list(np.arange(0, counts.shape[1]))
        colors = np.zeros((*counts.shape, 4))
        weights = np.zeros(counts.shape)

        for c in c_ord:
            min_color_intensity = counts[:, c].min()
            max_color_intensity = np.min(
                [np.quantile(counts[:, c], max_color_quantile), max_col[c]]
            )

            rgb_function = get_rgb_function(
                cmap=cmaps[c],
                min_value=min_color_intensity,
                max_value=max_color_intensity,
            )

            color = rgb_function(counts[:, c])
            color[:, 3] = color[:, 3] * alpha_scaling

            norm = mpl.colors.Normalize(
                vmin=min_color_intensity, vmax=max_color_intensity
            )

            if colorbar_position is not None:
                cbar_ticks = [
                    min_color_intensity,
                    np.mean([min_color_intensity, max_color_intensity]),
                    max_color_intensity,
                ]
                cbar_ticks = np.array(cbar_ticks)

                if max_color_intensity > 13:
                    cbar_ticks = cbar_ticks.astype(np.int32)
                else:
                    cbar_ticks = cbar_ticks.round(2)

                cbar = fig.colorbar(
                    mpl.cm.ScalarMappable(norm=norm, cmap=cmaps[c]),
                    cax=cbar_axes[c],
                    orientation="horizontal",
                    extend="both",
                    ticks=cbar_ticks,
                )

                cbar.ax.tick_params(labelsize=colorbar_tick_size)
                max_color = rgb_function(max_color_intensity / 1.5)
                cbar.ax.set_title(
                    labels[c],
                    **{
                        **{"size": 20, "color": max_color, "alpha": 1},
                        **colorbar_label_kw,
                    },
                )

            colors[:, c] = color
            weights[:, c] = np.clip(counts[:, c] / (max_color_intensity + 1e-10), 0, 1)
            weights[:, c][counts[:, c] < min_color_intensity] = 0

        colors_ryb = np.zeros((*weights.shape, 3))

        for i in range(colors.shape[0]):
            colors_ryb[i] = rgb_to_ryb(colors[i, :, :3])

        def kernel(w):
            return w**2

        kernel_weights = kernel(weights[:, :, np.newaxis])
        weighted_colors_ryb = (colors_ryb * kernel_weights).sum(
            axis=1
        ) / kernel_weights.sum(axis=1)
        weighted_colors = np.zeros((weights.shape[0], 4))
        weighted_colors[:, :3] = ryb_to_rgb(weighted_colors_ryb)
        weighted_colors[:, 3] = colors[:, :, 3].max(axis=1)

        if display_zeros:
            weighted_colors[weighted_colors[:, 3] == 0] = [
                210 / 255,
                210 / 255,
                210 / 255,
                1,
            ]

        ax.scatter(
            x=coords[:, 0], y=coords[:, 1], c=weighted_colors, s=circle_radius**2
        )

        # size in circles is radius
        # add text
        if text is not None:
            bbox_props = dict(boxstyle="round", ec="0.5", alpha=text_box_alpha, fc="w")
            texts = []
            for x, y, s in zip(
                np.array(text.iloc[:, 0].values).flatten(),
                np.array(text.iloc[:, 1].values).flatten(),
                text.iloc[:, 2].tolist(),
            ):
                texts.append(
                    ax.text(x, y, s, ha="center", va="bottom", bbox=bbox_props)
                )

            if adjust_text:
                from adjustText import adjust_text

                adjust_text(texts, arrowprops=dict(arrowstyle="->", color="w", lw=0.5))

    plt.grid(False)
    return fig


[docs] def plot_multi_patterns_spatial( adata, topic_prop, basis="spatial", bw=False, img=None, library_id=_empty, crop_coord=None, img_key=_empty, spot_size=None, na_color=None, scale_factor=None, scale_default=0.5, show_img=True, display_zeros=False, figsize=(10, 10), **kwargs, ): """Plot taken from cell2location at https://github.com/BayraktarLab/cell2location. Able to display zeros and also on umap through the basis function Args: adata (_type_): Adata object with spatial coordinates in adata.obsm['spatial'] topic_prop (_type_): Topic proportion obtained from STAMP. basis (str, optional): Which basis to plot in adata.obsm. Defaults to "spatial". bw (bool, optional): Defaults to False. img (_type_, optional): . Defaults to None. library_id (_type_, optional): _description_. Defaults to _empty. crop_coord (_type_, optional): _description_. Defaults to None. img_key (_type_, optional): _description_. Defaults to _empty. spot_size (_type_, optional): _description_. Defaults to None. na_color (_type_, optional): _description_. Defaults to None. scale_factor (_type_, optional): _description_. Defaults to None. scale_default (float, optional): _description_. Defaults to 0.5. show_img (bool, optional): Whether to display spatial image. Sets to false automatically when displaying umap. Defaults to True. display_zeros (bool, optional): Whether to display cells that have low counts values to grey colour. Defaults to False. figsize (tuple, optional): Figsize of image. Defaults to (10, 10). Returns: _type_: Function taken from cell2location at https://cell2location.readthedocs.io/en/latest/_modules/cell2location/plt/plot_spatial.html#plot_spatial. Able to plot both on spatial and umap coordinates. Still very raw. """ # get default image params if available library_id, spatial_data = _check_spatial_data(adata.uns, library_id) img, img_key = _check_img(spatial_data, img, img_key, bw=bw) spot_size = _check_spot_size(spatial_data, spot_size) scale_factor = _check_scale_factor( spatial_data, img_key=img_key, scale_factor=scale_factor ) crop_coord = _check_crop_coord(crop_coord, scale_factor) na_color = _check_na_color(na_color, img=img) if scale_factor is not None: circle_radius = scale_factor * spot_size * 0.5 * scale_default else: circle_radius = spot_size * 0.5 if show_img is True: kwargs["show_img"] = True kwargs["img"] = img kwargs["coords"] = adata.obsm[basis] * scale_factor fig = plot_spatial_general( value_df=topic_prop, labels=topic_prop.columns, circle_radius=circle_radius, figsize=figsize, display_zeros=display_zeros, **kwargs, ) # cell abundance values plt.gca().invert_yaxis() return fig
[docs] def plot_markers( adata: anndata.AnnData, groupby: str, mks: pd.DataFrame, n_genes: int = 5, kind: str = 'dotplot', remove_genes: list = [], **kwargs ): """ Plot markers for specific groups. Parameters ---------- adata : AnnData AnnData object containing expression data. groupby : str Column in `adata.obs` used for grouping cells. mks : DataFrame DataFrame containing marker statistics. n_genes : int, optional Number of top genes to plot per group. Default is 5. kind : str, optional Type of plot to create ('dotplot', 'violin', etc.). Default is 'dotplot'. remove_genes : list, optional List of genes to exclude from the plot. Default is an empty list. **kwargs : dict Additional keyword arguments passed to the plotting function. Returns ------- matplotlib.Axes or None Axes object of the plot or None if plotting in place. """ df = mks.reset_index()[['index', 'top_frac_group']].rename(columns={'index': 'gene', 'top_frac_group': 'cluster'}) var_tb = adata.raw.var if kwargs.get('use_raw', None) == True or adata.raw else adata.var remove_gene_set = set() for g_cat in remove_genes: if g_cat in var_tb.columns: remove_gene_set |= set(var_tb.index[var_tb[g_cat].values]) df = df[~df.gene.isin(list(remove_gene_set))].copy() df1 = df.groupby('cluster').head(n_genes) mks_dict = defaultdict(list) for c, g in zip(df1.cluster, df1.gene): mks_dict[c].append(g) func = getattr(sc.pl, kind) if sc.__version__.startswith('1.4'): return func(adata, df1.gene.to_list(), groupby=groupby, **kwargs) else: return func(adata, mks_dict, groupby=groupby, **kwargs)
[docs] def niches_enrichment_barplot( enrichments, niche, type="enrichr", # 默认为 enrichr 类型 figsize=(10, 5), n_enrichments=5, qval_cutoff=0.05, title="auto", ): """ Create a barplot for the enrichment results (either from Enrichr or GSEA). Parameters ---------- enrichments : dict Dictionary of enrichment results where the key is the niche name and the value is a dataframe of enrichment results. niche : str The niche (cluster) to visualize enrichment for. type : str, optional The type of enrichment analysis ('enrichr' or 'gsea'). Default is 'gsea'. figsize : tuple, optional Figure size for the plot. n_enrichments : int, optional Number of top enrichment terms to display (default is 5). qval_cutoff : float, optional Adjusted p-value or FDR cutoff to filter enrichment terms (default is 0.05). title : str, optional Title of the plot (default is 'auto', which uses the first gene set's name). Returns ------- ax : matplotlib.axes.Axes The matplotlib axis containing the plot. """ # Check if the niche exists in enrichments if niche not in enrichments: raise KeyError(f"Niche '{niche}' not found in enrichments.") # Extract enrichment data for the given niche enrichment = enrichments[niche] if type == "enrichr": # Filter by qval_cutoff and sort by Adjusted P-value enrichment = enrichment.loc[enrichment["Adjusted P-value"] < qval_cutoff, :] enrichment = enrichment.sort_values("Adjusted P-value") enrichment = enrichment.iloc[:n_enrichments, :] # Set title if title == "auto": title = enrichment["Gene_set"].iloc[0] # Create the plot fig, ax = plt.subplots(figsize=figsize) ax.barh( y=enrichment["Term"], width=-np.log(enrichment["Adjusted P-value"]), fill="blue", align="center", ) # Format y-axis labels ax.set_yticklabels( [textwrap.fill(term, 24) for term in enrichment["Term"].values] ) ax.set_xlabel("- Log Adjusted P-value") ax.set_title(title) ax.invert_yaxis() # Reverse y-axis for top to bottom order elif type == "gsea": # Filter by qval_cutoff and sort by NES (Normalized Enrichment Score) enrichment = enrichment.loc[enrichment["NOM p-val"] < qval_cutoff, :] enrichment = enrichment[enrichment["NES"] > 0] enrichment = enrichment.sort_values("NES", ascending=False) enrichment["Term"] = enrichment["Term"].str.replace("_", " ") enrichment = enrichment.iloc[:n_enrichments, :] # Add -log q-value column for better visualization enrichment["-log_qval"] = -np.log( enrichment["FDR q-val"].astype("float") + 1e-7 ) # Set title if title == "auto": title = enrichment["Name"].iloc[0] # Create the plot fig, ax = plt.subplots(figsize=figsize) ax.barh(y=enrichment["Term"], width=enrichment["NES"], align="center") ax.set_xlabel("NES") ax.set_title(title) # Format y-axis labels ax.set_yticklabels( [textwrap.fill(term, 24) for term in enrichment["Term"].values] ) ax.invert_yaxis() # Reverse y-axis for top to bottom order else: raise ValueError("Unsupported enrichment type. Choose either 'enrichr' or 'gsea'.") plt.tight_layout() return ax
[docs] def niches_enrichment_dotplot( enrichments, niche, type="gsea", # 默认为 gsea 类型 figsize=(10, 5), n_enrichments=10, title="auto", cmap=None, qval_cutoff=0.05 ): """ Create a dotplot for the enrichment results (either from Enrichr or GSEA). Parameters ---------- enrichments : dict Dictionary of enrichment results where the key is the niche name and the value is a dataframe of enrichment results. niche : str The niche (cluster) to visualize enrichment for. type : str, optional The type of enrichment analysis ('enrichr' or 'gsea'). Default is 'gsea'. figsize : tuple, optional Figure size for the plot. n_enrichments : int, optional Number of top enrichment terms to display (default is 10). title : str, optional Title of the plot (default is 'auto', which uses the first gene set's name). cmap : matplotlib.colors.Colormap, optional Colormap for the plot. qval_cutoff : float, optional Adjusted p-value or FDR cutoff to filter enrichment terms (default is 0.05). Returns ------- ax : matplotlib.axes.Axes The matplotlib axis containing the plot. """ # Check if the niche exists in enrichments if niche not in enrichments: raise KeyError(f"Niche '{niche}' not found in enrichments.") # Extract enrichment data for the given niche enrichment = enrichments[niche] # Initialize figure and axis fig, ax = plt.subplots(figsize=figsize) if type == "enrichr": # Process Enrichr results enrichment["gene_size"] = enrichment["Overlap"].str.split("/").str[1].astype(int) enrichment["-log_qval"] = -np.log(enrichment["Adjusted P-value"]) enrichment["gene_ratio"] = enrichment["Overlap"].str.split("/").str[0].astype(int) / enrichment["gene_size"] # Filter by q-value cutoff enrichment = enrichment.loc[enrichment["Adjusted P-value"] < qval_cutoff, :] if enrichment.shape[0] < n_enrichments: n_enrichments = enrichment.shape[0] enrichment = enrichment.sort_values("gene_ratio", ascending=False) enrichment = enrichment.iloc[:n_enrichments, :] # Scatter plot scatter = ax.scatter( x=enrichment["gene_ratio"].values, y=enrichment["Term"].values, s=enrichment["gene_size"].values, c=enrichment["Combined Score"].values, cmap=cmap, ) ax.set_xlabel("Gene Ratio") # Legends legend1 = ax.legend( *scatter.legend_elements(prop="sizes", num=5), bbox_to_anchor=(1.04, 1), loc="upper left", title="Geneset Size", labelspacing=1, borderpad=1, ) ax.legend( *scatter.legend_elements(prop="colors", num=5), bbox_to_anchor=(1.04, 0), loc="lower left", title="Combined Score", labelspacing=1, borderpad=1, ) ax.add_artist(legend1) # Format y-axis labels ax.set_yticklabels([textwrap.fill(term, 24) for term in enrichment["Term"].values]) # Set plot title if title == "auto": ax.set_title(enrichment["Gene_set"].values[0]) elif type == "gsea": # Process GSEA results enrichment["gene_size"] = enrichment["Tag %"].str.split("/").str[1].astype(int) enrichment["-log_qval"] = -np.log(enrichment["FDR q-val"].astype(float) + 1e-7) enrichment["gene_ratio"] = enrichment["Tag %"].str.split("/").str[0].astype(int) / enrichment["gene_size"] # Filter by q-value cutoff enrichment = enrichment.loc[enrichment["FDR q-val"] < qval_cutoff, :] if enrichment.shape[0] < n_enrichments: n_enrichments = enrichment.shape[0] enrichment = enrichment.sort_values("-log_qval", ascending=False) enrichment = enrichment.iloc[:n_enrichments, :] # Scatter plot scatter = ax.scatter( x=enrichment["-log_qval"].values, y=enrichment["Term"].values, s=enrichment["gene_ratio"].values.astype(float), c=enrichment["NES"].values, cmap=cmap, ) ax.set_xlabel("-log q_val") # Legends legend1 = ax.legend( *scatter.legend_elements(prop="sizes", num=5), bbox_to_anchor=(1, 1), loc="upper left", title="Gene Ratio", labelspacing=1, borderpad=1, ) ax.legend( *scatter.legend_elements(prop="colors", num=5), bbox_to_anchor=(1, 0), loc="lower left", title="NES", labelspacing=1, borderpad=1, ) ax.add_artist(legend1) # Format y-axis labels ax.set_yticklabels([textwrap.fill(term, 30) for term in enrichment["Term"].values]) # Set plot title if title == "auto": ax.set_title(enrichment["Name"].values[0]) ax.invert_yaxis() # Reverse y-axis for top to bottom order else: raise ValueError("Unsupported enrichment type. Choose either 'enrichr' or 'gsea'.") # Tight layout for better visualization plt.tight_layout() return ax