Source code for spatialvi.model._scviva

from __future__ import annotations

import logging
import warnings
from functools import partial
from typing import TYPE_CHECKING

import joblib
import numpy as np
import pandas as pd
import torch
from anndata import AnnData
from anndata import concat as anndata_concat
from rich import print
from scvi import REGISTRY_KEYS, settings
from scvi.data import AnnDataManager
from scvi.data._constants import _DATA_REGISTRY_KEY, _FIELD_REGISTRIES_KEY, _STATE_REGISTRY_KEY
from scvi.data._utils import _get_adata_minify_type
from scvi.data.fields import (
    CategoricalJointObsField,
    CategoricalObsField,
    LayerField,
    NumericalJointObsField,
    NumericalObsField,
    ObsmField,
)
from scvi.model._utils import _init_library_size, scrna_raw_counts_properties
from scvi.model.base import (
    ArchesMixin,
    BaseMinifiedModeModelClass,
    EmbeddingMixin,
    RNASeqMixin,
    UnsupervisedTrainingMixin,
    VAEMixin,
)
from scvi.model.base._archesmixin import _get_loaded_data
from scvi.model.base._de_core import _de_core
from scvi.utils import de_dsp, setup_anndata_dsp, unsupported_if_adata_minified

from spatialvi._constants import SCVIVA_REGISTRY_KEYS
from spatialvi.model.base import SpatialBaseModel, SpatialNeighborhoodMixin
from spatialvi.model.utils._scviva_de import _niche_de_core
from spatialvi.module._nichevae import nicheVAE

if TYPE_CHECKING:
    from collections.abc import Iterator, Sequence
    from typing import Literal

    from scvi.model.base import (
        BaseModelClass,
    )
    from torch import Tensor

    from spatialvi.model.utils._scviva_de import DifferentialExpressionResults

_SCVI_LATENT_QZM = "_scvi_latent_qzm"
_SCVI_LATENT_QZV = "_scvi_latent_qzv"
_SCVI_OBSERVED_LIB_SIZE = "_scvi_observed_lib_size"

logger = logging.getLogger(__name__)


