Source code for spatialvi.module._nichevae

from __future__ import annotations

import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING

import torch
from scvi import REGISTRY_KEYS
from scvi.module import VAE, Classifier
from scvi.module._constants import MODULE_KEYS
from scvi.module.base import (
    LossOutput,
    auto_move_data,
)
from torch.nn.functional import one_hot

from spatialvi._constants import SCVIVA_MODULE_KEYS, SCVIVA_REGISTRY_KEYS
from spatialvi.module.utils._nichevae_components import DirichletDecoder, Encoder, NicheDecoder

if TYPE_CHECKING:
    from typing import Literal

    import numpy as np
    from scvi._types import LossRecord
    from torch.distributions import Distribution

logger = logging.getLogger(__name__)


[docs] class nicheVAE(VAE): """Variational auto-encoder with niche decoders :cite:p:`Levy25`. Parameters ---------- n_input Number of input features. n_batch Number of batches. If ``0``, no batch correction is performed. n_labels Number of labels. n_hidden Number of nodes per hidden layer. Passed into :class:`~scvi.nn.Encoder` and :class:`~scvi.nn.DecoderSCVI`. n_latent Dimensionality of the latent space. n_layers Number of hidden layers. Passed into :class:`~scvi.nn.Encoder` and :class:`~scvi.nn.DecoderSCVI`. n_layers_niche Number of hidden layers in the niche state decoder. n_layers_compo Number of hidden layers in the composition decoder. n_hidden_niche Number of nodes per hidden layer in the niche state decoder. n_hidden_compo Number of nodes per hidden layer in the composition decoder. n_continuous_cov Number of continuous covariates. n_cats_per_cov A list of integers containing the number of categories for each categorical covariate. dropout_rate Dropout rate. Passed into :class:`~scvi.nn.Encoder` but not :class:`~scvi.nn.DecoderSCVI`. dispersion Flexibility of the dispersion parameter when ``gene_likelihood`` is either ``"nb"`` or ``"zinb"``. One of the following: * ``"gene"``: parameter is constant per gene across cells. * ``"gene-batch"``: parameter is constant per gene per batch. * ``"gene-label"``: parameter is constant per gene per label. * ``"gene-cell"``: parameter is constant per gene per cell. log_variational If ``True``, use :func:`~torch.log1p` on input data before encoding for numerical stability (not normalization). gene_likelihood Distribution to use for reconstruction in the generative process. One of the following: * ``"nb"``: :class:`~scvi.distributions.NegativeBinomial`. * ``"zinb"``: :class:`~scvi.distributions.ZeroInflatedNegativeBinomial`. * ``"poisson"``: :class:`~scvi.distributions.Poisson`. latent_distribution Distribution to use for the latent space. One of the following: * ``"normal"``: isotropic normal. * ``"ln"``: logistic normal with normal params N(0, 1). niche_likelihood Distribution to use for the niche state. One of the following: * ``"poisson"``: :class:`~torch.distributions.Poisson`. * ``"gaussian"``: :class:`~torch.distributions.Normal`. Default is ``"gaussian"`` and Poisson should be used if the niche state is count data. cell_rec_weight Weight of the cell reconstruction loss. latent_kl_weight Weight of the latent KL divergence. spatial_weight Weight of the spatial losses prior_mixture If ``True``, use a mixture of Gaussians for the latent space. Else, use unimodal Gaussian. prior_mixture_k Number of components in the Gaussian mixture. semisupervised If ``True``, use a classifier to predict cell type labels from the latent space. linear_classifier If ``True``, use a linear classifier. Else, use a neural network. inpute_covariates_niche_decoder If ``True``, covariates are concatenated to the input of the niche state decoder. encode_covariates If ``True``, covariates are concatenated to gene expression prior to passing through the encoder(s). Else, only gene expression is used. deeply_inject_covariates If ``True`` and ``n_layers > 1``, covariates are concatenated to the outputs of hidden layers in the encoder(s) (if ``encoder_covariates`` is ``True``) and the decoder prior to passing through the next layer. batch_representation ``EXPERIMENTAL`` Method for encoding batch information. One of the following: * ``"one-hot"``: represent batches with one-hot encodings. * ``"embedding"``: represent batches with continuously-valued embeddings using :class:`~scvi.nn.Embedding`. Note that batch representations are only passed into the encoder(s) if ``encode_covariates`` is ``True``. use_batch_norm Specifies where to use :class:`~torch.nn.BatchNorm1d` in the model. One of the following: * ``"none"``: don't use batch norm in either encoder(s) or decoder. * ``"encoder"``: use batch norm only in the encoder(s). * ``"decoder"``: use batch norm only in the decoder. * ``"both"``: use batch norm in both encoder(s) and decoder. Note: if ``use_layer_norm`` is also specified, both will be applied (first :class:`~torch.nn.BatchNorm1d`, then :class:`~torch.nn.LayerNorm`). use_layer_norm Specifies where to use :class:`~torch.nn.LayerNorm` in the model. One of the following: * ``"none"``: don't use layer norm in either encoder(s) or decoder. * ``"encoder"``: use layer norm only in the encoder(s). * ``"decoder"``: use layer norm only in the decoder. * ``"both"``: use layer norm in both encoder(s) and decoder. Note: if ``use_batch_norm`` is also specified, both will be applied (first :class:`~torch.nn.BatchNorm1d`, then :class:`~torch.nn.LayerNorm`). use_size_factor_key If ``True``, use the :attr:`~anndata.AnnData.obs` column as defined by the ``size_factor_key`` parameter in the model's ``setup_anndata`` method as the scaling factor in the mean of the conditional distribution. Takes priority over ``use_observed_lib_size``. use_observed_lib_size If ``True``, use the observed library size for RNA as the scaling factor in the mean of the conditional distribution. library_log_means :class:`~numpy.ndarray` of shape ``(1, n_batch)`` of means of the log library sizes that parameterize the prior on library size if ``use_size_factor_key`` is ``False`` and ``use_observed_lib_size`` is ``False``. library_log_vars :class:`~numpy.ndarray` of shape ``(1, n_batch)`` of variances of the log library sizes that parameterize the prior on library size if ``use_size_factor_key`` is ``False`` and ``use_observed_lib_size`` is ``False``. extra_decoder_kwargs Additional keyword arguments passed into :class:`~scvi.nn.DecoderSCVI`. batch_embedding_kwargs Keyword arguments passed into :class:`~scvi.nn.Embedding` if ``batch_representation`` is set to ``"embedding"``. Notes ----- Lifecycle: argument ``batch_representation`` is experimental in v1.2. """
[docs] def __init__( self, n_input: int, ############################## n_output_niche: int, ############################## n_batch: int = 0, n_labels: int = 0, n_hidden: int = 128, n_latent: int = 10, n_layers: int = 1, ############################## n_layers_niche: int = 1, n_layers_compo: int = 1, n_hidden_niche: int = 128, n_hidden_compo: int = 128, ############################## n_continuous_cov: int = 0, n_cats_per_cov: list[int] | None = None, dropout_rate: float = 0.1, dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"] = "gene", log_variational: bool = True, gene_likelihood: Literal["zinb", "nb", "poisson"] = "poisson", latent_distribution: Literal["normal", "ln"] = "normal", ############################## niche_likelihood: Literal["poisson", "gaussian"] = "gaussian", cell_rec_weight: float = 1.0, latent_kl_weight: float = 1.0, spatial_weight: float = 10, ############################## prior_mixture: bool = False, prior_mixture_k: int = 20, semisupervised: bool = True, linear_classifier: bool = True, ############################## inpute_covariates_niche_decoder: bool = True, encode_covariates: bool = False, deeply_inject_covariates: bool = True, batch_representation: Literal["one-hot", "embedding"] = "one-hot", use_batch_norm: Literal["encoder", "decoder", "none", "both"] = "none", use_layer_norm: Literal["encoder", "decoder", "none", "both"] = "both", use_size_factor_key: bool = False, use_observed_lib_size: bool = True, library_log_means: np.ndarray | None = None, library_log_vars: np.ndarray | None = None, batch_embedding_kwargs: dict | None = None, extra_decoder_kwargs: dict | None = None, extra_encoder_kwargs: dict | None = None, **vae_kwargs, ): super().__init__( n_input=n_input, n_batch=n_batch, n_labels=n_labels, n_hidden=n_hidden, n_latent=n_latent, n_layers=n_layers, n_continuous_cov=n_continuous_cov, n_cats_per_cov=n_cats_per_cov, dropout_rate=dropout_rate, dispersion=dispersion, log_variational=log_variational, gene_likelihood=gene_likelihood, latent_distribution=latent_distribution, encode_covariates=encode_covariates, deeply_inject_covariates=deeply_inject_covariates, batch_representation=batch_representation, use_size_factor_key=use_size_factor_key, use_observed_lib_size=use_observed_lib_size, library_log_means=library_log_means, library_log_vars=library_log_vars, extra_decoder_kwargs=extra_decoder_kwargs, batch_embedding_kwargs=batch_embedding_kwargs, extra_encoder_kwargs=extra_encoder_kwargs, use_batch_norm=use_batch_norm, use_layer_norm=use_layer_norm, **vae_kwargs, ) self.latent_kl_weight = latent_kl_weight self.cell_rec_weight = cell_rec_weight self.spatial_weight = spatial_weight self.n_output_niche = n_output_niche self.niche_likelihood = niche_likelihood self.prior_mixture = prior_mixture self.semisupervised = semisupervised self.batch_representation = batch_representation if self.batch_representation == "embedding": self.init_embedding(REGISTRY_KEYS.BATCH_KEY, n_batch, **(batch_embedding_kwargs or {})) batch_dim = self.get_embedding(REGISTRY_KEYS.BATCH_KEY).embedding_dim elif self.batch_representation != "one-hot": raise ValueError("`batch_representation` must be one of 'one-hot', 'embedding'.") use_batch_norm_encoder = use_batch_norm == "encoder" or use_batch_norm == "both" use_batch_norm_decoder = use_batch_norm == "decoder" or use_batch_norm == "both" use_layer_norm_encoder = use_layer_norm == "encoder" or use_layer_norm == "both" use_layer_norm_decoder = use_layer_norm == "decoder" or use_layer_norm == "both" if self.prior_mixture is True: if self.semisupervised: prior_mixture_k = n_labels self.prior_mixture_k = prior_mixture_k self.prior_means = torch.nn.Parameter(torch.zeros([prior_mixture_k, n_latent])) self.prior_log_scales = torch.nn.Parameter( torch.zeros([prior_mixture_k, n_latent]) ) self.prior_logits = torch.nn.Parameter(torch.ones([prior_mixture_k])) else: self.prior_mixture_k = prior_mixture_k self.prior_means = torch.nn.Parameter(torch.randn([prior_mixture_k, n_latent])) self.prior_log_scales = torch.nn.Parameter( torch.zeros([prior_mixture_k, n_latent]) - 1.0 ) self.prior_logits = torch.nn.Parameter(torch.ones([prior_mixture_k])) n_input_encoder = n_input + n_continuous_cov * encode_covariates if self.batch_representation == "embedding": n_input_encoder += batch_dim * encode_covariates cat_list = list([] if n_cats_per_cov is None else n_cats_per_cov) else: cat_list = [n_batch] + list([] if n_cats_per_cov is None else n_cats_per_cov) encoder_cat_list = cat_list if encode_covariates else None _extra_encoder_kwargs = extra_encoder_kwargs or {} self.z_encoder = Encoder( n_input_encoder, n_latent, n_cat_list=encoder_cat_list, n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_rate, distribution=latent_distribution, inject_covariates=deeply_inject_covariates, use_batch_norm=use_batch_norm_encoder, use_layer_norm=use_layer_norm_encoder, return_dist=True, **_extra_encoder_kwargs, ) n_input_decoder = n_latent + n_continuous_cov if self.batch_representation == "embedding": n_input_decoder += batch_dim _extra_decoder_kwargs = extra_decoder_kwargs or {} self.niche_decoder = NicheDecoder( n_input=n_input_decoder, n_output=n_output_niche, n_niche_components=n_labels, n_cat_list=cat_list if inpute_covariates_niche_decoder else None, n_layers=n_layers_niche, n_hidden=n_hidden_niche, inject_covariates=deeply_inject_covariates, use_batch_norm=use_batch_norm_decoder, use_layer_norm=use_layer_norm_decoder, dropout_rate=dropout_rate, **_extra_decoder_kwargs, ) self.composition_decoder = DirichletDecoder( n_input_decoder, n_labels, n_cat_list=None, # do not batch-correct the cell type proportions. n_layers=n_layers_compo, n_hidden=n_hidden_compo, inject_covariates=deeply_inject_covariates, use_batch_norm=use_batch_norm_decoder, use_layer_norm=use_layer_norm_decoder, **_extra_decoder_kwargs, ) if self.semisupervised: # Classifier takes n_latent as input cls_parameters = { "n_layers": 0 if linear_classifier else n_layers, "n_hidden": 0 if linear_classifier else n_hidden, "dropout_rate": dropout_rate, "logits": True, # no Softmax } self.classifier = Classifier( n_latent, n_labels=n_labels, use_batch_norm=use_batch_norm_encoder, use_layer_norm=use_layer_norm_encoder, **cls_parameters, ) else: self.classifier = None
@auto_move_data def generative( self, z: torch.Tensor, library: torch.Tensor, batch_index: torch.Tensor, cont_covs: torch.Tensor | None = None, cat_covs: torch.Tensor | None = None, size_factor: torch.Tensor | None = None, y: torch.Tensor | None = None, transform_batch: torch.Tensor | None = None, ) -> dict[str, Distribution | None]: """Run the generative process.""" from scvi.distributions import ( NegativeBinomial, Poisson, ZeroInflatedNegativeBinomial, ) from torch.distributions import Categorical, Independent, MixtureSameFamily, Normal from torch.nn.functional import linear # TODO: refactor forward function to not rely on y # Likelihood distribution if cont_covs is None: decoder_input = z elif z.dim() != cont_covs.dim(): decoder_input = torch.cat( [z, cont_covs.unsqueeze(0).expand(z.size(0), -1, -1)], dim=-1 ) else: decoder_input = torch.cat([z, cont_covs], dim=-1) categorical_input = torch.split(cat_covs, 1, dim=1) if cat_covs is not None else () if transform_batch is not None: batch_index = torch.ones_like(batch_index) * transform_batch if not self.use_size_factor_key: size_factor = library if self.batch_representation == "embedding": batch_rep = self.compute_embedding(REGISTRY_KEYS.BATCH_KEY, batch_index) decoder_input = torch.cat([decoder_input, batch_rep], dim=-1) px_scale, px_r, px_rate, px_dropout = self.decoder( self.dispersion, decoder_input, size_factor, *categorical_input, y, ) else: px_scale, px_r, px_rate, px_dropout = self.decoder( self.dispersion, decoder_input, size_factor, batch_index, *categorical_input, y, ) if self.dispersion == "gene-label": px_r = linear( one_hot(y, self.n_labels).float(), self.px_r ) # px_r gets transposed - last dimension is nb genes elif self.dispersion == "gene-batch": px_r = linear(one_hot(batch_index, self.n_batch).float(), self.px_r) elif self.dispersion == "gene": px_r = self.px_r px_r = torch.exp(px_r) if self.gene_likelihood == "zinb": px = ZeroInflatedNegativeBinomial( mu=px_rate, theta=px_r, zi_logits=px_dropout, scale=px_scale, ) elif self.gene_likelihood == "nb": px = NegativeBinomial(mu=px_rate, theta=px_r, scale=px_scale) elif self.gene_likelihood == "poisson": px = Poisson(px_rate, scale=px_scale) # Priors if self.use_observed_lib_size: pl = None else: ( local_library_log_means, local_library_log_vars, ) = self._compute_local_library_params(batch_index) pl = Normal(local_library_log_means, local_library_log_vars.sqrt()) if self.prior_mixture is True: u_prior_logits = self.prior_logits u_prior_means = self.prior_means u_prior_scales = torch.exp(self.prior_log_scales) + 1e-4 if self.semisupervised: logits_input = ( torch.stack( [ torch.nn.functional.one_hot(y_i, self.n_labels) if y_i < self.n_labels else torch.zeros(self.n_labels) for y_i in y.ravel() ] ) .to(z.device) .float() ) u_prior_logits = u_prior_logits + 10 * logits_input u_prior_means = u_prior_means.expand(y.shape[0], -1, -1) u_prior_scales = u_prior_scales.expand(y.shape[0], -1, -1) cats = Categorical(logits=u_prior_logits) normal_dists = Independent( Normal(u_prior_means, u_prior_scales), reinterpreted_batch_ndims=1 ) pz = MixtureSameFamily(cats, normal_dists) else: pz = Normal(torch.zeros_like(z), torch.ones_like(z)) niche_composition = self.composition_decoder( decoder_input, batch_index, *categorical_input ) # DirichletDecoder, niche_composition is a distribution niche_mean, niche_variance = self.niche_decoder( decoder_input, batch_index, *categorical_input ) if self.niche_likelihood == "poisson": niche_expression = torch.distributions.Poisson(niche_variance) else: niche_expression = Normal(niche_mean, niche_variance) return { MODULE_KEYS.PX_KEY: px, MODULE_KEYS.PL_KEY: pl, MODULE_KEYS.PZ_KEY: pz, SCVIVA_MODULE_KEYS.NICHE_MEAN: niche_mean, SCVIVA_MODULE_KEYS.NICHE_VARIANCE: niche_variance, SCVIVA_MODULE_KEYS.P_NICHE_EXPRESSION: niche_expression, SCVIVA_MODULE_KEYS.P_NICHE_COMPOSITION: niche_composition, } def loss( self, tensors: dict[str, torch.Tensor], inference_outputs: dict[str, torch.Tensor | Distribution | None], generative_outputs: dict[str, torch.Tensor | Distribution | None], kl_weight: float = 1.0, classification_ratio=50, epsilon: float = 1e-6, n_samples_mixture: int = 10, ) -> NicheLossOutput: """Compute the loss.""" from torch.distributions import kl_divergence x = tensors[REGISTRY_KEYS.X_KEY] if self.semisupervised: y = tensors[REGISTRY_KEYS.LABELS_KEY].ravel().long() z_mean = inference_outputs[MODULE_KEYS.QZ_KEY].loc y_ct = self.classifier(z_mean) classification_loss = torch.nn.functional.cross_entropy(y_ct, y, reduction="none") if self.prior_mixture is True: z = inference_outputs[MODULE_KEYS.QZ_KEY].rsample( sample_shape=(n_samples_mixture,) ) # sample multiple times # sample x n_obs x n_latent kl_divergence_z = ( inference_outputs[MODULE_KEYS.QZ_KEY].log_prob(z).sum(-1) - generative_outputs[MODULE_KEYS.PZ_KEY].log_prob(z) ).mean(0) else: kl_divergence_z = kl_divergence( inference_outputs[MODULE_KEYS.QZ_KEY], generative_outputs[MODULE_KEYS.PZ_KEY], ).sum(dim=-1) if not self.use_observed_lib_size: kl_divergence_l = kl_divergence( inference_outputs[MODULE_KEYS.QL_KEY], generative_outputs[MODULE_KEYS.PL_KEY], ).sum(dim=1) else: kl_divergence_l = torch.zeros_like(kl_divergence_z) reconst_loss_cell = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1) if self.semisupervised: reconst_loss_cell = reconst_loss_cell + classification_ratio * classification_loss kl_local_for_warmup = kl_divergence_z kl_local_no_warmup = kl_divergence_l weighted_kl_local = kl_weight * kl_local_for_warmup + kl_local_no_warmup niche_weights = tensors[SCVIVA_REGISTRY_KEYS.NICHE_COMPOSITION_KEY] niche_weights = (niche_weights > 0).float() z1_mean_niche = tensors[ SCVIVA_REGISTRY_KEYS.Z1_MEAN_CT_KEY ] # batch times cell_types times n_latent reconst_loss_niche = ( -generative_outputs[SCVIVA_MODULE_KEYS.P_NICHE_EXPRESSION] .log_prob(z1_mean_niche) .sum(dim=(-1)) ) masked_reconst_loss_niche = (reconst_loss_niche * niche_weights).sum(dim=-1) true_niche_composition = tensors[SCVIVA_REGISTRY_KEYS.NICHE_COMPOSITION_KEY] + epsilon true_niche_composition = true_niche_composition / true_niche_composition.sum( dim=-1, keepdim=True, ) reconst_niche_composition = generative_outputs[SCVIVA_MODULE_KEYS.P_NICHE_COMPOSITION] composition_loss = -reconst_niche_composition.log_prob(true_niche_composition) _weighted_reconst_loss_cell = self.cell_rec_weight * reconst_loss_cell _weighted_reconst_loss_niche = self.spatial_weight * masked_reconst_loss_niche _weighted_composition_loss = self.spatial_weight * composition_loss _weighted_kl_local = self.latent_kl_weight * weighted_kl_local loss = torch.mean( _weighted_reconst_loss_cell + _weighted_reconst_loss_niche + _weighted_kl_local + _weighted_composition_loss ) return NicheLossOutput( loss=loss, reconstruction_loss=reconst_loss_cell, classification_loss=classification_loss.mean() if self.semisupervised else None, true_labels=y if self.semisupervised else None, logits=y_ct if self.semisupervised else None, kl_local={ MODULE_KEYS.KL_L_KEY: kl_divergence_l, MODULE_KEYS.KL_Z_KEY: kl_divergence_z, }, composition_loss=composition_loss, niche_loss=masked_reconst_loss_niche, extra_metrics={ SCVIVA_MODULE_KEYS.NLL_NICHE_COMPOSITION_KEY: torch.mean(composition_loss), SCVIVA_MODULE_KEYS.NLL_NICHE_EXPRESSION_KEY: torch.mean(masked_reconst_loss_niche), }, )
@dataclass class NicheLossOutput(LossOutput): """Modify loss output to record niche losses.""" composition_loss: LossRecord | None = None niche_loss: LossRecord | None = None def __post_init__(self): super().__post_init__() default = 0 * self.loss if self.composition_loss is None: object.__setattr__(self, "composition_loss", default) if self.niche_loss is None: object.__setattr__(self, "niche_loss", default) object.__setattr__(self, "composition_loss", self._as_dict("composition_loss")) object.__setattr__(self, "niche_loss", self._as_dict("niche_loss")) def compute_composition_error( module, dataloader, return_mean: bool = True, **kwargs, ): """Compute the composition prediction error on the data. The error is the negative log likelihood of the data (alpha) given the latent variables. Parameters ---------- module A callable that takes a dictionary of tensors and returns a tuple whose last element is a :class:`NicheLossOutput`. dataloader Iterator over minibatches of data formatted as expected by ``module.forward``. return_mean If ``True``, return the mean error across the dataset; otherwise per-cell. **kwargs Additional keyword arguments forwarded to ``module``. Returns ------- The composition prediction error. """ import torch composition_loss = [] for tensors in dataloader: _, _, losses = module(tensors, **kwargs) if isinstance(losses.composition_loss, dict): recon = torch.stack(list(losses.composition_loss.values())).sum(dim=0) else: recon = losses.composition_loss composition_loss.append(recon) composition_loss = torch.cat(composition_loss, dim=0) if return_mean: composition_loss = composition_loss.mean() return composition_loss def compute_niche_error( module, dataloader, return_mean: bool = True, **kwargs, ): """Compute the niche state prediction error on the data. The error is the negative log likelihood of the data (eta) given the latent variables. Parameters ---------- module A callable that takes a dictionary of tensors and returns a tuple whose last element is a :class:`NicheLossOutput`. dataloader Iterator over minibatches of data formatted as expected by ``module.forward``. return_mean If ``True``, return the mean error across the dataset; otherwise per-cell. **kwargs Additional keyword arguments forwarded to ``module``. Returns ------- The niche state prediction error. """ import torch niche_loss = [] for tensors in dataloader: _, _, losses = module(tensors, **kwargs) if isinstance(losses.niche_loss, dict): recon = torch.stack(list(losses.niche_loss.values())).sum(dim=0) else: recon = losses.niche_loss niche_loss.append(recon) niche_loss = torch.cat(niche_loss, dim=0) if return_mean: niche_loss = niche_loss.mean() return niche_loss