from __future__ import annotations
import logging
from collections import OrderedDict
from typing import TYPE_CHECKING
import numpy as np
import pandas as pd
import torch
from scipy.sparse import csr_matrix
from scvi import REGISTRY_KEYS
from scvi.data import AnnDataManager
from scvi.data._constants import _SETUP_ARGS_KEY
from scvi.data.fields import CategoricalObsField, LayerField, NumericalObsField
from scvi.model.base import UnsupervisedTrainingMixin
from scvi.model.base._archesmixin import _get_loaded_data
from scvi.train._config import merge_kwargs
from scvi.utils import setup_anndata_dsp
from scvi.utils._docstrings import devices_dsp
from spatialvi.model.base import SpatialBaseModel, SpatialDeconvolutionMixin
from spatialvi.module._mrdeconv import MRDeconv
if TYPE_CHECKING:
from collections.abc import Sequence
from anndata import AnnData
from scvi.model import CondSCVI
logger = logging.getLogger(__name__)
[docs]
class DestVI(
SpatialDeconvolutionMixin,
UnsupervisedTrainingMixin,
SpatialBaseModel,
):
"""Multi-resolution deconvolution of Spatial Transcriptomics data (DestVI) :cite:p:`Lopez22`.
Most users will use the alternate constructor (see example).
Parameters
----------
st_adata
spatial transcriptomics AnnData object that has been registered via
:meth:`~spatialvi.model.DestVI.setup_anndata`.
cell_type_mapping
mapping between numerals and cell type labels
decoder_state_dict
state_dict from the decoder of the CondSCVI model
px_decoder_state_dict
state_dict from the px_decoder of the CondSCVI model
px_r
parameters for the px_r tensor in the CondSCVI model
n_hidden
Number of nodes per hidden layer.
n_latent
Dimensionality of the latent space.
n_layers
Number of hidden layers used for encoder and decoder NNs.
**module_kwargs
Keyword args for :class:`~spatialvi.module.MRDeconv`
Examples
--------
>>> sc_adata = anndata.read_h5ad(path_to_scRNA_anndata)
>>> scvi.model.CondSCVI.setup_anndata(sc_adata)
>>> sc_model = scvi.model.CondSCVI(sc_adata)
>>> st_adata = anndata.read_h5ad(path_to_ST_anndata)
>>> DestVI.setup_anndata(st_adata)
>>> spatial_model = DestVI.from_rna_model(st_adata, sc_model)
>>> spatial_model.train(max_epochs=2000)
>>> st_adata.obsm["proportions"] = spatial_model.get_proportions(st_adata)
>>> gamma = spatial_model.get_gamma(st_adata)
Notes
-----
See further usage examples in the following tutorials:
1. :doc:`/tutorials/notebooks/spatial/DestVI_tutorial`
2. :doc:`/tutorials/notebooks/r/DestVI_in_R`
"""
_module_cls = MRDeconv
[docs]
def __init__(
self,
st_adata: AnnData,
cell_type_mapping: np.ndarray,
decoder_state_dict: OrderedDict,
px_decoder_state_dict: OrderedDict,
px_r: torch.tensor,
per_ct_bias: torch.tensor,
n_hidden: int,
n_latent: int,
n_layers: int,
dropout_decoder: float,
**module_kwargs,
):
super().__init__(st_adata)
self.module = self._module_cls(
n_spots=st_adata.n_obs,
n_labels=cell_type_mapping.shape[0],
n_batch=self.summary_stats.n_batch,
decoder_state_dict=decoder_state_dict,
px_decoder_state_dict=px_decoder_state_dict,
px_r=px_r,
per_ct_bias=per_ct_bias,
n_genes=st_adata.n_vars,
n_latent=n_latent,
n_layers=n_layers,
n_hidden=n_hidden,
dropout_decoder=dropout_decoder,
**module_kwargs,
)
self.cell_type_mapping = cell_type_mapping
self.cell_type_mapping_extended = list(self.cell_type_mapping) + [
f"additional_{i}" for i in range(self.module.add_celltypes)
]
self._model_summary_string = "DestVI Model"
self.init_params_ = self._get_init_params(locals())
@classmethod
def from_rna_model(
cls,
st_adata: AnnData,
sc_model: CondSCVI,
vamp_prior_p: int = 15,
anndata_setup_kwargs: dict | None = None,
**module_kwargs,
):
"""Alternate constructor for exploiting a pre-trained model on a RNA-seq dataset.
Parameters
----------
st_adata
registered anndata object
sc_model
trained CondSCVI model
vamp_prior_p
number of mixture parameter for VampPrior calculations
anndata_setup_kwargs
Keyword args for :meth:`~spatialvi.model.DestVI.setup_anndata`
**model_kwargs
Keyword args for :class:`~spatialvi.model.DestVI`
"""
attr_dict, var_names, load_state_dict, _ = _get_loaded_data(sc_model)
registry = attr_dict.pop("registry_")
decoder_state_dict = OrderedDict(
(i[8:], load_state_dict[i]) for i in load_state_dict if i.split(".")[0] == "decoder"
)
px_decoder_state_dict = OrderedDict(
(i[11:], load_state_dict[i])
for i in load_state_dict
if i.split(".")[0] == "px_decoder"
)
px_r = load_state_dict["px_r"]
per_ct_bias = load_state_dict["per_ct_bias"]
mapping = registry["field_registries"]["labels"]["state_registry"]["categorical_mapping"]
dropout_decoder = attr_dict["init_params_"]["non_kwargs"]["dropout_rate"]
if vamp_prior_p is None:
mean_vprior = None
var_vprior = None
elif attr_dict["init_params_"]["kwargs"]["module_kwargs"].get("prior") == "mog":
mean_vprior = load_state_dict["prior_means"].clone().detach()
var_vprior = torch.exp(load_state_dict["prior_log_std"]) ** 2
mp_vprior = torch.nn.Softmax(dim=-1)(load_state_dict["prior_logits"])
else:
assert sc_model is not str, (
"VampPrior requires loading CondSCVI model and providing it"
)
vamp_prior = sc_model.get_vamp_prior(sc_model.adata, p=vamp_prior_p)
mean_vprior = torch.tensor(vamp_prior["mean_vprior"], dtype=torch.float32)
var_vprior = torch.tensor(vamp_prior["var_vprior"], dtype=torch.float32)
mp_vprior = torch.tensor(vamp_prior["weights_vprior"], dtype=torch.float32)
if anndata_setup_kwargs is None:
anndata_setup_kwargs = {}
cls.setup_anndata(
st_adata,
source_registry=registry,
extend_categories=True,
**anndata_setup_kwargs,
**registry[_SETUP_ARGS_KEY],
)
return cls(
st_adata,
mapping,
decoder_state_dict,
px_decoder_state_dict,
px_r,
per_ct_bias,
sc_model.module.n_hidden,
sc_model.module.n_latent,
sc_model.module.n_layers,
mean_vprior=mean_vprior,
var_vprior=var_vprior,
mp_vprior=mp_vprior,
dropout_decoder=dropout_decoder,
**module_kwargs,
)
@torch.inference_mode()
def get_proportions(
self,
keep_additional: bool = False,
normalize: bool = True,
indices: Sequence[int] | None = None,
batch_size: int | None = None,
) -> pd.DataFrame:
"""Returns the estimated cell type proportion for the spatial data.
Shape is n_cells x n_labels OR n_cells x (n_labels + add_celltypes) if keep_additional.
Parameters
----------
keep_additional
whether to account for the additional cell-types as standalone cell types in the
proportion estimate.
normalize
whether to normalize the proportions to sum to 1.
indices
Indices of cells in adata to use. Only used if amortization. If `None`, all cells are
used.
batch_size
Minibatch size for data loading into model. Only used if amortization. Defaults to
`scvi.settings.batch_size`.
"""
self._check_if_trained()
column_names = self.cell_type_mapping
index_names = self.adata.obs.index
if keep_additional:
column_names = list(self.cell_type_mapping_extended)
else:
column_names = list(self.cell_type_mapping)
if self.module.amortization in ["both", "proportion"]:
stdl = self._make_data_loader(adata=self.adata, indices=indices, batch_size=batch_size)
prop_ = []
for tensors in stdl:
inference_inputs = self.module._get_inference_input(tensors)
outputs = self.module.inference(**inference_inputs)
generative_inputs = self.module._get_generative_input(tensors, outputs)
prop_local = self.module.generative(**generative_inputs)["v"][0, ...]
prop_ += [prop_local.cpu()]
data = torch.cat(prop_).detach().numpy()
if indices:
index_names = index_names[indices]
else:
data = (
torch.nn.functional.softplus(self.module.V).transpose(0, 1).detach().cpu().numpy()
)
if not keep_additional:
data = data[:, : -self.module.add_celltypes]
if normalize:
data = data / data.sum(axis=1, keepdims=True)
return pd.DataFrame(
data=data,
columns=column_names,
index=index_names,
)
@torch.inference_mode()
def get_fine_celltypes(
self,
sc_model: CondSCVI,
indices=None,
batch_size: int | None = None,
) -> np.ndarray | dict[str, pd.DataFrame]:
"""Returns the estimated cell-type specific latent space for the spatial data.
Parameters
----------
sc_model
trained CondSCVI model
indices
Indices of cells in adata to use. Only used if amortization.
If `None`, all cells are used.
batch_size
Minibatch size for data loading into model. Only used if amortization.
Defaults to `scvi.settings.batch_size`.
"""
self._check_if_trained()
index_names = self.adata.obs.index
stdl = self._make_data_loader(adata=self.adata, indices=indices, batch_size=batch_size)
if sc_model.n_fine_labels is None:
raise RuntimeError(
"Single cell model does not contain fine labels. "
"Please train the single-cell model with fine labels."
)
predicted_fine_celltype_ = []
for tensors in stdl:
inference_inputs = self.module._get_inference_input(tensors)
outputs = self.module.inference(**inference_inputs)
generative_inputs = self.module._get_generative_input(tensors, outputs)
generative_outputs = self.module.generative(**generative_inputs)
gamma_local = generative_outputs["gamma"][0, ...].transpose(-2, -4) # c, n, p, m
proportions_modes_local = generative_outputs["proportion_modes"][0, ...] # pmc
n_modes, batch_size, n_celltypes = proportions_modes_local.shape
gamma_local_ = gamma_local.permute((3, 2, 0, 1)).reshape(
-1, self.module.n_latent
) # m*p*c, n
proportions_modes_local_ = proportions_modes_local.permute(
(1, 0, 2)
).flatten() # m*p*c
v_local = (
generative_outputs["v"][0, ..., : -self.module.add_celltypes]
.flatten()
.repeat_interleave(n_modes)
) # m*p*c
label = (
torch.arange(self.module.n_labels, device=gamma_local.device)
.repeat(batch_size)
.repeat_interleave(n_modes)
.unsqueeze(-1)
) # m*p*c, 1
predicted_fine_celltype_local = (
v_local.unsqueeze(-1)
* proportions_modes_local_.unsqueeze(-1)
* torch.nn.functional.softmax(
sc_model.module.classify(gamma_local_, label), dim=-1
)
)
predicted_fine_celltype_sum = predicted_fine_celltype_local.reshape(
batch_size, n_celltypes * n_modes, sc_model.n_fine_labels
).sum(1)
predicted_fine_celltype_.append(predicted_fine_celltype_sum.detach().cpu())
predicted_fine_celltype = torch.cat(predicted_fine_celltype_, dim=0).numpy()
pred = pd.DataFrame(
predicted_fine_celltype,
columns=sc_model._fine_label_mapping,
index=index_names,
)
return pred
@torch.inference_mode()
def get_gamma(
self,
indices: Sequence[int] | None = None,
batch_size: int | None = None,
return_numpy: bool = False,
) -> np.ndarray | dict[str, pd.DataFrame]:
"""Returns the estimated cell-type specific latent space for the spatial data.
Parameters
----------
indices
Indices of cells in adata to use. Only used if amortization. If `None`, all cells are
used.
batch_size
Minibatch size for data loading into model. Only used if amortization. Defaults to
`scvi.settings.batch_size`.
return_numpy
if activated, will return a numpy array of shape is n_spots x n_latent x n_labels.
"""
self._check_if_trained()
column_names = [str(i) for i in np.arange(self.module.n_latent)]
index_names = self.adata.obs.index
if self.module.amortization in ["both", "latent"]:
stdl = self._make_data_loader(adata=self.adata, indices=indices, batch_size=batch_size)
gamma_ = []
for tensors in stdl:
inference_inputs = self.module._get_inference_input(tensors)
outputs = self.module.inference(**inference_inputs)
generative_inputs = self.module._get_generative_input(tensors, outputs)
generative_outputs = self.module.generative(**generative_inputs)
gamma_local = generative_outputs["gamma"][0, ...]
if self.module.prior_mode == "mog":
proportions_model_local = generative_outputs["proportion_modes"][0, ...]
gamma_local = torch.einsum(
"pncm,pmc->ncm", gamma_local, proportions_model_local
)
else:
gamma_local = gamma_local[0, ...].squeeze(0)
gamma_ += [gamma_local.cpu()]
data = torch.cat(gamma_, dim=-1).numpy()
if indices is not None:
index_names = index_names[indices]
else:
data = self.module.gamma.detach().cpu().numpy()
data = np.transpose(data, (2, 0, 1))
if return_numpy:
return data
else:
res = {}
for i, ct in enumerate(self.cell_type_mapping):
res[ct] = pd.DataFrame(data=data[:, :, i], columns=column_names, index=index_names)
return res
@torch.inference_mode()
def get_latent_representation(
self,
adata: AnnData | None = None,
indices: Sequence[int] | None = None,
give_mean: bool = True,
mc_samples: int = 5000,
batch_size: int | None = None,
return_dist: bool = False,
backend: str = "cpu",
) -> np.ndarray:
"""Return the latent representation for each cell.
This is typically denoted as :math:`z_n`.
Parameters
----------
adata
AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
AnnData object used to initialize the model.
indices
Indices of cells in adata to use. If `None`, all cells are used.
give_mean
Give mean of distribution or sample from it.
mc_samples
For distributions with no closed-form mean (e.g., `logistic normal`),
how many Monte Carlo samples to take for computing mean.
batch_size
Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
return_dist
Return (mean, variance) of distributions instead of just the mean.
If `True`, ignores `give_mean` and `mc_samples`. In the case of the latter,
`mc_samples` is used to compute the mean of a transformed distribution.
If `return_dist` is true the untransformed mean and variance are returned.
Returns
-------
Low-dimensional representation for each cell or a tuple containing its mean and variance.
"""
assert self.module.n_latent_amortization is not None, (
"Model has no latent representation for amortized values."
)
self._check_if_trained(warn=False)
adata = self._validate_anndata(adata)
scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size)
latent = []
latent_qzm = []
latent_qzv = []
for tensors in scdl:
inference_inputs = self.module._get_inference_input(tensors)
inference_outputs = self.module.inference(**inference_inputs, n_samples=mc_samples)
z = inference_outputs["z"][0, ...]
qz = inference_outputs["qz"]
if give_mean:
latent += [qz.loc[0, ...].cpu()]
else:
latent += [z.cpu()]
latent_qzm += [qz.loc[0, ...].cpu()]
latent_qzv += [qz.scale[0, ...].square().cpu()]
latent = (
(torch.cat(latent_qzm).numpy(), torch.cat(latent_qzv).numpy())
if return_dist
else torch.cat(latent).numpy()
)
if backend == "rapids":
try:
import cupy as cp
if return_dist:
return tuple(cp.asarray(x) for x in latent)
return cp.asarray(latent)
except ImportError as e:
raise ImportError(
"backend='rapids' requires cupy. "
"Install with: pip install 'spatialvi-tools[rapids]'"
) from e
return latent
@torch.inference_mode()
def get_scale_for_ct(
self,
label: str,
indices: Sequence[int] | None = None,
batch_size: int | None = None,
) -> pd.DataFrame:
r"""Return the scaled parameter of the NB for every spot in queried cell types.
Parameters
----------
label
cell type of interest
indices
Indices of cells in self.adata to use. If `None`, all cells are used.
batch_size
Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
Returns
-------
Pandas dataframe of gene_expression
"""
self._check_if_trained()
self._validate_anndata()
cell_type_mapping_extended = list(self.cell_type_mapping) + [
f"additional_{i}" for i in range(self.module.add_celltypes)
]
if label not in cell_type_mapping_extended:
raise ValueError("Unknown cell type")
y = cell_type_mapping_extended.index(label)
stdl = self._make_data_loader(self.adata, indices=indices, batch_size=batch_size)
scale = []
for tensors in stdl:
inference_inputs = self.module._get_inference_input(tensors)
outputs = self.module.inference(**inference_inputs)
generative_inputs = self.module._get_generative_input(tensors, outputs)
px_scale = self.module.generative(**generative_inputs)["px_mu"][0, :, y, :]
scale += [px_scale.cpu()]
data = torch.cat(scale).numpy()
column_names = self.adata.var.index
index_names = self.adata.obs.index
if indices is not None:
index_names = index_names[indices]
return pd.DataFrame(data=data, columns=column_names, index=index_names)
@torch.inference_mode()
def get_expression_for_ct(
self,
label: str,
indices: Sequence[int] | None = None,
batch_size: int | None = None,
return_sparse_array: bool = False,
) -> pd.DataFrame:
r"""Return the scaled parameter of the NB for every spot in queried cell types.
Parameters
----------
label
cell type of interest
indices
Indices of cells in self.adata to use. If `None`, all cells are used.
batch_size
Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
return_sparse_array
If `True`, returns a sparse array instead of a dataframe.
Returns
-------
Pandas dataframe of gene_expression
"""
self._check_if_trained()
if label not in self.cell_type_mapping_extended:
raise ValueError("Unknown cell type")
y = self.cell_type_mapping_extended.index(label)
stdl = self._make_data_loader(self.adata, indices=indices, batch_size=batch_size)
expression_ct = []
for tensors in stdl:
inference_inputs = self.module._get_inference_input(tensors)
outputs = self.module.inference(**inference_inputs)
generative_inputs = self.module._get_generative_input(tensors, outputs)
generative_outputs = self.module.generative(**generative_inputs)
px_scale, proportions = (
generative_outputs["px_mu"][0, ...],
generative_outputs["v"][0, ...],
)
px_scale_expected = torch.einsum("mkl,mk->mkl", px_scale, proportions)
px_scale_proportions = px_scale_expected[:, y, :] / px_scale_expected.sum(dim=1)
x_ct = tensors["X"].to(px_scale_proportions.device) * px_scale_proportions
expression_ct += [x_ct.cpu()]
data = torch.cat(expression_ct).numpy()
if return_sparse_array:
data = csr_matrix(data.T)
return data
else:
column_names = self.adata.var.index
index_names = self.adata.obs.index
if indices is not None:
index_names = index_names[indices]
return pd.DataFrame(data=data, columns=column_names, index=index_names)
@devices_dsp.dedent
def train(
self,
max_epochs: int = 2000,
lr: float = 0.003,
accelerator: str = "auto",
devices: int | list[int] | str = "auto",
train_size: float = 1.0,
validation_size: float | None = None,
shuffle_set_split: bool = True,
batch_size: int = 128,
n_epochs_kl_warmup: int = 200,
datasplitter_kwargs: dict | None = None,
plan_kwargs: dict | None = None,
**kwargs,
):
"""Trains the model using MAP inference.
Parameters
----------
max_epochs
Number of epochs to train for
lr
Learning rate for optimization.
%(param_accelerator)s
%(param_devices)s
train_size
Size of training set in the range [0.0, 1.0].
validation_size
Size of the test set. If `None`, defaults to 1 - `train_size`. If
`train_size + validation_size < 1`, the remaining cells belong to a test set.
shuffle_set_split
Whether to shuffle indices before splitting. If `False`, the val, train, and test set
are split in the sequential order of the data according to `validation_size` and
`train_size` percentages.
batch_size
Minibatch size to use during training.
n_epochs_kl_warmup
number of epochs needed to reach unit kl weight in the elbo
datasplitter_kwargs
Additional keyword arguments passed into :class:`~scvi.dataloaders.DataSplitter`.
plan_kwargs
Keyword args for :class:`~scvi.train.TrainingPlan`. Keyword arguments passed to
`train()` will overwrite values present in `plan_kwargs`, when appropriate.
**kwargs
Other keyword args for :class:`~scvi.train.Trainer`.
"""
update_dict = {
"lr": lr,
"n_epochs_kl_warmup": n_epochs_kl_warmup,
}
plan_kwargs = merge_kwargs(None, plan_kwargs, name="plan")
plan_kwargs.update(update_dict)
super().train(
max_epochs=max_epochs,
accelerator=accelerator,
devices=devices,
train_size=train_size,
validation_size=validation_size,
shuffle_set_split=shuffle_set_split,
batch_size=batch_size,
datasplitter_kwargs=datasplitter_kwargs,
plan_kwargs=plan_kwargs,
**kwargs,
)
@classmethod
@setup_anndata_dsp.dedent
def setup_anndata(
cls,
adata: AnnData,
layer: str | None = None,
smoothed_layer: str | None = None,
batch_key: str | None = None,
**kwargs,
):
"""%(summary)s.
Parameters
----------
%(param_adata)s
%(param_layer)s
smoothed_layer
param that...
%(param_batch_key)s
"""
setup_method_args = cls._get_setup_method_args(**locals())
# add index for each cell (provided to pyro plate for correct minibatching)
adata.obs["_indices"] = np.arange(adata.n_obs)
anndata_fields = [
LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True),
NumericalObsField(REGISTRY_KEYS.INDICES_KEY, "_indices"),
CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key),
]
if smoothed_layer is not None:
anndata_fields.append(LayerField("x_smoothed", smoothed_layer, is_count_data=True))
adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
adata_manager.register_fields(adata, **kwargs)
cls.register_manager(adata_manager)