from __future__ import annotations
import logging
from functools import partial
from typing import TYPE_CHECKING
import numpy as np
import pandas as pd
import pyro
from pyro.infer import Trace_ELBO
from scvi import REGISTRY_KEYS
from scvi.data import AnnDataManager
from scvi.data._utils import get_anndata_attribute
from scvi.data.fields import (
CategoricalJointObsField,
CategoricalObsField,
LabelsWithUnlabeledObsField,
LayerField,
NumericalObsField,
ObsmField,
)
from scvi.dataloaders import AnnTorchDataset
from scvi.model._utils import (
scrna_raw_counts_properties,
)
from scvi.model.base import ArchesMixin, BaseModelClass, PyroSampleMixin, PyroSviTrainMixin
from scvi.model.base._de_core import _de_core
from scvi.train._config import merge_kwargs
from scvi.utils import de_dsp, setup_anndata_dsp
from spatialvi.model.base import SpatialBaseModel, SpatialNeighborhoodMixin
from spatialvi.model.base._resolvi_predictive import ResolVIPredictiveMixin
from spatialvi.module._resolvae import RESOLVAE
if TYPE_CHECKING:
from collections.abc import Iterable, Sequence
from typing import Literal
from anndata import AnnData
logger = logging.getLogger(__name__)
[docs]
class ResolVI(
SpatialNeighborhoodMixin,
SpatialBaseModel,
PyroSviTrainMixin,
PyroSampleMixin,
ResolVIPredictiveMixin,
ArchesMixin,
BaseModelClass,
):
"""
ResolVI addresses noise and bias in single-cell resolved spatial transcriptomics data.
Parameters
----------
adata
AnnData object that has been registered via :meth:`~spatialvi.model.ResolVI.setup_anndata`.
n_hidden
Number of nodes per hidden layer.
n_latent
Dimensionality of the latent space.
n_layers
Number of hidden layers used for encoder and decoder NNs.
dropout_rate
Dropout rate for neural networks.
dispersion
One of the following:
* ``'gene'`` - dispersion parameter of NB is constant per gene across cells
* ``'gene-batch'`` - dispersion can differ between different batches
gene_likelihood
One of:
* ``'nb'`` - Negative binomial distribution
* ``'poisson'`` - Poisson distribution
**model_kwargs
Keyword args for :class:`~spatialvi.module.RESOLVAE`
Examples
--------
>>> adata = anndata.read_h5ad(path_to_anndata)
>>> ResolVI.setup_anndata(adata, batch_key="batch")
>>> model = ResolVI(adata)
>>> model.train()
>>> adata.obsm["X_resolVI"] = model.get_latent_representation()
>>> adata.layers["X_normalized_resolVI"] = model.get_normalized_expression()
"""
_module_cls = RESOLVAE
_block_parameters = []
[docs]
def __init__(
self,
adata: AnnData,
n_hidden: int = 32,
n_hidden_encoder: int = 128,
n_latent: int = 10,
n_layers: int = 2,
dropout_rate: float = 0.05,
dispersion: Literal["gene", "gene-batch"] = "gene",
gene_likelihood: Literal["nb", "poisson"] = "nb",
background_ratio=None,
median_distance=None,
semisupervised=False,
mixture_k=50,
downsample_counts=True,
**model_kwargs,
):
pyro.clear_param_store()
super().__init__(adata)
if semisupervised:
self._set_indices_and_labels()
results = self.compute_dataset_dependent_priors()
n_cats_per_cov = (
self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).n_cats_per_key
if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry
else None
)
n_labels = self.summary_stats.n_labels - 1
if background_ratio is None:
background_ratio = results["background_ratio"]
if median_distance is None:
median_distance = results["median_distance"]
if downsample_counts:
downsample_counts_mean = results["mean_log_counts"]
downsample_counts_std = results["std_log_counts"]
else:
downsample_counts_mean = None
downsample_counts_std = 1.0
expression_anntorchdata = AnnTorchDataset(
self.adata_manager,
getitem_tensors=["X"],
load_sparse_tensor=True,
)
self.module = self._module_cls(
n_input=self.summary_stats.n_vars,
n_batch=self.summary_stats.n_batch,
n_cats_per_cov=n_cats_per_cov,
n_labels=n_labels,
mixture_k=mixture_k,
expression_anntorchdata=expression_anntorchdata,
n_neighbors=self.summary_stats.n_distance_neighbor,
n_obs=self.summary_stats["n_ind_x"],
n_hidden=n_hidden,
n_hidden_encoder=n_hidden_encoder,
n_latent=n_latent,
n_layers=n_layers,
dropout_rate=dropout_rate,
dispersion=dispersion,
gene_likelihood=gene_likelihood,
background_ratio=background_ratio,
median_distance=median_distance,
semisupervised=semisupervised,
downsample_counts_mean=downsample_counts_mean,
downsample_counts_std=downsample_counts_std,
**model_kwargs,
)
self._model_summary_string = (
f"ResolVI Model with the following params: \nn_hidden: {n_hidden} "
f"n_latent: {n_latent}, n_layers: {n_layers}, dropout_rate: "
f"{dropout_rate}, dispersion: {dispersion}, gene_likelihood: {gene_likelihood} "
f"n_neighbors: {self.summary_stats.n_distance_neighbor}"
)
self.init_params_ = self._get_init_params(locals())
def train(
self,
max_epochs: int = 50,
lr: float = 3e-3,
lr_extra: float = 1e-2,
extra_lr_parameters: tuple = ("per_neighbor_diffusion_map", "u_prior_means"),
batch_size: int = 512,
weight_decay: float = 0.0,
eps: float = 1e-4,
n_steps_kl_warmup: int | None = None,
n_epochs_kl_warmup: int | None = 20,
plan_kwargs: dict | None = None,
expose_params: list = (),
**kwargs,
):
"""
Train the model using amortized variational inference.
Parameters
----------
max_epochs
Number of passes through the dataset.
lr
Learning rate for optimization.
lr_extra
Learning rate for parameters (non-amortized and custom ones)
extra_lr_parameters
List of parameters to train with `lr_extra` learning rate.
batch_size
Minibatch size to use during training.
weight_decay
weight decay regularization term for optimization
eps
Optimizer eps
n_steps_kl_warmup
Number of training steps (minibatches) to scale weight on KL divergences from 0 to 1.
Only activated when `n_epochs_kl_warmup` is set to None.
n_epochs_kl_warmup
Number of epochs to scale weight on KL divergences from 0 to 1.
Overrides `n_steps_kl_warmup` when both are not `None`.
plan_kwargs
Keyword args for the Pyro training plan.
expose_params
List of parameters to train if running model in Arches mode.
**kwargs
Other keyword args for the Trainer.
"""
blocked = self._block_parameters.copy()
for name, param in self.module.named_parameters():
if not param.requires_grad:
blocked.append(name)
param.requires_grad = True
blocked = set(blocked) - set(expose_params)
if blocked:
print("Running transfer learning mode. Set lr to 0 and blocking variables.")
def per_param_callable(module_name, param_name):
store_name = f"{module_name}$$${param_name}" if "." in param_name else param_name
if store_name in blocked:
return {"lr": 0.0, "weight_decay": 0, "eps": eps}
if store_name in extra_lr_parameters:
return {"lr": lr_extra, "weight_decay": weight_decay, "eps": eps}
else:
return {"lr": lr, "weight_decay": weight_decay, "eps": eps}
optim = pyro.optim.Adam(per_param_callable)
plan_kwargs = merge_kwargs(None, plan_kwargs, name="plan")
plan_kwargs.update(
{
"optim_kwargs": {"lr": lr, "weight_decay": weight_decay, "eps": eps},
"optim": optim,
"blocked": blocked,
"n_epochs_kl_warmup": n_epochs_kl_warmup
if n_epochs_kl_warmup is not None
else max_epochs,
"n_steps_kl_warmup": n_steps_kl_warmup,
"loss_fn": Trace_ELBO(
num_particles=5, vectorize_particles=True, retain_graph=True
),
}
)
super().train(
max_epochs=max_epochs,
train_size=1.0,
plan_kwargs=plan_kwargs,
batch_size=batch_size,
**kwargs,
)
@staticmethod
def _prepare_data(
adata: AnnData,
n_neighbors: int = 10,
spatial_rep: str = "X_spatial",
batch_key: str | None = None,
slice_key: str | None = None,
**kwargs,
) -> None:
"""Compute spatial neighbors and store in ``adata.obsm``.
Parameters
----------
adata
AnnData object with spatial coordinates.
n_neighbors
Number of spatial neighbors to compute.
spatial_rep
Key in ``adata.obsm`` containing spatial coordinates.
batch_key
Key in ``adata.obs`` for batch/slice grouping of neighbor computation.
slice_key
Alias for ``batch_key``.
"""
if slice_key is not None:
batch_key = slice_key
try:
import scanpy
from sklearn.neighbors._base import _kneighbors_from_graph
except ImportError as err:
raise ImportError(
"Please install scanpy and scikit-learn -- `pip install scanpy`"
) from err
if batch_key is None:
indices = [np.arange(adata.n_obs)]
else:
indices = [
np.where(adata.obs[batch_key] == i)[0] for i in adata.obs[batch_key].unique()
]
distance_neighbor = 1e6 * np.ones([adata.n_obs, n_neighbors])
index_neighbor = np.zeros([adata.n_obs, n_neighbors], dtype=int)
for index in indices:
sub_data = adata[index].copy()
try:
import rapids_singlecell
rapids_singlecell.pp.neighbors(
sub_data, n_neighbors=n_neighbors + 5, use_rep=spatial_rep
)
except ImportError:
scanpy.pp.neighbors(sub_data, n_neighbors=n_neighbors + 5, use_rep=spatial_rep)
distances = sub_data.obsp["distances"] ** 2
distance_neighbor[index, :], index_neighbor_batch = _kneighbors_from_graph(
distances, n_neighbors, return_distance=True
)
index_neighbor[index, :] = index[index_neighbor_batch]
adata.obsm["X_spatial"] = adata.obsm[spatial_rep]
adata.obsm["index_neighbor"] = index_neighbor
adata.obsm["distance_neighbor"] = distance_neighbor
@classmethod
@setup_anndata_dsp.dedent
def setup_anndata(
cls,
adata: AnnData,
layer: str | None = None,
batch_key: str | None = None,
labels_key: str | None = None,
size_factor_key: str | None = None,
categorical_covariate_keys: list[str] | None = None,
prepare_data: bool | None = True,
prepare_data_kwargs: dict | None = None,
unlabeled_category: str = "unknown",
**kwargs,
):
"""%(summary)s.
Parameters
----------
%(param_adata)s
%(param_layer)s
%(param_batch_key)s
%(param_labels_key)s
size_factor_key
Key in ``adata.obs`` corresponding to pre-computed size factors.
%(param_cat_cov_keys)s
prepare_data
If ``True``, automatically compute spatial neighbors via :meth:`_prepare_data`.
Set to ``False`` if neighbors are already in ``adata.obsm``.
prepare_data_kwargs
Keyword args for :meth:`_prepare_data` (e.g. ``n_neighbors``, ``spatial_rep``).
%(param_unlabeled_category)s
"""
setup_method_args = cls._get_setup_method_args(**locals())
x = adata.X if layer is None else adata.layers[layer]
assert np.min(x.sum(axis=1)) > 0, (
"Please filter cells with less than 5 counts prior to running ResolVI."
)
if prepare_data:
if prepare_data_kwargs is None:
prepare_data_kwargs = {}
spatial_rep = prepare_data_kwargs.get("spatial_rep", "X_spatial")
if spatial_rep in adata.obsm:
cls._prepare_data(adata, batch_key=batch_key, **prepare_data_kwargs)
elif "index_neighbor" not in adata.obsm:
raise KeyError(
f"Spatial key '{spatial_rep}' not found in adata.obsm and no pre-computed "
"neighbors found. Either provide spatial coordinates or pre-compute "
"neighbors manually and call setup_anndata with prepare_data=False."
)
if batch_key is not None:
adata.obs["_indices"] = (
adata.obs[batch_key].astype(str) + "_" + adata.obs_names.astype(str)
)
else:
adata.obs["_indices"] = adata.obs_names.astype(str)
adata.obs["_indices"] = adata.obs["_indices"].astype("category")
assert not adata.obs["_indices"].duplicated(keep="first").any(), (
"obs_names need to be unique prior to running ResolVI."
)
if labels_key is None:
label_field = CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key)
else:
label_field = LabelsWithUnlabeledObsField(
REGISTRY_KEYS.LABELS_KEY, labels_key, unlabeled_category
)
anndata_fields = [
LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True),
CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key),
NumericalObsField(REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key, required=False),
ObsmField("index_neighbor", "index_neighbor"),
ObsmField("distance_neighbor", "distance_neighbor"),
CategoricalObsField(REGISTRY_KEYS.INDICES_KEY, "_indices"),
label_field,
CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys),
]
adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
adata_manager.register_fields(adata, **kwargs)
cls.register_manager(adata_manager)
def compute_dataset_dependent_priors(self, n_small_genes=None):
"""Compute dataset-dependent priors for background ratio and median distance."""
x = self.adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY)
n_small_genes = x.shape[1] // 50 if n_small_genes is None else int(n_small_genes)
smallest_means = x[:, np.array(x.sum(0)).flatten().argsort()[:n_small_genes]].mean(
1
) / np.array(x.mean(1))
background_ratio = np.mean(np.array(smallest_means))
distance = self.adata_manager.get_from_registry("distance_neighbor")
median_distance = np.median(np.partition(distance, 5)[:, 5])
log_library_size = np.log1p(np.array(x.sum(1)))
mean_log_counts = np.median(log_library_size)
std_log_counts = np.std(log_library_size)
return {
"background_ratio": background_ratio,
"median_distance": median_distance,
"mean_log_counts": mean_log_counts,
"std_log_counts": std_log_counts,
}
@de_dsp.dedent
def differential_expression(
self,
adata: AnnData | None = None,
groupby: str | None = None,
group1: Iterable[str] | None = None,
group2: str | None = None,
idx1: Sequence[int] | Sequence[bool] | None = None,
idx2: Sequence[int] | Sequence[bool] | None = None,
subset_idx: Sequence[int] | None = None,
mode: Literal["vanilla", "change"] = "change",
delta: float = 0.25,
batch_size: int | None = None,
all_stats: bool = True,
batch_correction: bool = False,
batchid1: Iterable[str] | None = None,
batchid2: Iterable[str] | None = None,
fdr_target: float = 0.05,
silent: bool = False,
weights: Literal["uniform", "importance"] | None = "uniform",
filter_outlier_cells: bool = False,
n_samples: int = 5,
size_scaling: bool = False,
library_scaling: bool = False,
**kwargs,
) -> pd.DataFrame:
r"""A unified method for differential expression analysis.
Implements `"vanilla"` DE :cite:p:`Lopez18` and `"change"` mode DE :cite:p:`Boyeau19`.
Parameters
----------
%(de_adata)s
%(de_groupby)s
%(de_group1)s
%(de_group2)s
%(de_idx1)s
%(de_idx2)s
%(de_subset_idx)s
%(de_mode)s
%(de_delta)s
%(de_batch_size)s
%(de_all_stats)s
%(de_batch_correction)s
%(de_batchid1)s
%(de_batchid2)s
%(de_fdr_target)s
%(de_silent)s
weights
Precomputed weight for importance sampling.
filter_outlier_cells
Whether to filter outlier cells.
n_samples
Number of posterior samples to use for estimation.
size_scaling
If True, will scale normalized expression by size factors.
library_scaling
If True, will scale normalized expression to library size.
**kwargs
Keyword args for DifferentialComputation.get_bayes_factors
Returns
-------
Differential expression DataFrame.
"""
adata = self._validate_anndata(adata)
if library_scaling and weights != "importance":
raise ValueError(
"library_scaling=True is only supported with weights='importance'. "
"Pass weights='importance' or set library_scaling=False."
)
if weights == "importance":
model_fn = partial(
self.get_normalized_expression_importance,
return_numpy=True,
n_samples=n_samples,
batch_size=batch_size,
weights=weights,
return_mean=False,
size_scaling=size_scaling,
library_scaling=library_scaling,
)
else:
model_fn = partial(
self.get_normalized_expression,
return_numpy=True,
n_samples=n_samples,
batch_size=batch_size,
weights=weights,
return_mean=False,
size_scaling=size_scaling,
)
representation_fn = self.get_latent_representation if filter_outlier_cells else None
result = _de_core(
adata_manager=self.get_anndata_manager(adata, required=True),
model_fn=model_fn,
representation_fn=representation_fn,
groupby=groupby,
group1=group1,
group2=group2,
idx1=idx1,
idx2=idx2,
subset_idx=subset_idx,
all_stats=all_stats,
all_stats_fn=scrna_raw_counts_properties,
col_names=adata.var_names,
mode=mode,
batchid1=batchid1,
batchid2=batchid2,
delta=delta,
batch_correction=batch_correction,
fdr=fdr_target,
silent=silent,
**kwargs,
)
return result
@de_dsp.dedent
def differential_niche_abundance(
self,
adata: AnnData | None = None,
groupby: str | None = None,
group1: Iterable[str] | None = None,
group2: str | None = None,
neighbor_key: str | None = None,
idx1: Sequence[int] | Sequence[bool] | None = None,
idx2: Sequence[int] | Sequence[bool] | None = None,
subset_idx: Sequence[int] | None = None,
mode: Literal["vanilla", "change"] = "change",
delta: float = 0.25,
batch_size: int | None = None,
fdr_target: float = 0.05,
silent: bool = False,
filter_outlier_cells: bool = False,
pseudocounts: float = 1e-3,
**kwargs,
) -> pd.DataFrame:
r"""A unified method for niche differential abundance analysis.
Parameters
----------
%(de_adata)s
%(de_groupby)s
%(de_group1)s
%(de_group2)s
neighbor_key
Obsm key containing the spatial neighbors of each cell.
%(de_idx1)s
%(de_idx2)s
%(de_subset_idx)s
%(de_mode)s
%(de_delta)s
%(de_batch_size)s
%(de_fdr_target)s
%(de_silent)s
filter_outlier_cells
Whether to filter outlier cells.
pseudocounts
pseudocount offset used for the mode `change`.
**kwargs
Keyword args for DifferentialComputation.get_bayes_factors
Returns
-------
Differential expression DataFrame.
"""
adata = self._validate_anndata(adata)
model_fn = partial(
self.get_neighbor_abundance,
return_numpy=True,
n_samples=5,
batch_size=batch_size,
return_mean=False,
neighbor_key=neighbor_key,
)
representation_fn = self.get_latent_representation if filter_outlier_cells else None
result = _de_core(
adata_manager=self.get_anndata_manager(adata, required=True),
model_fn=model_fn,
representation_fn=representation_fn,
groupby=groupby,
group1=group1,
group2=group2,
idx1=idx1,
idx2=idx2,
subset_idx=subset_idx,
all_stats=False,
all_stats_fn=scrna_raw_counts_properties,
col_names=self._label_mapping[:-1],
mode=mode,
batchid1=None,
batchid2=None,
delta=delta,
batch_correction=False,
fdr=fdr_target,
silent=silent,
pseudocounts=pseudocounts,
**kwargs,
)
return result
def predict(
self,
adata: AnnData | None = None,
indices: Sequence[int] | None = None,
soft: bool = False,
batch_size: int | None = 500,
num_samples: int | None = 30,
) -> np.ndarray | pd.DataFrame:
"""
Return cell label predictions.
Parameters
----------
adata
AnnData object that has been registered via setup_anndata.
indices
Subsample AnnData to these indices.
soft
If True, returns per class probabilities
batch_size
Minibatch size for data loading into the model.
num_samples
Samples to draw from the posterior for cell-type prediction.
"""
adata = self._validate_anndata(adata)
if indices is None:
indices = np.arange(adata.n_obs)
sampled_prediction = self.sample_posterior(
adata=adata,
indices=indices,
model=self.module.model_corrected,
return_sites=["probs_prediction"],
num_samples=num_samples,
return_samples=False,
batch_size=batch_size,
summary_frequency=10,
return_observed=True,
)
y_pred = sampled_prediction["post_sample_means"]["probs_prediction"]
if not soft:
y_pred = y_pred.argmax(axis=1)
predictions = [self._code_to_label[p] for p in y_pred]
return np.array(predictions)
else:
n_labels = len(y_pred[0])
predictions = pd.DataFrame(
y_pred,
columns=self._label_mapping[:n_labels],
index=adata.obs_names[indices],
)
return predictions
def _set_indices_and_labels(self):
"""Set indices for labeled and unlabeled cells."""
labels_state_registry = self.adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY)
self.original_label_key = labels_state_registry.original_key
self.unlabeled_category_ = labels_state_registry.unlabeled_category
labels = get_anndata_attribute(
self.adata,
self.adata_manager.data_registry.labels.attr_name,
self.original_label_key,
).ravel()
self._label_mapping = labels_state_registry.categorical_mapping
self._unlabeled_indices = np.argwhere(labels == self.unlabeled_category_).ravel()
self._labeled_indices = np.argwhere(labels != self.unlabeled_category_).ravel()
self._code_to_label = dict(enumerate(self._label_mapping))