Source code for spatialvi.model.base._deconvolution_mixin

from __future__ import annotations

import logging
from typing import TYPE_CHECKING

import pandas as pd

if TYPE_CHECKING:
    from anndata import AnnData

logger = logging.getLogger(__name__)


[docs] class SpatialDeconvolutionMixin: """Mixin for spatial deconvolution result formatting and visualization. Applied to: DestVI only. Requires the model to implement: - ``self.cell_type_mapping``: np.ndarray of cell type label strings - ``self.get_proportions(adata)``: returns np.ndarray of shape (n_spots, n_cell_types) """ def get_proportions_df(self, adata: AnnData | None = None) -> pd.DataFrame: """Return cell type proportions as a tidy DataFrame. Parameters ---------- adata AnnData object. If None, uses the model's registered adata. Returns ------- DataFrame of shape (n_spots, n_cell_types) with cell type names as columns. Rows sum to 1. """ if adata is None and hasattr(self, "adata"): adata = self.adata import inspect sig = inspect.signature(self.get_proportions) params = list(sig.parameters.keys()) # Call with adata only if the method accepts it as the first positional argument. if params and params[0] == "adata": proportions = self.get_proportions(adata) else: proportions = self.get_proportions() if isinstance(proportions, pd.DataFrame): return proportions # Wrap numpy array in a DataFrame with cell type names as columns. import numpy as np proportions = np.asarray(proportions) columns = list(self.cell_type_mapping) if hasattr(self, "cell_type_mapping") else None return pd.DataFrame(proportions, columns=columns) def plot_cell_type_map( self, adata: AnnData | None = None, cell_type: str | None = None, basis: str = "spatial", ax=None, **kwargs, ): """Plot spatial map of a single cell type's proportion. Parameters ---------- adata AnnData object. If None, uses the model's registered adata. cell_type Name of the cell type to visualize. Must be in ``self.cell_type_mapping``. basis Key in ``adata.obsm`` for spatial coordinates. ax Matplotlib axes. If None, a new figure is created. **kwargs Forwarded to :func:`scanpy.pl.embedding`. """ import scanpy as sc if adata is None and hasattr(self, "adata"): adata = self.adata df = self.get_proportions_df(adata) if cell_type is not None: if cell_type not in df.columns: raise ValueError( f"cell_type '{cell_type}' not found. Available: {list(df.columns)}" ) key = f"_spatialvi_prop_{cell_type}" adata.obs[key] = df[cell_type].values return sc.pl.embedding(adata, basis=basis, color=key, ax=ax, **kwargs) # No cell_type specified — plot all as a grid; write each column to obs first logger.info("No cell_type specified; plotting all %d cell types.", len(df.columns)) keys = [] for ct in df.columns: obs_key = f"_spatialvi_prop_{ct}" adata.obs[obs_key] = df[ct].values keys.append(obs_key) return sc.pl.embedding(adata, basis=basis, color=keys, **kwargs)