Source code for spatialvi.module._jvae

"""Joint VAE (JVAE) module for GIMVI."""

from __future__ import annotations

from typing import TYPE_CHECKING

import torch
import torch.nn.functional as F
from scvi import REGISTRY_KEYS

if TYPE_CHECKING:
    import numpy as np
from scvi.distributions import NegativeBinomial, ZeroInflatedNegativeBinomial
from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data
from scvi.nn import Encoder, MultiDecoder, MultiEncoder
from torch.distributions import Normal, Poisson
from torch.distributions import kl_divergence as kl
from torch.nn import ModuleList

torch.backends.cudnn.benchmark = True


[docs] class JVAE(BaseModuleClass): """Joint variational auto-encoder for imputing missing genes in spatial data. Implementation of gimVI :cite:p:`Lopez19`. Parameters ---------- dim_input_list List of number of input genes for each dataset. total_genes Total number of different genes. indices_mappings List of mappings from model inputs to model output locations. gene_likelihoods List of distributions: 'zinb', 'nb', or 'poisson'. model_library_bools Whether to model library size with a latent variable per dataset. library_log_means List of 1 x n_batch arrays of log library size means. library_log_vars List of 1 x n_batch arrays of log library size variances. n_latent Dimension of latent space. n_layers_encoder_individual Number of individual encoder layers. n_layers_encoder_shared Number of shared encoder layers. dim_hidden_encoder Hidden layer dimension for encoder. n_layers_decoder_individual Number of individual decoder layers. n_layers_decoder_shared Number of shared decoder layers. dim_hidden_decoder_individual Hidden layer dimension for individual decoder. dim_hidden_decoder_shared Hidden layer dimension for shared decoder. dropout_rate_encoder Dropout rate for encoder. dropout_rate_decoder Dropout rate for decoder. n_batch Total number of batches. n_labels Total number of labels. dispersion Dispersion parameterization: 'gene', 'gene-batch', 'gene-label', or 'gene-cell'. log_variational Log(data+1) prior to encoding for numerical stability. """
[docs] def __init__( self, dim_input_list: list[int], total_genes: int, indices_mappings: list[np.ndarray | slice], gene_likelihoods: list[str], model_library_bools: list[bool], library_log_means: list[np.ndarray | None], library_log_vars: list[np.ndarray | None], n_latent: int = 10, n_layers_encoder_individual: int = 1, n_layers_encoder_shared: int = 1, dim_hidden_encoder: int = 64, n_layers_decoder_individual: int = 0, n_layers_decoder_shared: int = 0, dim_hidden_decoder_individual: int = 64, dim_hidden_decoder_shared: int = 64, dropout_rate_encoder: float = 0.2, dropout_rate_decoder: float = 0.2, n_batch: int = 0, n_labels: int = 0, dispersion: str = "gene-batch", log_variational: bool = True, ): super().__init__() self.n_input_list = dim_input_list self.total_genes = total_genes self.indices_mappings = indices_mappings self.gene_likelihoods = gene_likelihoods self.model_library_bools = model_library_bools for mode in range(len(dim_input_list)): if self.model_library_bools[mode]: self.register_buffer( f"library_log_means_{mode}", torch.from_numpy(library_log_means[mode]).float(), ) self.register_buffer( f"library_log_vars_{mode}", torch.from_numpy(library_log_vars[mode]).float(), ) self.n_latent = n_latent self.n_batch = n_batch self.n_labels = n_labels self.dispersion = dispersion self.log_variational = log_variational self.z_encoder = MultiEncoder( n_heads=len(dim_input_list), n_input_list=dim_input_list, n_output=self.n_latent, n_hidden=dim_hidden_encoder, n_layers_individual=n_layers_encoder_individual, n_layers_shared=n_layers_encoder_shared, dropout_rate=dropout_rate_encoder, return_dist=True, ) self.l_encoders = ModuleList( [ Encoder( self.n_input_list[i], 1, n_layers=1, dropout_rate=dropout_rate_encoder, return_dist=True, ) if self.model_library_bools[i] else None for i in range(len(self.n_input_list)) ] ) self.decoder = MultiDecoder( self.n_latent, self.total_genes, n_hidden_conditioned=dim_hidden_decoder_individual, n_hidden_shared=dim_hidden_decoder_shared, n_layers_conditioned=n_layers_decoder_individual, n_layers_shared=n_layers_decoder_shared, n_cat_list=[self.n_batch], dropout_rate=dropout_rate_decoder, ) if self.dispersion == "gene": self.px_r = torch.nn.Parameter(torch.randn(self.total_genes)) elif self.dispersion == "gene-batch": self.px_r = torch.nn.Parameter(torch.randn(self.total_genes, n_batch)) elif self.dispersion == "gene-label": self.px_r = torch.nn.Parameter(torch.randn(self.total_genes, n_labels))
def sample_from_posterior_z( self, x: torch.Tensor, mode: int = None, deterministic: bool = False ) -> torch.Tensor: """Sample tensor of latent values from the posterior.""" if mode is None: if len(self.n_input_list) == 1: mode = 0 else: raise Exception("Must provide a mode when having multiple datasets") outputs = self.inference(x, mode) qz_m = outputs["qz"].loc z = outputs["z"] if deterministic: z = qz_m return z def sample_from_posterior_l( self, x: torch.Tensor, mode: int = None, deterministic: bool = False ) -> torch.Tensor: """Sample the tensor of library sizes from the posterior.""" inference_out = self.inference(x, mode) return ( inference_out["ql"].loc if (deterministic and inference_out["ql"] is not None) else inference_out["library"] ) def sample_scale( self, x: torch.Tensor, mode: int, batch_index: torch.Tensor, y: torch.Tensor | None = None, deterministic: bool = False, decode_mode: int | None = None, ) -> torch.Tensor: """Return the tensor of predicted frequencies of expression.""" gen_out = self._run_forward( x, mode, batch_index, y=y, deterministic=deterministic, decode_mode=decode_mode ) return gen_out["px_scale"] def get_sample_rate(self, x, batch_index, *_, **__): """Get the sample rate for the model.""" return self.sample_rate(x, 0, batch_index) def _run_forward( self, x: torch.Tensor, mode: int, batch_index: torch.Tensor, y: torch.Tensor | None = None, deterministic: bool = False, decode_mode: int = None, ) -> dict: """Run the forward pass of the model.""" if decode_mode is None: decode_mode = mode inference_out = self.inference(x, mode) if deterministic: z = inference_out["qz"].loc if inference_out["ql"] is not None: library = inference_out["ql"].loc else: library = inference_out["library"] else: z = inference_out["z"] library = inference_out["library"] gen_out = self.generative(z, library, batch_index, y, decode_mode) return gen_out def sample_rate( self, x: torch.Tensor, mode: int, batch_index: torch.Tensor, y: torch.Tensor | None = None, deterministic: bool = False, decode_mode: int = None, ) -> torch.Tensor: """Return the tensor of scaled frequencies of expression.""" gen_out = self._run_forward( x, mode, batch_index, y=y, deterministic=deterministic, decode_mode=decode_mode ) return gen_out["px_rate"] def reconstruction_loss( self, x: torch.Tensor, px_rate: torch.Tensor, px_r: torch.Tensor, px_dropout: torch.Tensor, mode: int, ) -> torch.Tensor: """Compute the reconstruction loss.""" reconstruction_loss = None if self.gene_likelihoods[mode] == "zinb": reconstruction_loss = ( -ZeroInflatedNegativeBinomial(mu=px_rate, theta=px_r, zi_logits=px_dropout) .log_prob(x) .sum(dim=-1) ) elif self.gene_likelihoods[mode] == "nb": reconstruction_loss = -NegativeBinomial(mu=px_rate, theta=px_r).log_prob(x).sum(dim=-1) elif self.gene_likelihoods[mode] == "poisson": reconstruction_loss = -Poisson(px_rate).log_prob(x).sum(dim=1) return reconstruction_loss def _get_inference_input(self, tensors) -> dict[str, torch.Tensor | None]: """Get the input for the inference model.""" return { "x": tensors[REGISTRY_KEYS.X_KEY], "batch_index": tensors.get(REGISTRY_KEYS.BATCH_KEY, None), } def _get_generative_input(self, tensors, inference_outputs, transform_batch=None): """Get the input for the generative model.""" z = inference_outputs["z"] library = inference_outputs["library"] batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] y = tensors[REGISTRY_KEYS.LABELS_KEY] if transform_batch is not None: batch_index = torch.ones_like(batch_index) * transform_batch return {"z": z, "library": library, "batch_index": batch_index, "y": y} @auto_move_data def inference( self, x: torch.Tensor, mode: int | None = 0, n_samples: int | None = 1, batch_index: torch.Tensor | None = None, ) -> dict: """Run the inference model.""" x_ = x if self.log_variational: x_ = torch.log(1 + x_) qz, z = self.z_encoder(x_, mode) ql, library = None, None if self.model_library_bools[mode]: ql, library = self.l_encoders[mode](x_) else: library = torch.log(torch.sum(x, dim=1)).view(-1, 1) if n_samples > 1: untran_z = qz.sample((n_samples,)) z = self.z_encoder.z_transformation(untran_z) return {"qz": qz, "z": z, "ql": ql, "library": library} @auto_move_data def generative( self, z: torch.Tensor, library: torch.Tensor, batch_index: torch.Tensor | None = None, y: torch.Tensor | None = None, mode: int | None = 0, transform_batch: torch.Tensor | None = None, ) -> dict: """Run the generative model.""" px_scale, px_r, px_rate, px_dropout = self.decoder( z, mode, library, self.dispersion, batch_index, y ) if self.dispersion == "gene-label": px_r = F.linear(F.one_hot(y.squeeze(-1).long(), self.n_labels).float(), self.px_r) elif self.dispersion == "gene-batch": px_r = F.linear(F.one_hot(batch_index.squeeze(-1), self.n_batch).float(), self.px_r) elif self.dispersion == "gene": px_r = self.px_r.view(1, self.px_r.size(0)) px_r = torch.exp(px_r) px_scale = px_scale / torch.sum(px_scale[:, self.indices_mappings[mode]], dim=1).view( -1, 1 ) px_rate = px_scale * torch.exp(library) if transform_batch is not None: batch_index = torch.ones_like(batch_index) * transform_batch px = NegativeBinomial(mu=px_rate, theta=px_r, scale=px_scale) return { "px_scale": px_scale, "px": px, "px_r": px_r, "px_rate": px_rate, "px_dropout": px_dropout, "batch_index": batch_index, } def loss( self, tensors, inference_outputs, generative_outputs, mode: int | None = None, kl_weight: float = 1.0, ) -> LossOutput: """Return the reconstruction loss and the Kullback divergences.""" if mode is None: if len(self.n_input_list) == 1: mode = 0 else: raise Exception("Must provide a mode") x = tensors[REGISTRY_KEYS.X_KEY] batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] qz = inference_outputs["qz"] ql = inference_outputs["ql"] px_rate = generative_outputs["px_rate"] px_r = generative_outputs["px_r"] px_dropout = generative_outputs["px_dropout"] mapping_indices = self.indices_mappings[mode] reconstruction_loss = self.reconstruction_loss( x, px_rate[:, mapping_indices], px_r[:, mapping_indices], px_dropout[:, mapping_indices], mode, ) mean = torch.zeros_like(qz.loc) scale = torch.ones_like(qz.scale) kl_divergence_z = kl(qz, Normal(mean, scale)).sum(dim=1) if self.model_library_bools[mode]: library_log_means = getattr(self, f"library_log_means_{mode}") library_log_vars = getattr(self, f"library_log_vars_{mode}") local_library_log_means = F.linear( F.one_hot(batch_index.squeeze(-1), self.n_batch).float(), library_log_means ) local_library_log_vars = F.linear( F.one_hot(batch_index.squeeze(-1), self.n_batch).float(), library_log_vars ) kl_divergence_l = kl( ql, Normal(local_library_log_means, local_library_log_vars.sqrt()), ).sum(dim=1) else: kl_divergence_l = torch.zeros_like(kl_divergence_z) kl_local = kl_divergence_l + kl_divergence_z loss = torch.mean(reconstruction_loss + kl_weight * kl_local) * x.size(0) return LossOutput(loss=loss, reconstruction_loss=reconstruction_loss, kl_local=kl_local)