Source code for spatialvi.external.stereoscope._model

"""Stereoscope models for spatial transcriptomics deconvolution."""

from __future__ import annotations

import logging
from typing import TYPE_CHECKING

import numpy as np
import pandas as pd
import torch
from scvi import REGISTRY_KEYS
from scvi.data import AnnDataManager
from scvi.data.fields import CategoricalObsField, LayerField, NumericalObsField
from scvi.model.base import UnsupervisedTrainingMixin
from scvi.train._config import merge_kwargs
from scvi.utils import setup_anndata_dsp
from scvi.utils._docstrings import devices_dsp

from spatialvi.external.stereoscope._module import RNADeconv, SpatialDeconv
from spatialvi.model.base._deconvolution_mixin import SpatialDeconvolutionMixin
from spatialvi.model.base._spatial_base import SpatialBaseModel

if TYPE_CHECKING:
    from typing import Literal

    from anndata import AnnData

logger = logging.getLogger(__name__)


[docs] class RNAStereoscope(UnsupervisedTrainingMixin, SpatialBaseModel): """Reimplementation of Stereoscope for the scRNA-seq component :cite:p:`Andersson20`. Trains the RNA model whose parameters are then transferred to :class:`~spatialvi.external.SpatialStereoscope` for spatial deconvolution. Original implementation: https://github.com/almaan/stereoscope. Parameters ---------- sc_adata Single-cell AnnData registered via :meth:`~spatialvi.external.RNAStereoscope.setup_anndata`. **model_kwargs Keyword args for :class:`~spatialvi.external.stereoscope.RNADeconv`. Examples -------- >>> scvi.external.RNAStereoscope.setup_anndata(sc_adata, labels_key="labels") >>> sc_model = RNAStereoscope(sc_adata) >>> sc_model.train() Notes ----- See further usage examples in the following tutorial: 1. :doc:`/tutorials/notebooks/spatial/stereoscope_heart_LV_tutorial` """
[docs] def __init__(self, sc_adata: AnnData, **model_kwargs): super().__init__(sc_adata) self.n_genes = self.summary_stats.n_vars self.n_labels = self.summary_stats.n_labels self.module = RNADeconv( n_genes=self.n_genes, n_labels=self.n_labels, **model_kwargs, ) self._model_summary_string = ( f"RNADeconv Model with params: \nn_genes: {self.n_genes}, n_labels: {self.n_labels}" ) self.init_params_ = self._get_init_params(locals())
@devices_dsp.dedent def train( self, max_epochs: int = 400, lr: float = 0.01, accelerator: str = "auto", devices: int | list[int] | str = "auto", train_size: float = 1, validation_size: float | None = None, shuffle_set_split: bool = True, batch_size: int = 128, datasplitter_kwargs: dict | None = None, plan_kwargs: dict | None = None, **kwargs, ): """Train the model using MAP inference. Parameters ---------- max_epochs Number of epochs to train for. lr Learning rate for optimization. %(param_accelerator)s %(param_devices)s train_size Size of the training set in [0.0, 1.0]. validation_size Size of the test set. shuffle_set_split Whether to shuffle indices before splitting. batch_size Minibatch size. datasplitter_kwargs Additional kwargs for :class:`~scvi.dataloaders.DataSplitter`. plan_kwargs Keyword args for :class:`~scvi.train.TrainingPlan`. **kwargs Other keyword args for :class:`~scvi.train.Trainer`. """ plan_kwargs = merge_kwargs(None, plan_kwargs, name="plan") plan_kwargs.update({"lr": lr}) super().train( max_epochs=max_epochs, accelerator=accelerator, devices=devices, train_size=train_size, validation_size=validation_size, shuffle_set_split=shuffle_set_split, batch_size=batch_size, datasplitter_kwargs=datasplitter_kwargs, plan_kwargs=plan_kwargs, **kwargs, ) @classmethod @setup_anndata_dsp.dedent def setup_anndata( cls, adata: AnnData, labels_key: str | None = None, layer: str | None = None, **kwargs, ): """%(summary)s. Parameters ---------- %(param_labels_key)s %(param_layer)s """ setup_method_args = cls._get_setup_method_args(**locals()) anndata_fields = [ LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key), ] adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager)
[docs] class SpatialStereoscope(SpatialDeconvolutionMixin, UnsupervisedTrainingMixin, SpatialBaseModel): """Reimplementation of Stereoscope for the spatial component :cite:p:`Andersson20`. Deconvolves spatial transcriptomics spots into cell type proportions using parameters learned by a pre-trained :class:`~spatialvi.external.RNAStereoscope` model. Inherits :class:`~spatialvi.model.base.SpatialDeconvolutionMixin` which provides :meth:`get_proportions_df` and :meth:`plot_cell_type_map`. Parameters ---------- st_adata Spatial AnnData registered via :meth:`~spatialvi.external.SpatialStereoscope.setup_anndata`. sc_params Parameters from the RNA model (from :meth:`~spatialvi.external.RNAStereoscope.get_params`). cell_type_mapping numpy array mapping for the cell types used in the deconvolution. prior_weight How to reweight minibatches. ``"n_obs"`` is statistically correct; ``"minibatch"`` reproduces the original Stereoscope paper. **model_kwargs Keyword args for :class:`~spatialvi.external.stereoscope.SpatialDeconv`. Examples -------- >>> RNAStereoscope.setup_anndata(sc_adata, labels_key="labels") >>> sc_model = RNAStereoscope(sc_adata) >>> sc_model.train() >>> SpatialStereoscope.setup_anndata(st_adata) >>> st_model = SpatialStereoscope.from_rna_model(st_adata, sc_model) >>> st_model.train() >>> st_adata.obsm["deconv"] = st_model.get_proportions() Notes ----- See further usage examples in the following tutorial: 1. :doc:`/tutorials/notebooks/spatial/stereoscope_heart_LV_tutorial` """
[docs] def __init__( self, st_adata: AnnData, sc_params: tuple[np.ndarray], cell_type_mapping: np.ndarray, prior_weight: Literal["n_obs", "minibatch"] = "n_obs", **model_kwargs, ): super().__init__(st_adata) self.module = SpatialDeconv( n_spots=st_adata.n_obs, sc_params=sc_params, prior_weight=prior_weight, **model_kwargs, ) self._model_summary_string = ( f"SpatialDeconv Model with params: \nn_spots: {st_adata.n_obs}" ) self.cell_type_mapping = cell_type_mapping self.init_params_ = self._get_init_params(locals())
@classmethod def from_rna_model( cls, st_adata: AnnData, sc_model: RNAStereoscope, prior_weight: Literal["n_obs", "minibatch"] = "n_obs", **model_kwargs, ): """Alternate constructor using a pre-trained RNA model. Parameters ---------- st_adata Registered spatial AnnData. sc_model Trained :class:`~spatialvi.external.RNAStereoscope` model. prior_weight How to reweight minibatches for stochastic optimization. **model_kwargs Keyword args for :class:`~spatialvi.external.stereoscope.SpatialDeconv`. """ return cls( st_adata, sc_model.module.get_params(), sc_model.adata_manager.get_state_registry( REGISTRY_KEYS.LABELS_KEY ).categorical_mapping, prior_weight=prior_weight, **model_kwargs, ) def get_proportions(self, keep_noise: bool = False) -> pd.DataFrame: """Return the estimated cell type proportions for the spatial data. Shape is (n_spots, n_labels) or (n_spots, n_labels + 1) if keep_noise. Parameters ---------- keep_noise Whether to include the noise term as a standalone cell type. """ self._check_if_trained() column_names = self.cell_type_mapping if keep_noise: column_names = np.append(column_names, "noise_term") return pd.DataFrame( data=self.module.get_proportions(keep_noise), columns=column_names, index=self.adata.obs.index, ) def get_scale_for_ct(self, y: np.ndarray) -> np.ndarray: r"""Calculate the cell-type-specific expression. Parameters ---------- y Array of cell type names to query. Returns ------- Gene expression array of shape (n_query, n_genes). """ self._check_if_trained() ind_y = np.array([np.where(ct == self.cell_type_mapping)[0][0] for ct in y]) if ind_y.shape != y.shape: raise ValueError( "Incorrect shape after matching cell types to reference mapping. " "Please check cell type query." ) px_scale = self.module.get_ct_specific_expression(torch.tensor(ind_y)[:, None]) return np.array(px_scale.cpu()) @devices_dsp.dedent def train( self, max_epochs: int = 400, lr: float = 0.01, accelerator: str = "auto", devices: int | list[int] | str = "auto", shuffle_set_split: bool = True, batch_size: int = 128, datasplitter_kwargs: dict | None = None, plan_kwargs: dict | None = None, **kwargs, ): """Train the model using MAP inference. Parameters ---------- max_epochs Number of epochs to train for. lr Learning rate for optimization. %(param_accelerator)s %(param_devices)s shuffle_set_split Whether to shuffle indices before splitting. batch_size Minibatch size. datasplitter_kwargs Additional kwargs for :class:`~scvi.dataloaders.DataSplitter`. plan_kwargs Keyword args for :class:`~scvi.train.TrainingPlan`. **kwargs Other keyword args for :class:`~scvi.train.Trainer`. """ plan_kwargs = merge_kwargs(None, plan_kwargs, name="plan") plan_kwargs.update({"lr": lr}) super().train( max_epochs=max_epochs, accelerator=accelerator, devices=devices, train_size=1, validation_size=None, shuffle_set_split=shuffle_set_split, batch_size=batch_size, datasplitter_kwargs=datasplitter_kwargs, plan_kwargs=plan_kwargs, **kwargs, ) @classmethod @setup_anndata_dsp.dedent def setup_anndata( cls, adata: AnnData, layer: str | None = None, **kwargs, ): """%(summary)s. Parameters ---------- %(param_layer)s """ setup_method_args = cls._get_setup_method_args(**locals()) adata.obs["_indices"] = np.arange(adata.n_obs) anndata_fields = [ LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), NumericalObsField(REGISTRY_KEYS.INDICES_KEY, "_indices"), ] adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager)