"""Vendored from scvi-tools: scvi.external.resolvi._utils.ResolVIPredictiveMixin.
This avoids importing from a private scvi internal module.
"""
from __future__ import annotations
import logging
import warnings
from functools import partial
from typing import TYPE_CHECKING
import numpy as np
import pandas as pd
import torch
from pyro import infer
from scvi import settings
from scvi.model._utils import _get_batch_code_from_category, parse_device_args
from scvi.utils import track
if TYPE_CHECKING:
from collections.abc import Sequence
from anndata import AnnData
logger = logging.getLogger(__name__)
[docs]
class ResolVIPredictiveMixin:
"""Mixin class for generating samples from posterior distribution using infer.predictive."""
@torch.inference_mode()
def get_latent_representation(
self,
adata: AnnData | None = None,
indices: Sequence[int] | None = None,
give_mean: bool = True,
mc_samples: int = 1, # consistency, noqa, pylint: disable=unused-argument
batch_size: int | None = None,
return_dist: bool = False,
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
"""
Return the latent representation for each cell.
This is denoted as :math:`z` in RESOLVI.
Parameters
----------
adata
AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
AnnData object used to initialize the model.
indices
Indices of cells in adata to use. If `None`, all cells are used.
give_mean
Give mean of distribution or sample from it.
mc_samples
For consistency with scVI, this parameter is ignored.
batch_size
Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
return_dist
Return the distribution parameters of the latent variables rather than their sampled
values. If `True`, ignores `give_mean` and `mc_samples`.
Returns
-------
Low-dimensional representation for each cell or a tuple containing its mean and variance.
"""
adata = self._validate_anndata(adata)
scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size)
latent = []
latent_qzm = []
latent_qzv = []
_, _, device = parse_device_args(
accelerator="auto",
devices="auto",
return_device="torch",
validate_single_device=True,
)
for tensors in scdl:
_, kwargs = self.module._get_fn_args_from_batch(tensors)
kwargs = {k: v.to(device) if v is not None else v for k, v in kwargs.items()}
if kwargs["cat_covs"] is not None and self.module.encode_covariates:
categorical_input = list(torch.split(kwargs["cat_covs"], 1, dim=1))
else:
categorical_input = ()
qz_m, qz_v, z = self.module.z_encoder(
torch.log1p(kwargs["x"] / torch.mean(kwargs["x"], dim=1, keepdim=True)),
kwargs["batch_index"],
*categorical_input,
)
qz = torch.distributions.Normal(qz_m, qz_v.sqrt())
if give_mean:
z = qz.loc
latent += [z.cpu()]
latent_qzm += [qz.loc.cpu()]
latent_qzv += [qz.scale.square().cpu()]
return (
(torch.cat(latent_qzm).numpy(), torch.cat(latent_qzv).numpy())
if return_dist
else torch.cat(latent).numpy()
)
@torch.inference_mode()
def get_normalized_expression_importance(
self,
adata: AnnData | None = None,
indices: Sequence[int] | None = None,
transform_batch: Sequence[int | str] | None = None,
gene_list: Sequence[str] | None = None,
library_size: float | None = 1,
n_samples: int = 30,
n_samples_overall: int = None,
batch_size: int | None = None,
weights: str | np.ndarray | None = None,
return_mean: bool = True,
return_numpy: bool | None = None,
library_scaling: bool = False,
size_scaling: bool = False,
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
r"""Returns the normalized (decoded) importance-sampled gene expression.
This is denoted as :math:`\rho_n` in the scVI paper.
Parameters
----------
adata
AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
AnnData object used to initialize the model.
indices
Indices of cells in adata to use. If `None`, all cells are used.
transform_batch
Not supported. Here for consistency with other functions.
gene_list
Return frequencies of expression for a subset of genes.
This can save memory when working with large datasets and few genes are
of interest.
library_size
Scale the expression frequencies to a common library size.
This allows gene expression levels to be interpreted on a common scale of relevant
magnitude.
n_samples
Number of posterior samples to use for estimation.
n_samples_overall
Number of posterior samples to use for estimation. Overrides `n_samples`.
batch_size
Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
weights
Precomputed weight for importance sampling. If `uniform` no importance sampling is
performed.
return_mean
Whether to return the mean of the samples.
return_numpy
Return a :class:`~numpy.ndarray` instead of a :class:`~pandas.DataFrame`. DataFrame
includes gene names as columns. If either `n_samples=1` or `return_mean=True`, defaults
to `False`. Otherwise, it defaults to `True`.
library_scaling
If `True`, multiplies the decoded expression by the library size.
size_scaling
If `True`, divides the decoded expression by the size factor (e.g. cell_area).
Requires that a size factor key was provided in
:meth:`~scvi.external.RESOLVI.setup_anndata`.
Returns
-------
If `n_samples` is provided and `return_mean` is False,
this method returns a 3d tensor of shape (n_samples, n_cells, n_genes).
If `n_samples` is provided and `return_mean` is True, it returns a 2d tensor
of shape (n_cells, n_genes).
In this case, return type is :class:`~pandas.DataFrame` unless `return_numpy` is True.
Otherwise, the method expects `n_samples_overall` to be provided and returns a 2d tensor
of shape (n_samples_overall, n_genes) sampled by importance.
"""
adata = self._validate_anndata(adata)
if indices is None:
indices = np.arange(adata.n_obs)
scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size)
transform_batch = _get_batch_code_from_category(
self.get_anndata_manager(adata, required=True), transform_batch
)
gene_mask = slice(None) if gene_list is None else adata.var_names.isin(gene_list)
if n_samples > 1 and return_mean is False:
if return_numpy is False:
warnings.warn(
"`return_numpy` must be `True` if `n_samples > 1` and `return_mean` "
"is`False`, returning an `np.ndarray`.",
UserWarning,
stacklevel=settings.warnings_stacklevel,
)
return_numpy = True
exprs = []
weighting = []
_, _, device = parse_device_args(
accelerator="auto",
devices="auto",
return_device="torch",
validate_single_device=True,
)
for tensors in scdl:
args, kwargs = self.module._get_fn_args_from_batch(tensors)
kwargs = {k: v.to(device) if v is not None else v for k, v in kwargs.items()}
model_now = partial(self.module.model_simplified, corrected_rate=True)
importance_dist = infer.Importance(
model_now, guide=self.module.guide.guide_simplified, num_samples=10 * n_samples
)
posterior = importance_dist.run(*args, **kwargs)
marginal = infer.EmpiricalMarginal(posterior, sites=["mean_poisson", "px_scale"])
samples = torch.cat([marginal().unsqueeze(1) for i in range(n_samples)], 1)
log_weights = (
torch.distributions.Poisson(samples[0, ...] + 1e-3)
.log_prob(kwargs["x"].to(samples.device))
.sum(-1)
)
log_weights = log_weights / kwargs["x"].to(samples.device).sum(-1)
weighting.append(log_weights.reshape(-1).cpu())
if library_scaling or size_scaling:
if size_scaling:
if "size_factor" in self.adata_manager.data_registry:
size_factor = kwargs["size_factor"]
samples[0, ...] = samples[0, ...] / size_factor.unsqueeze(0).repeat(
n_samples, 1, 1
)
else:
raise ValueError(
"size_scaling is True but no size_factor_key was provided "
"in setup_anndata."
)
exprs.append(samples[0, ...].cpu())
else:
exprs.append(samples[1, ...].cpu())
exprs = torch.cat(exprs, axis=1).numpy()
if return_mean:
exprs = exprs.mean(0)
weighting = torch.cat(weighting, axis=0).numpy()
if library_size is not None:
exprs = library_size * exprs
if n_samples_overall is not None:
# Converts the 3d tensor to a 2d tensor
exprs = exprs.reshape(-1, exprs.shape[-1])
n_samples_ = exprs.shape[0]
if weights == "uniform":
p = None
else:
weighting -= weighting.max()
weighting = np.exp(weighting)
p = weighting / weighting.sum(axis=0, keepdims=True)
ind_ = np.random.choice(n_samples_, n_samples_overall, p=p, replace=True)
exprs = exprs[ind_]
if return_numpy is None or return_numpy is False:
return pd.DataFrame(
exprs,
columns=adata.var_names[gene_mask],
index=adata.obs_names[indices],
)
else:
return exprs
@torch.inference_mode()
def get_normalized_expression(
self,
adata: AnnData | None = None,
indices: Sequence[int] | None = None,
transform_batch: Sequence[int | str] | None = None,
gene_list: Sequence[str] | None = None,
library_size: float | None = 1,
size_scaling: bool = False,
n_samples: int = 1,
n_samples_overall: int = None,
batch_size: int | None = None,
return_mean: bool = True,
return_numpy: bool | None = None,
silent: bool = True,
**kwargs,
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
r"""Returns the normalized (decoded) gene expression.
This is denoted as :math:`\rho_n` in the scVI paper.
Parameters
----------
adata
AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
AnnData object used to initialize the model.
indices
Indices of cells in adata to use. If `None`, all cells are used.
transform_batch
Batch to condition on.
If transform_batch is:
- None, then real observed batch is used.
- int, then batch transform_batch is used.
gene_list
Return frequencies of expression for a subset of genes.
This can save memory when working with large datasets and few genes are
of interest.
library_size
Scale the expression frequencies to a common library size.
This allows gene expression levels to be interpreted on a common scale of relevant
magnitude.
size_scaling
If `True`, divides the decoded expression by the size factor (e.g. cell_area).
Requires that a size factor key was provided in
:meth:`~scvi.external.RESOLVI.setup_anndata`.
n_samples
Number of posterior samples to use for estimation.
n_samples_overall
Number of posterior samples to use for estimation. Overrides `n_samples`.
batch_size
Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
return_mean
Whether to return the mean of the samples.
return_numpy
Return a :class:`~numpy.ndarray` instead of a :class:`~pandas.DataFrame`. DataFrame
includes gene names as columns. If either `n_samples=1` or `return_mean=True`, defaults
to `False`. Otherwise, it defaults to `True`.
**kwargs
Additional keyword arguments passed
Returns
-------
If `n_samples` is provided and `return_mean` is False,
this method returns a 3d tensor of shape (n_samples, n_cells, n_genes).
If `n_samples` is provided and `return_mean` is True, it returns a 2d tensor
of shape (n_cells, n_genes).
In this case, return type is :class:`~pandas.DataFrame` unless `return_numpy` is True.
Otherwise, the method expects `n_samples_overall` to be provided and returns a 2d tensor
of shape (n_samples_overall, n_genes).
"""
adata = self._validate_anndata(adata)
if indices is None:
indices = np.arange(adata.n_obs)
scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size)
transform_batch = _get_batch_code_from_category(
self.get_anndata_manager(adata, required=True), transform_batch
)
gene_mask = slice(None) if gene_list is None else adata.var_names.isin(gene_list)
if n_samples > 1 and return_mean is False:
if return_numpy is False:
warnings.warn(
"`return_numpy` must be `True` if `n_samples > 1` and `return_mean` "
"is`False`, returning an `np.ndarray`.",
UserWarning,
stacklevel=settings.warnings_stacklevel,
)
return_numpy = True
exprs = []
_, _, device = parse_device_args(
accelerator="auto",
devices="auto",
return_device="torch",
validate_single_device=True,
)
for tensors in scdl:
per_batch_exprs = []
for batch in track(transform_batch, disable=silent):
_, kwargs = self.module._get_fn_args_from_batch(tensors)
kwargs = {k: v.to(device) if v is not None else v for k, v in kwargs.items()}
if kwargs["cat_covs"] is not None and self.module.encode_covariates:
categorical_input = list(torch.split(kwargs["cat_covs"], 1, dim=1))
else:
categorical_input = ()
qz_m, qz_v, _ = self.module.z_encoder(
torch.log1p(kwargs["x"] / torch.mean(kwargs["x"], dim=1, keepdim=True)),
kwargs["batch_index"],
*categorical_input,
)
z = torch.distributions.Normal(qz_m, qz_v.sqrt()).sample([n_samples])
if kwargs["cat_covs"] is not None:
categorical_input = list(torch.split(kwargs["cat_covs"], 1, dim=1))
else:
categorical_input = ()
if batch is not None:
batch = torch.full_like(kwargs["batch_index"], batch)
else:
batch = kwargs["batch_index"]
px_scale, _, px_rate, _ = self.module.model.decoder(
self.module.model.dispersion, z, kwargs["library"], batch, *categorical_input
)
if size_scaling:
if "size_factor" in self.adata_manager.data_registry:
size_factor = kwargs["size_factor"]
px_rate = px_rate / size_factor.reshape(-1, 1, 1)
exp_ = px_rate
else:
raise ValueError(
"size_scaling is True but no size_factor_key was provided "
"in setup_anndata."
)
else:
exp_ = library_size * px_scale if library_size is not None else px_rate
exp_ = exp_[..., gene_mask]
per_batch_exprs.append(exp_[None].cpu())
per_batch_exprs = torch.cat(per_batch_exprs, dim=0).mean(0).numpy()
exprs.append(per_batch_exprs)
exprs = np.concatenate(exprs, axis=1)
if return_mean:
exprs = exprs.mean(0)
if n_samples_overall is not None:
# Converts the 3d tensor to a 2d tensor
exprs = exprs.reshape(-1, exprs.shape[-1])
n_samples_ = exprs.shape[0]
ind_ = np.random.choice(n_samples_, n_samples_overall, replace=True)
exprs = exprs[ind_]
if return_numpy is None or return_numpy is False:
return pd.DataFrame(
exprs,
columns=adata.var_names[gene_mask],
index=adata.obs_names[indices],
)
else:
return exprs
@torch.inference_mode()
def get_neighbor_abundance(
self,
adata: AnnData | None = None,
indices: Sequence[int] | None = None,
neighbor_key: str | None = None,
n_samples: int = 1,
n_samples_overall: int = None,
batch_size: int | None = None,
summary_frequency: int = 2,
weights: str | None = None,
return_mean: bool = True,
return_numpy: bool | None = None,
**kwargs,
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
r"""Returns the abundance of cell-types within spatial proximity of center cells.
Parameters
----------
adata
AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
AnnData object used to initialize the model.
indices
Indices of cells in adata to use. If `None`, all cells are used.
neighbor_key
Obsm key containing the spatial neighbors of each cell.
n_samples
Number of posterior samples to use for estimation.
n_samples_overall
Number of posterior samples to use for estimation. Overrides `n_samples`.
batch_size
Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
summary_frequency
Compute summary_fn after summary_frequency batches. Reduces memory footprint.
weights
Spatial weights for each neighbor. If `None` performs no spatial weighting.
Needs to be of shape `n_cells` by `n_neighbors`.
return_mean
Whether to return the mean of the samples.
return_numpy
Return a :class:`~numpy.ndarray` instead of a :class:`~pandas.DataFrame`. DataFrame
includes gene names as columns. If either `n_samples=1` or `return_mean=True`, defaults
to `False`. Otherwise, it defaults to `True`.
kwargs
Additional keyword arguments that have no effect and only serve for compatibility.
Returns
-------
If `n_samples` is provided and `return_mean` is False,
this method returns a 3d tensor of shape (n_samples, n_cells, n_celltypes).
If `n_samples` is provided and `return_mean` is True, it returns a 2d tensor
of shape (n_cells, n_celltypes).
In this case, return type is :class:`~pandas.DataFrame` unless `return_numpy` is True.
Otherwise, the method expects `n_samples_overall` to be provided and returns a 2d tensor
of shape (n_samples_overall, n_celltypes).
"""
if adata:
assert neighbor_key is not None, "Must provide `neighbor_key` if `adata` is provided."
adata = self._validate_anndata(adata)
if indices is None:
indices = np.arange(adata.n_obs)
if neighbor_key is None:
neighbor_key = self.adata_manager.registry["field_registries"]["index_neighbor"][
"data_registry"
]["attr_key"]
neighbor_obsm = adata.obsm[neighbor_key]
else:
neighbor_obsm = adata.obsm[neighbor_key]
n_neighbors = neighbor_obsm.shape[-1]
if n_samples > 1 and return_mean is False:
if return_numpy is False:
warnings.warn(
"`return_numpy` must be `True` if `n_samples > 1` and `return_mean` "
"is `False`, returning an `np.ndarray`.",
UserWarning,
stacklevel=settings.warnings_stacklevel,
)
return_numpy = True
if batch_size is not None and batch_size % n_neighbors != 0:
raise ValueError("Batch size must be divisible by the number of neighbors.")
batch_size = batch_size if batch_size is not None else n_neighbors * settings.batch_size
indices_ = neighbor_obsm[indices].reshape(-1)
dl = self._make_data_loader(
adata=adata, indices=indices_, shuffle=False, batch_size=batch_size
)
sampled_prediction = self.sample_posterior(
input_dl=dl,
model=self.module.model_corrected,
return_sites=["probs_prediction"],
summary_frequency=summary_frequency,
num_samples=n_samples,
return_samples=True,
)
flat_neighbor_abundance_ = sampled_prediction["posterior_samples"]["probs_prediction"]
neighbor_abundance_ = flat_neighbor_abundance_.reshape(
n_samples, len(indices), n_neighbors, -1
)
neighbor_abundance = np.average(neighbor_abundance_, axis=-2, weights=weights)
if return_mean:
neighbor_abundance = np.mean(neighbor_abundance, axis=0)
if n_samples_overall is not None:
# Converts the 3d tensor to a 2d tensor
neighbor_abundance = neighbor_abundance.reshape(-1, neighbor_abundance.shape[-1])
n_samples_ = neighbor_abundance.shape[0]
ind_ = np.random.choice(n_samples_, n_samples_overall, replace=True)
neighbor_abundance = neighbor_abundance[ind_]
if return_numpy is None or return_numpy is False:
assert return_mean, "Only numpy output is supported when `return_mean` is False."
n_labels = len(neighbor_abundance[-1])
return pd.DataFrame(
neighbor_abundance,
columns=self._label_mapping[:n_labels],
index=adata.obs_names[indices],
)
else:
return neighbor_abundance