[docs] class SCVIVA( SpatialNeighborhoodMixin, EmbeddingMixin, RNASeqMixin, VAEMixin, ArchesMixin, UnsupervisedTrainingMixin, SpatialBaseModel, BaseMinifiedModeModelClass, ): """scVIVA: variational auto-encoder with niche decoders for ST:cite:p:`Levy25`. Parameters ---------- adata AnnData object that has been registered via :meth:`~spatialvi.model.SCVIVA.setup_anndata`. If ``None``, then the underlying module will not be initialized until training, and a :class:`~lightning.pytorch.core.LightningDataModule` must be passed in during training. n_hidden Number of nodes per hidden layer. n_latent Dimensionality of the latent space. n_layers Number of hidden layers used for encoder and decoder NNs. dropout_rate Dropout rate for neural networks. dispersion One of the following: * ``'gene'`` - dispersion parameter of NB is constant per gene across cells * ``'gene-batch'`` - dispersion can differ between different batches * ``'gene-label'`` - dispersion can differ between different labels * ``'gene-cell'`` - dispersion can differ for every gene in every cell gene_likelihood One of: * ``'nb'`` - Negative binomial distribution * ``'zinb'`` - Zero-inflated negative binomial distribution * ``'poisson'`` - Poisson distribution latent_distribution One of: * ``'normal'`` - Normal distribution * ``'ln'`` - Logistic normal distribution (Normal(0, I) transformed by softmax) **kwargs Additional keyword arguments for :class:`~spatialvi.module._nichevae.nicheVAE`. Examples -------- >>> adata = anndata.read_h5ad(path_to_anndata) >>> spatialvi.model.SCVIVA.preprocessing_anndata( adata, k_nn = 20, sample_key = 'slide_ID', labels_key = "cell_type", cell_coordinates_key = "spatial", expression_embedding_key = "X_scVI", **kwargs ) >>> spatialvi.model.SCVIVA.setup_anndata(adata, batch_key="batch") >>> vae = spatialvi.model.SCVIVA(adata) >>> vae.train() >>> adata.obsm["X_scVIVA"] = vae.get_latent_representation() >>> adata.obsm["X_normalized_scVIVA"] = vae.get_normalized_expression() Notes ----- See further usage examples in the following tutorials: 1. :doc:`/tutorials/notebooks/spatial/scVIVA_tutorial` See Also -------- :class:`~spatialvi.module._nichevae.nicheVAE` """ _module_cls = nicheVAE
[docs] def __init__( self, adata: AnnData | None = None, n_hidden: int = 128, n_latent: int = 10, n_layers: int = 1, dropout_rate: float = 0.1, dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"] = "gene", gene_likelihood: Literal["zinb", "nb", "poisson"] = "poisson", latent_distribution: Literal["normal", "ln"] = "normal", **kwargs, ): super().__init__(adata) self.n_labels = self.summary_stats.n_labels self._module_kwargs = { "n_hidden": n_hidden, "n_latent": n_latent, "n_layers": n_layers, "dropout_rate": dropout_rate, "dispersion": dispersion, "gene_likelihood": gene_likelihood, "latent_distribution": latent_distribution, **kwargs, } self._model_summary_string = ( "scVIVA model with the following parameters: \n" f"n_hidden: {n_hidden}, n_latent: {n_latent}, n_layers: {n_layers}, " f"dropout_rate: {dropout_rate}, dispersion: {dispersion}, " f"gene_likelihood: {gene_likelihood}, latent_distribution: {latent_distribution}." ) if self._module_init_on_train: self.module = None warnings.warn( "Model was initialized without `adata`. The module will be initialized when " "calling `train`. This behavior is experimental and may change in the future.", UserWarning, stacklevel=settings.warnings_stacklevel, ) else: n_cats_per_cov = ( self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).n_cats_per_key if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry else None ) n_batch = self.summary_stats.n_batch use_size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry library_log_means, library_log_vars = None, None if not use_size_factor_key and self.minified_data_type is None: library_log_means, library_log_vars = _init_library_size( self.adata_manager, n_batch ) self.module = self._module_cls( n_input=self.summary_stats.n_vars, n_output_niche=self.summary_stats.n_latent_mean, n_batch=n_batch, n_labels=self.summary_stats.n_labels, n_continuous_cov=self.summary_stats.get("n_extra_continuous_covs", 0), n_cats_per_cov=n_cats_per_cov, n_hidden=n_hidden, n_latent=n_latent, n_layers=n_layers, dropout_rate=dropout_rate, dispersion=dispersion, gene_likelihood=gene_likelihood, latent_distribution=latent_distribution, use_size_factor_key=use_size_factor_key, library_log_means=library_log_means, library_log_vars=library_log_vars, **kwargs, ) self.module.minified_data_type = self.minified_data_type self.init_params_ = self._get_init_params(locals())
def get_latent_representation( self, adata=None, indices=None, give_mean: bool = True, batch_size=None, backend: str = "cpu", **kwargs, ): """Return latent representation with optional RAPIDS backend. Parameters ---------- backend ``"cpu"`` (default) or ``"rapids"`` for cupy array output. """ latent = super().get_latent_representation( adata=adata, indices=indices, give_mean=give_mean, batch_size=batch_size, **kwargs ) if backend == "rapids": try: import cupy as cp return cp.asarray(latent) except ImportError as e: raise ImportError( "backend='rapids' requires cupy. " "Install with: pip install 'spatialvi-tools[rapids]'" ) from e return latent @staticmethod def preprocessing_anndata( adata: AnnData, k_nn: int = 20, sample_key: str | None = None, labels_key: str = "cell_type", cell_coordinates_key: str = "spatial", expression_embedding_key: str = "X_scVI", expression_embedding_niche_key: str = "niche_activation", niche_composition_key: str = "niche_composition", niche_indexes_key: str = "niche_indexes", niche_distances_key: str = "niche_distances", log1p: bool = False, ) -> None: """Preprocess an AnnData object for scVIVA analysis. This function prepares the input AnnData object by computing niche indexes, neighborhood composition, and average latent space embeddings per cell type. Parameters ---------- adata : AnnData The annotated data matrix of shape `n_obs` x `n_vars`. Rows correspond to cells and columns to genes. k_nn : int, optional Number of nearest neighbors for niche computation. Default is 20. sample_key : str or None, optional Key in `adata.obs` for sample identifiers. Default is None. labels_key : str, optional Key in `adata.obs` for cell type labels. the Default is "cell_type". cell_coordinates_key : str, optional Key in `adata.obsm` for spatial coordinates. Default is "spatial". expression_embedding_key : str, optional Key in `adata.obsm` for latent space embeddings. the Default is "X_scVI". expression_embedding_niche_key : str, optional Key in `adata.obsm` where average latent embeddings per cell type are stored. the Default is "niche_activation". niche_composition_key : str, optional Key in `adata.obsm` where neighborhood composition is stored. the Default is "niche_composition". niche_indexes_key : str, optional Key in `adata.obsm` where niche indexes are stored. the Default is "niche_indexes". niche_distances_key : str, optional Key in `adata.obsm` where neighbor distances are stored. the Default is "niche_distances". log1p : bool, optional Whether to apply log1p to latent space embeddings. Default is False. Returns ------- None The function modifies the input AnnData object in place. """ get_niche_indexes( adata=adata, sample_key=sample_key, niche_indexes_key=niche_indexes_key, niche_distances_key=niche_distances_key, cell_coordinates_key=cell_coordinates_key, k_nn=k_nn, ) get_neighborhood_composition( adata=adata, cell_type_column=labels_key, indices_key=niche_indexes_key, niche_composition_key=niche_composition_key, ) get_average_latent_per_celltype( adata=adata, labels_key=labels_key, niche_indexes_key=niche_indexes_key, latent_mean_key=expression_embedding_key, latent_mean_ct_key=expression_embedding_niche_key, log1p=log1p, ) return None @classmethod @setup_anndata_dsp.dedent def setup_anndata( cls, adata: AnnData, layer: str | None = None, batch_key: str | None = None, size_factor_key: str | None = None, categorical_covariate_keys: list[str] | None = None, continuous_covariate_keys: list[str] | None = None, sample_key: str | None = None, labels_key: str = "cell_type", cell_coordinates_key: str = "spatial", expression_embedding_key: str = "X_scVI", expression_embedding_niche_key: str = "niche_activation", niche_composition_key: str = "niche_composition", niche_indexes_key: str = "niche_indexes", niche_distances_key: str = "niche_distances", **kwargs, ): """%(summary)s. Parameters ---------- %(param_adata)s %(param_layer)s %(param_batch_key)s %(param_labels_key)s %(param_size_factor_key)s %(param_cat_cov_keys)s %(param_cont_cov_keys)s """ setup_method_args = cls._get_setup_method_args(**locals()) anndata_fields = [ LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key), CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key), NumericalObsField(REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key, required=False), CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys), NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys), CategoricalObsField(SCVIVA_REGISTRY_KEYS.SAMPLE_KEY, sample_key), ObsmField(SCVIVA_REGISTRY_KEYS.NICHE_COMPOSITION_KEY, niche_composition_key), ObsmField(SCVIVA_REGISTRY_KEYS.CELL_COORDINATES_KEY, cell_coordinates_key), ObsmField(SCVIVA_REGISTRY_KEYS.NICHE_INDEXES_KEY, niche_indexes_key), ObsmField(SCVIVA_REGISTRY_KEYS.NICHE_DISTANCES_KEY, niche_distances_key), ObsmField(SCVIVA_REGISTRY_KEYS.Z1_MEAN_KEY, expression_embedding_key), ObsmField(SCVIVA_REGISTRY_KEYS.Z1_MEAN_CT_KEY, expression_embedding_niche_key), ] # register new fields if the adata is minified adata_minify_type = _get_adata_minify_type(adata) if adata_minify_type is not None: anndata_fields += cls._get_fields_for_adata_minification(adata_minify_type) adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) def _run_spatial_decoder(self, decoder_fn, adata, indices, batch_size): """Run a spatial decoder over mini-batches and return concatenated numpy output. Parameters ---------- decoder_fn Callable ``(decoder_input, batch_index) -> tensor | tuple``. When a tuple is returned its first element is collected. adata AnnData object or ``None`` (falls back to model adata). indices Cell indices or ``None`` for all cells. batch_size Mini-batch size. """ self._check_if_trained(warn=False) adata = self._validate_anndata(adata) scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) results = [] for tensors in scdl: inference_inputs = self.module._get_inference_input(tensors) outputs = self.module.inference(**inference_inputs) batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] decoder_input = outputs["qz"].loc batch_index = batch_index.to(decoder_input.device) out = decoder_fn(decoder_input, batch_index) if isinstance(out, tuple): out = out[0] results.append(out.detach().cpu()) return torch.cat(results).numpy() @torch.inference_mode() def predict_neighborhood( self, adata: AnnData | None = None, indices: np.ndarray | None = None, batch_size: int | None = 1024, ) -> np.ndarray: """ Predict the cell type composition of each cell niche in the dataset. Parameters ---------- adata AnnData object. If ``None``, the model's ``adata`` will be used. indices Indices of cells to use. If ``None``, all cells will be used. batch_size Minibatch size to use during inference. Returns ------- ct_prediction Predicted cell type composition of each cell niche in the dataset. It is computed as the expectation of the Dirichlet distribution. """ def _decoder(decoder_input, batch_index): dist = self.module.composition_decoder(decoder_input, batch_index) return dist.concentration / dist.concentration.sum(dim=1).unsqueeze(1) return self._run_spatial_decoder(_decoder, adata, indices, batch_size) @torch.inference_mode() def predict_niche_activation( self, adata: AnnData | None = None, indices: np.ndarray | None = None, batch_size: int | None = 1024, ) -> np.ndarray: """ Predict the activation of each cell niche in the dataset. Parameters ---------- adata AnnData object. If ``None``, the model's ``adata`` will be used. indices Indices of cells to use. If ``None``, all cells will be used. batch_size Minibatch size to use during inference. Returns ------- niche_activation Predicted activation of each cell niche in the dataset. """ return self._run_spatial_decoder(self.module.niche_decoder, adata, indices, batch_size) @de_dsp.dedent def differential_expression( self, adata: AnnData | None = None, groupby: str | None = None, group1: list[str] | None = None, group2: str | None = None, idx1: list[int] | list[bool] | str | None = None, idx2: list[int] | list[bool] | str | None = None, mode: Literal["vanilla", "change"] = "change", delta: float | list[float] = 0.15, batch_size: int | None = None, all_stats: bool = True, batch_correction: bool = False, batchid1: list[str] | None = None, batchid2: list[str] | None = None, fdr_target: float | list[float] = 0.05, silent: bool = False, weights: Literal["uniform", "importance"] | None = "uniform", filter_outlier_cells: bool = False, importance_weighting_kwargs: dict | None = None, ###### scVIVA specific ###### niche_mode: bool = True, radius: int | None = None, k_nn: int | None = None, n_restarts_optimizer_gpc: int = 10, path_to_save: str | None = None, **kwargs, ) -> DifferentialExpressionResults: r"""A unified method for differential expression analysis. Implements ``'vanilla'`` DE :cite:p:`Lopez18` and ``'change'`` mode DE :cite:p:`Boyeau19`. Adds a neighborhood component to the DE analysis :cite:p:`Levy25`. Parameters ---------- %(de_adata)s %(de_groupby)s %(de_group1)s %(de_group2)s %(de_idx1)s %(de_idx2)s %(de_mode)s %(de_delta)s %(de_batch_size)s %(de_all_stats)s %(de_batch_correction)s %(de_batchid1)s %(de_batchid2)s %(de_fdr_target)s %(de_silent)s weights Weights to use for sampling. If `None`, defaults to `"uniform"`. filter_outlier_cells Whether to filter outlier cells with :meth:`~scvi.model.base.DifferentialComputation.filter_outlier_cells`. importance_weighting_kwargs Keyword arguments passed into :meth:`~scvi.model.base.RNASeqMixin.get_importance_weights`. niche_mode Whether to use scVIVA DE or SCVI DE. radius Radius for scVIVA DE. k_nn Number of nearest neighbors for scVIVA DE. n_restarts_optimizer_gpc Number of restarts for the Gaussian Process Classifier optimization. path_to_save Path to save the results to, as a pickle file. **kwargs Keyword args for :meth:`scvi.model.base.DifferentialComputation.get_bayes_factors` Returns ------- Differential expression dataclass. """ adata = self._validate_anndata(adata) col_names = adata.var_names importance_weighting_kwargs = importance_weighting_kwargs or {} model_fn = partial( self.get_normalized_expression, return_numpy=True, n_samples=1, batch_size=batch_size, weights=weights, **importance_weighting_kwargs, ) representation_fn = self.get_latent_representation if filter_outlier_cells else None if niche_mode: result = _niche_de_core( self.get_anndata_manager(adata, required=True), model_fn, representation_fn, groupby, group1, group2, idx1, idx2, all_stats, scrna_raw_counts_properties, col_names, mode, batchid1, batchid2, delta, batch_correction, fdr_target, silent, radius=radius, k_nn=k_nn, n_restarts_optimizer_gpc=n_restarts_optimizer_gpc, **kwargs, ) else: result = _de_core( self.get_anndata_manager(adata, required=True), model_fn, representation_fn, groupby, group1, group2, idx1, idx2, all_stats, scrna_raw_counts_properties, col_names, mode, batchid1, batchid2, delta, batch_correction, fdr_target, silent, **kwargs, ) if path_to_save is not None: joblib.dump(result, path_to_save) return result @torch.inference_mode() @unsupported_if_adata_minified def get_composition_error( self, adata: AnnData | None = None, indices: Sequence[int] | None = None, batch_size: int | None = None, dataloader: Iterator[dict[str, Tensor | None]] = None, return_mean: bool = True, **kwargs, ) -> dict[str, float]: r"""Compute the composition prediction error on the data. The error is the negative log likelihood of the data (alpha) given the latent variables. This is typically written as :math:`p(alpha \mid z)`, the likelihood term given one posterior sample. Parameters ---------- adata :class:`~anndata.AnnData` object with :attr:`~anndata.AnnData.var_names` in the same order as the ones used to train the model. If ``None`` and ``dataloader`` is also ``None``, it defaults to the object used to initialize the model. indices Indices of observations in ``adata`` to use. If ``None``, defaults to all observations. Ignored if ``dataloader`` is not ``None`` batch_size Minibatch size for the forward pass. If ``None``, defaults to ``scvi.settings.batch_size``. Ignored if ``dataloader`` is not ``None`` dataloader An iterator over minibatches of data on which to compute the metric. The minibatches should be formatted as a dictionary of :class:`~torch.Tensor` with keys as expected by the model. If ``None``, a dataloader is created from ``adata``. return_mean Whether to return the mean reconstruction loss or the reconstruction loss for each observation. **kwargs Additional keyword arguments to pass into the forward method of the module. Returns ------- The composition prediction error on the data. Notes ----- This is not the negative reconstruction error, so higher is better. """ from spatialvi.module._nichevae import compute_composition_error if adata is not None and dataloader is not None: raise ValueError("Only one of `adata` or `dataloader` can be provided.") if dataloader is None: adata = self._validate_anndata(adata) dataloader = self._make_data_loader( adata=adata, indices=indices, batch_size=batch_size ) return compute_composition_error( self.module, dataloader, return_mean=return_mean, **kwargs ) @torch.inference_mode() @unsupported_if_adata_minified def get_niche_error( self, adata: AnnData | None = None, indices: Sequence[int] | None = None, batch_size: int | None = None, dataloader: Iterator[dict[str, Tensor | None]] = None, return_mean: bool = True, **kwargs, ) -> dict[str, float]: r"""Compute the niche state prediction error on the data. The error is the negative log likelihood of the data (eta) given the latent variables. This is typically written as :math:`p(eta \mid z)`, the likelihood term given one posterior sample. Parameters ---------- adata :class:`~anndata.AnnData` object with :attr:`~anndata.AnnData.var_names` in the same order as the ones used to train the model. If ``None`` and ``dataloader`` is also ``None``, it defaults to the object used to initialize the model. indices Indices of observations in ``adata`` to use. If ``None``, defaults to all observations. Ignored if ``dataloader`` is not ``None`` batch_size Minibatch size for the forward pass. If ``None``, defaults to ``scvi.settings.batch_size``. Ignored if ``dataloader`` is not ``None`` dataloader An iterator over minibatches of data on which to compute the metric. The minibatches should be formatted as a dictionary of :class:`~torch.Tensor` with keys as expected by the model. If ``None``, a dataloader is created from ``adata``. return_mean Whether to return the mean reconstruction loss or the reconstruction loss for each observation. **kwargs Additional keyword arguments to pass into the forward method of the module. Returns ------- The niche state prediction error of the data. Notes ----- This is not the negative reconstruction error, so higher is better. """ from spatialvi.module._nichevae import compute_niche_error if adata is not None and dataloader is not None: raise ValueError("Only one of `adata` or `dataloader` can be provided.") if dataloader is None: adata = self._validate_anndata(adata) dataloader = self._make_data_loader( adata=adata, indices=indices, batch_size=batch_size ) return compute_niche_error(self.module, dataloader, return_mean=return_mean, **kwargs) def preprocessing_query_anndata( self, adata: AnnData, reference_model: str | BaseModelClass, return_reference_var_names: bool = False, inplace: bool = True, k_nn: int = 20, sample_key: str | None = None, labels_key: str = "cell_type", cell_coordinates_key: str = "spatial", expression_embedding_key: str = "X_scVI", expression_embedding_niche_key: str = "niche_activation", niche_composition_key: str = "niche_composition", niche_indexes_key: str = "niche_indexes", niche_distances_key: str = "niche_distances", log1p: bool = False, ) -> AnnData | pd.Index | None: """Prepare data for query integration. Merges SCVIVA.preprocessing_anndata and ArchesMixin.prepare_query_anndata. This function will return a new AnnData object with padded zeros for missing features (genes, alpha, and eta), as well as correctly sorted features. Parameters ---------- adata AnnData organized in the same way as data used to train model. It is not necessary to run setup_anndata, as AnnData is validated against the ``registry``. reference_model Either an already instantiated model of the same class or a path to saved outputs for the reference model. return_reference_var_names Only load and return reference var names if True. inplace Whether to subset and rearrange query vars inplace or return new AnnData. k_nn : int, optional Number of nearest neighbors for niche computation. Default is 20. sample_key : str or None, optional Key in `adata.obs` for sample identifiers. Default is None. labels_key : str, optional Key in `adata.obs` for cell type labels. the Default is "cell_type". cell_coordinates_key : str, optional Key in `adata.obsm` for spatial coordinates. Default is "spatial". expression_embedding_key : str, optional Key in `adata.obsm` for latent space embeddings. the Default is "X_scVI". expression_embedding_niche_key : str, optional Key in `adata.obsm` where average latent embeddings per cell type are stored. the Default is "niche_activation". niche_composition_key : str, optional Key in `adata.obsm` where neighborhood composition is stored. the Default is "niche_composition". niche_indexes_key : str, optional Key in `adata.obsm` where niche indexes are stored. the Default is "niche_indexes". niche_distances_key : str, optional Key in `adata.obsm` where neighbor distances are stored. the Default is "niche_distances". log1p : bool, optional Whether to apply log1p to latent space embeddings. Default is False. Returns ------- Query adata ready to use in `load_query_data` unless `return_reference_var_names` in which case a pd.Index of reference var names is returned. """ SCVIVA.preprocessing_anndata( adata=adata, k_nn=k_nn, sample_key=sample_key, labels_key=labels_key, cell_coordinates_key=cell_coordinates_key, expression_embedding_key=expression_embedding_key, expression_embedding_niche_key=expression_embedding_niche_key, niche_composition_key=niche_composition_key, niche_indexes_key=niche_indexes_key, niche_distances_key=niche_distances_key, log1p=log1p, ) _, var_names, _, _ = _get_loaded_data(reference_model, device="cpu") var_names = pd.Index(var_names) if return_reference_var_names: return var_names reference_niche_composition_key = reference_model.registry[_FIELD_REGISTRIES_KEY][ SCVIVA_REGISTRY_KEYS.NICHE_COMPOSITION_KEY ][_DATA_REGISTRY_KEY]["attr_key"] assert reference_niche_composition_key == niche_composition_key, ( f"niche_composition_key in query ({niche_composition_key}) must match that " f"of reference ({reference_niche_composition_key})" ) reference_expression_embedding_niche_key = reference_model.registry[_FIELD_REGISTRIES_KEY][ SCVIVA_REGISTRY_KEYS.Z1_MEAN_CT_KEY ][_DATA_REGISTRY_KEY]["attr_key"] assert reference_expression_embedding_niche_key == expression_embedding_niche_key, ( f"expression_embedding_niche_key in query ({expression_embedding_niche_key}) must " f"match that of reference ({reference_expression_embedding_niche_key})" ) reference_label_names = reference_model.registry[_FIELD_REGISTRIES_KEY][ SCVIVA_REGISTRY_KEYS.NICHE_COMPOSITION_KEY ][_STATE_REGISTRY_KEY]["column_names"] reference_label_names = pd.Index(reference_label_names) query_label_names = pd.Index(adata.obsm[niche_composition_key].columns) return _pad_and_sort_query_anndata( adata, var_names, reference_label_names, query_label_names, niche_composition_key, expression_embedding_niche_key, inplace, )
def get_niche_indexes( adata: AnnData, sample_key: str, niche_indexes_key: str, cell_coordinates_key: str, k_nn: int, niche_distances_key: str, ) -> None: """Get the k nearest neighbors of each cell in the dataset, grouped per sample. The indexes of the neighbors are stored in adata.obsm[niche_indexes_key] and the distances to the neighbors are stored in adata.obsm[niche_distances_key]. Parameters ---------- adata Anndata object sample_key Key in adata.obs that contain the sample of each cell (i.e., the donor slice) niche_indexes_key Key in adata.obsm where the indexes of the neighbors will be stored cell_coordinates_key Key in adata.obsm that contains the spatial coordinates of each cell k_nn Number of nearest neighbors to compute niche_distances_key Key in adata.obsm where the distances to the neighbors will be stored Returns ------- None """ from sklearn.neighbors import NearestNeighbors adata.obsm[niche_indexes_key] = np.zeros( (adata.n_obs, k_nn) ) # for each cell, store the indexes of its k_nn neighbors adata.obsm[niche_distances_key] = np.zeros( (adata.n_obs, k_nn) ) # for each cell, store the distances to its k_nn neighbors adata.obs["index"] = np.arange(adata.shape[0]) # build a dictionary giving the index of each 'donor_slice' observation: donor_slice_index = {} for sample in adata.obs[sample_key].unique(): donor_slice_index[sample] = adata.obs[adata.obs[sample_key] == sample]["index"].values for sample in adata.obs[sample_key].unique(): sample_coord = adata.obsm[cell_coordinates_key][adata.obs[sample_key] == sample] # Create a NearestNeighbors object knn = NearestNeighbors(n_neighbors=k_nn + 1) # Fit the kNN model to the points knn.fit(sample_coord) # Find the indices of the kNN for each point distances, indices = knn.kneighbors(sample_coord) # Store the indices in the adata object sample_global_index = donor_slice_index[sample][indices].astype(int) adata.obsm[niche_indexes_key][adata.obs[sample_key] == sample] = sample_global_index[:, 1:] adata.obsm[niche_indexes_key] = adata.obsm[niche_indexes_key].astype(int) adata.obsm[niche_distances_key][adata.obs[sample_key] == sample] = distances[:, 1:] print( f"[bold cyan]Saved {niche_indexes_key} and {niche_distances_key} in adata.obsm[/bold cyan]" ) return None def get_neighborhood_composition( adata: AnnData, cell_type_column: str, indices_key: str = "niche_indexes", niche_composition_key: str = "niche_composition", ) -> None: """Get the composition of each neighborhood in the dataset (alpha).""" n_cell_types = len(adata.obs[cell_type_column].unique()) # number of cell types adata.obsm[niche_composition_key] = np.zeros( (adata.n_obs, n_cell_types) ) # for each cell, store the composition of its neighborhood # as a convex vector of the cell type proportions indices = adata.obsm[indices_key].astype(int) cell_types = adata.obs[cell_type_column].unique().tolist() cell_type_to_int = {cell_types[i]: i for i in range(len(cell_types))} adata.uns["cell_type_to_int"] = cell_type_to_int # Transform the query vector into an integer-valued vector integer_vector = np.vectorize(cell_type_to_int.get)(adata.obs[cell_type_column]) n_cells = adata.n_obs # For each cell, get the cell types of its neighbors cell_types_in_the_neighborhood = [integer_vector[indices[cell, :]] for cell in range(n_cells)] # Compute the composition of each neighborhood composition = np.array( [ np.bincount( cell_types_in_the_neighborhood[cell], minlength=len(cell_type_to_int), ) for cell in range(n_cells) ] ) # Normalize the composition of each neighborhood composition = composition / indices.shape[1] composition = np.array(composition) neighborhood_composition_df = pd.DataFrame( data=composition, columns=cell_types, index=adata.obs_names, ) adata.obsm[niche_composition_key] = neighborhood_composition_df print(f"[bold green]Saved {niche_composition_key} in adata.obsm[/bold green]") return None def get_average_latent_per_celltype( adata: AnnData, labels_key: str, niche_indexes_key: str, latent_mean_key: str, latent_mean_ct_key: str = "niche_activation", log1p: bool = False, ) -> None: """Get the average embedding per cell type in the dataset. For this one needs to provide the cell type of each cell in the dataset and an embedding for each cell, computed, for instance, with PCA, or scVI. Parameters ---------- adata Anndata object labels_key Key in adata.obs that contains the cell type of each cell niche_indexes_key Key in adata.obsm that contains the indexes of the neighbors of each cell latent_mean_key Key in adata.obsm that contains the expression embedding of each cell latent_mean_ct_key Key in adata.obsm where the average embedding per cell type will be stored log1p Whether the latent space is log-transformed Returns ------- None """ n_cells = adata.n_obs n_cell_types = len(adata.obs[labels_key].unique()) n_latent_z1 = adata.obsm[latent_mean_key].shape[1] niche_indexes = adata.obsm[niche_indexes_key] if log1p: z1_mean_niches = np.log1p(adata.obsm[latent_mean_key][niche_indexes]) else: z1_mean_niches = adata.obsm[latent_mean_key][niche_indexes] cell_type_to_int = adata.uns["cell_type_to_int"] integer_vector = np.vectorize(cell_type_to_int.get)(adata.obs[labels_key]) # For each cell, get the cell types of its neighbors (as integers) cell_types_in_the_neighborhood = np.vstack( [integer_vector[niche_indexes[cell, :]] for cell in range(n_cells)] ) dict_of_cell_type_indices = {} for cell_type, cell_type_idx in cell_type_to_int.items(): ct_row_indices, ct_col_indices = np.where(cell_types_in_the_neighborhood == cell_type_idx) # dict of cells:local index of the cells of cell_type in the neighborhood. result_dict = {} for row_idx, col_idx in zip(ct_row_indices, ct_col_indices, strict=False): result_dict.setdefault(row_idx, []).append(col_idx) dict_of_cell_type_indices[cell_type] = result_dict latent_mean_ct_prior = np.zeros((n_cell_types, n_latent_z1)) z1_mean_niches_ct = np.stack( [latent_mean_ct_prior] * n_cells, axis=0 ) # batch times n_cell_types times n_latent. Initialize your prior with a non-spatial average. # outer loop over cell types for cell_type, cell_type_idx in cell_type_to_int.items(): ct_dict = dict_of_cell_type_indices[cell_type] # inner loop over every cell that has this cell type in its neighborhood. for cell_idx, neighbor_idxs in ct_dict.items(): z1_mean_niches_ct[cell_idx, cell_type_idx, :] = np.mean( z1_mean_niches[cell_idx, neighbor_idxs, :], axis=0 ) adata.obsm[latent_mean_ct_key] = z1_mean_niches_ct print(f"[bold green]Saved {latent_mean_ct_key} in adata.obsm[/bold green]") return None def _pad_and_sort_query_anndata( adata: AnnData, reference_var_names: pd.Index, reference_label_names: pd.Index, query_label_names: pd.Index, niche_composition_key: str, expression_embedding_niche_key: str, inplace: bool, min_var_name_ratio: float = 0.8, ) -> AnnData | None: r""" Pad and sort anndata to match reference var names. Also covers \alpha and \eta niche features. """ intersection_genes = adata.var_names.intersection(reference_var_names) inter_len_genes = len(intersection_genes) if inter_len_genes == 0: raise ValueError( "No reference var names found in query data. " "Please rerun with return_reference_var_names=True " "to see reference var names." ) ratio = inter_len_genes / len(reference_var_names) logger.info(f"Found {ratio * 100}% reference vars in query data.") if ratio < min_var_name_ratio: warnings.warn( f"Query data contains less than {min_var_name_ratio:.0%} of reference " "var names. This may result in poor performance.", UserWarning, stacklevel=settings.warnings_stacklevel, ) missing_in_query = reference_label_names.difference(query_label_names) extra_in_query = query_label_names.difference(reference_label_names) if len(missing_in_query) > 0: logger.info(f"Labels not observed in query: {sorted(missing_in_query.tolist())}") if len(extra_in_query) > 0: raise ValueError( f"Label(s) observed in query but not in reference: {sorted(extra_in_query.tolist())}" ) # pad alpha composition if needed if len(missing_in_query) > 0: # add missing columns for ct in missing_in_query: adata.obsm[niche_composition_key][ct] = 0.0 # sort columns anyway to match reference adata.obsm[niche_composition_key] = adata.obsm[niche_composition_key][reference_label_names] pad_and_sorted_query_label_names = pd.Index(adata.obsm[niche_composition_key].columns) assert pad_and_sorted_query_label_names.equals(reference_label_names), ( "Error when sorting query label names to match reference." ) # pad eta niche activation if needed cell_type_to_int = adata.uns["cell_type_to_int"] if len(missing_in_query) > 0: # Update with missing labels: add zero embeddings for ct in missing_in_query: cell_type_to_int[ct] = len(cell_type_to_int) # Add zeros for this new cell type zeros = np.zeros( (adata.n_obs, 1, adata.obsm[expression_embedding_niche_key].shape[-1]) ) z_niche = np.concatenate([adata.obsm[expression_embedding_niche_key], zeros], axis=1) else: z_niche = adata.obsm[expression_embedding_niche_key] # reorder axis 1 to match reference_label_names idxs = np.array([cell_type_to_int[ct] for ct in reference_label_names]) adata.obsm[expression_embedding_niche_key] = z_niche[:, idxs, :] genes_to_add = reference_var_names.difference(adata.var_names) needs_padding = len(genes_to_add) > 0 if needs_padding: from scipy.sparse import csr_matrix padding_mtx = csr_matrix(np.zeros((adata.n_obs, len(genes_to_add)))) adata_padding = AnnData( X=padding_mtx.copy(), layers={layer: padding_mtx.copy() for layer in adata.layers}, ) adata_padding.var_names = genes_to_add adata_padding.obs_names = adata.obs_names # Concatenate object adata_out = anndata_concat( [adata, adata_padding], axis=1, join="outer", index_unique=None, merge="unique", ) else: adata_out = adata # also covers the case when new adata has more var names than old if not reference_var_names.equals(adata_out.var_names): adata_out._inplace_subset_var(reference_var_names) if inplace: if adata_out is not adata: adata._init_as_actual(adata_out) else: return adata_out