"""GIMVI model for imputing missing genes across spatial and scRNA-seq data."""
from __future__ import annotations
import logging
import os
import warnings
from itertools import cycle
from typing import TYPE_CHECKING
import numpy as np
import torch
from scvi import REGISTRY_KEYS, settings
from scvi.data import AnnDataManager
from scvi.data._constants import _MODEL_NAME_KEY, _SETUP_ARGS_KEY
from scvi.data.fields import CategoricalObsField, LayerField
from scvi.dataloaders import DataSplitter
from scvi.model._utils import _init_library_size, parse_device_args
from scvi.train import Trainer
from scvi.train._config import merge_kwargs
from scvi.utils import setup_anndata_dsp
from scvi.utils._docstrings import devices_dsp
from torch.utils.data import DataLoader
from spatialvi.model.base._spatial_base import SpatialBaseModel
from spatialvi.model.utils._gimvi_utils import _load_saved_gimvi_files
from spatialvi.module._jvae import JVAE
from spatialvi.train._gimvi_trainingplans import GIMVITrainingPlan
if TYPE_CHECKING:
from anndata import AnnData
from scvi.dataloaders import AnnDataLoader
logger = logging.getLogger(__name__)
def _unpack_tensors(tensors):
x = tensors[REGISTRY_KEYS.X_KEY].squeeze_(0)
batch_index = tensors[REGISTRY_KEYS.BATCH_KEY].squeeze_(0)
y = tensors[REGISTRY_KEYS.LABELS_KEY].squeeze_(0)
return x, batch_index, y
[docs]
class GIMVI(SpatialBaseModel):
"""Joint VAE for imputing missing genes in spatial data :cite:p:`Lopez19`.
Learns a joint latent space for paired scRNA-seq and spatial transcriptomics
data, enabling imputation of spatially unmeasured genes.
Parameters
----------
adata_seq
AnnData object registered via :meth:`~spatialvi.model.GIMVI.setup_anndata`
containing scRNA-seq data.
adata_spatial
AnnData object registered via :meth:`~spatialvi.model.GIMVI.setup_anndata`
containing spatial transcriptomics data.
generative_distributions
List of generative distributions for seq and spatial data.
Defaults to ``['zinb', 'nb']``.
model_library_size
Whether to model library size per dataset. Defaults to ``[True, False]``.
n_latent
Dimensionality of the latent space.
**model_kwargs
Keyword args for :class:`~spatialvi.module.JVAE`.
Examples
--------
>>> adata_seq = anndata.read_h5ad(path_to_seq)
>>> adata_spatial = anndata.read_h5ad(path_to_spatial)
>>> spatialvi.model.GIMVI.setup_anndata(adata_seq)
>>> spatialvi.model.GIMVI.setup_anndata(adata_spatial)
>>> model = spatialvi.model.GIMVI(adata_seq, adata_spatial)
>>> model.train(max_epochs=200)
Notes
-----
See further usage examples in the following tutorial:
1. :doc:`/tutorials/notebooks/spatial/gimvi_tutorial`
"""
[docs]
def __init__(
self,
adata_seq: AnnData,
adata_spatial: AnnData,
generative_distributions: list[str] | None = None,
model_library_size: list[bool] | None = None,
n_latent: int = 10,
**model_kwargs,
):
super().__init__(adata_seq)
if adata_seq is adata_spatial:
raise ValueError(
"`adata_seq` and `adata_spatial` cannot point to the same object. "
"If you would really like to do this, make a copy and pass it as `adata_spatial`."
)
model_library_size = model_library_size or [True, False]
generative_distributions = generative_distributions or ["zinb", "nb"]
self.adatas = [adata_seq, adata_spatial]
self.adata_managers = {
"seq": self._get_most_recent_anndata_manager(adata_seq, required=True),
"spatial": self._get_most_recent_anndata_manager(adata_spatial, required=True),
}
self.registries_ = []
for adm in self.adata_managers.values():
self._register_manager_for_instance(adm)
self.registries_.append(adm.registry)
seq_var_names = adata_seq.var_names
spatial_var_names = adata_spatial.var_names
if not set(spatial_var_names) <= set(seq_var_names):
raise ValueError("spatial genes must be a subset of seq genes")
spatial_gene_loc = [np.argwhere(seq_var_names == g)[0] for g in spatial_var_names]
spatial_gene_loc = np.concatenate(spatial_gene_loc)
gene_mappings = [slice(None), spatial_gene_loc]
sum_stats = [adm.summary_stats for adm in self.adata_managers.values()]
n_inputs = [s["n_vars"] for s in sum_stats]
total_genes = n_inputs[0]
adata_seq_n_batches = sum_stats[0]["n_batch"]
adata_spatial_batch = adata_spatial.obs[
self.adata_managers["spatial"].data_registry[REGISTRY_KEYS.BATCH_KEY].attr_key
]
if np.min(adata_spatial_batch) == 0:
adata_spatial.obs[
self.adata_managers["spatial"].data_registry[REGISTRY_KEYS.BATCH_KEY].attr_key
] += adata_seq_n_batches
n_batches = sum(s["n_batch"] for s in sum_stats)
library_log_means = []
library_log_vars = []
for adata_manager in self.adata_managers.values():
adata_library_log_means, adata_library_log_vars = _init_library_size(
adata_manager, n_batches
)
library_log_means.append(adata_library_log_means)
library_log_vars.append(adata_library_log_vars)
self.module = JVAE(
n_inputs,
total_genes,
gene_mappings,
generative_distributions,
model_library_size,
library_log_means,
library_log_vars,
n_batch=n_batches,
n_latent=n_latent,
**model_kwargs,
)
self._model_summary_string = (
"GIMVI Model with the following params: \n"
f"n_latent: {n_latent}, n_inputs: {n_inputs}, n_genes: {total_genes}, "
f"n_batch: {n_batches}, generative distributions: {generative_distributions}"
)
self.init_params_ = self._get_init_params(locals())
# ------------------------------------------------------------------ #
# SpatialData integration (override to clarify spatial-adata role)
# ------------------------------------------------------------------ #
@classmethod
def setup_spatialdata(
cls,
sdata,
table_key: str = "table",
region: str | None = None,
**kwargs,
) -> None:
"""Register the *spatial* adata component from a SpatialData object.
Calls :meth:`setup_anndata` on the extracted table. The scRNA-seq
component must always be registered via :meth:`setup_anndata` directly.
Parameters
----------
sdata
A :class:`spatialdata.SpatialData` object.
table_key
Key in ``sdata`` pointing to the spatial AnnData table.
region
Region name to subset (stored in ``sdata[table_key].obs``).
**kwargs
Passed to :meth:`setup_anndata`.
"""
super().setup_spatialdata(sdata, table_key=table_key, region=region, **kwargs)
# ------------------------------------------------------------------ #
# Training
# ------------------------------------------------------------------ #
@devices_dsp.dedent
def train(
self,
max_epochs: int = 200,
accelerator: str = "auto",
devices: int | list[int] | str = "auto",
kappa: int = 5,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
batch_size: int = 128,
datasplitter_kwargs: dict | None = None,
plan_kwargs: dict | None = None,
**kwargs,
):
"""Train the model.
Parameters
----------
max_epochs
Number of passes through the dataset.
%(param_accelerator)s
%(param_devices)s
kappa
Scaling parameter for the discriminator loss.
train_size
Size of training set in the range [0.0, 1.0].
validation_size
Size of the test set.
shuffle_set_split
Whether to shuffle indices before splitting.
batch_size
Minibatch size.
datasplitter_kwargs
Additional kwargs for :class:`~scvi.dataloaders.DataSplitter`.
plan_kwargs
Keyword args for the training plan.
**kwargs
Other keyword args for :class:`~scvi.train.Trainer`.
"""
accelerator, devices, device = parse_device_args(
accelerator=accelerator,
devices=devices,
return_device="torch",
)
datasplitter_kwargs = datasplitter_kwargs or {}
self.trainer = Trainer(
max_epochs=max_epochs,
accelerator=accelerator,
devices=devices,
**kwargs,
)
self.train_indices_, self.test_indices_, self.validation_indices_ = [], [], []
train_dls, test_dls, val_dls = [], [], []
for i, adm in enumerate(self.adata_managers.values()):
ds = DataSplitter(
adm,
train_size=train_size,
validation_size=validation_size,
batch_size=batch_size,
shuffle_set_split=shuffle_set_split,
**datasplitter_kwargs,
)
ds.setup()
train_dls.append(ds.train_dataloader())
test_dls.append(ds.test_dataloader())
val = ds.val_dataloader()
val_dls.append(val)
val.mode = i
self.train_indices_.append(ds.train_idx)
self.test_indices_.append(ds.test_idx)
self.validation_indices_.append(ds.val_idx)
train_dl = TrainDL(train_dls)
plan_kwargs = merge_kwargs(None, plan_kwargs, name="plan")
self._training_plan = GIMVITrainingPlan(
self.module,
adversarial_classifier=True,
scale_adversarial_loss=kappa,
**plan_kwargs,
)
if train_size == 1.0:
self.trainer.fit(self._training_plan, train_dl)
else:
self.trainer.fit(self._training_plan, train_dl, val_dls)
try:
self.history_ = self.trainer.logger.history
except AttributeError:
self.history_ = None
self.module.eval()
self.to_device(device)
self.is_trained_ = True
# ------------------------------------------------------------------ #
# Posterior queries
# ------------------------------------------------------------------ #
def _make_scvi_dls(
self, adatas: list[AnnData] = None, batch_size: int = 128
) -> list[AnnDataLoader]:
if adatas is None:
adatas = self.adatas
post_list = [self._make_data_loader(ad, batch_size=batch_size) for ad in adatas]
for i, dl in enumerate(post_list):
dl.mode = i
return post_list
@torch.inference_mode()
def get_latent_representation(
self,
adatas: list[AnnData] = None,
deterministic: bool = True,
batch_size: int = 128,
backend: str = "cpu",
) -> list[np.ndarray]:
"""Return the latent space embedding for each dataset.
Parameters
----------
adatas
List of [adata_seq, adata_spatial]. If None, uses the training adatas.
deterministic
If True, use the encoder mean instead of a Gaussian sample.
batch_size
Minibatch size for data loading.
backend
``"cpu"`` returns numpy arrays.
``"rapids"`` transfers each latent array to a cupy array for GPU
downstream processing (requires ``pip install cuml``).
Returns
-------
List of arrays, one per dataset (seq then spatial).
"""
if adatas is None:
adatas = self.adatas
scdls = self._make_scvi_dls(adatas, batch_size=batch_size)
self.module.eval()
latents = []
for mode, scdl in enumerate(scdls):
latent = []
for tensors in scdl:
(sample_batch, *_) = _unpack_tensors(tensors)
latent.append(
self.module.sample_from_posterior_z(
sample_batch, mode, deterministic=deterministic
)
.cpu()
.detach()
)
latent_arr = torch.cat(latent).numpy()
latents.append(latent_arr)
if backend == "rapids":
try:
import cupy as cp
return [cp.asarray(z) for z in latents]
except ImportError as e:
raise ImportError(
"backend='rapids' requires cupy. "
"Install with: pip install 'spatialvi-tools[rapids]'"
) from e
return latents
@torch.inference_mode()
def get_imputed_values(
self,
adatas: list[AnnData] = None,
deterministic: bool = True,
normalized: bool = True,
decode_mode: int | None = None,
batch_size: int = 128,
) -> list[np.ndarray]:
"""Return imputed values for all genes for each dataset.
Parameters
----------
adatas
List of adata_seq and adata_spatial.
deterministic
If True, use the encoder mean.
normalized
Return normalized values or raw rates.
decode_mode
If provided, use the decoder of this dataset id for all inputs.
batch_size
Minibatch size.
"""
self.module.eval()
if adatas is None:
adatas = self.adatas
scdls = self._make_scvi_dls(adatas, batch_size=batch_size)
imputed_values = []
for mode, scdl in enumerate(scdls):
imputed_value = []
for tensors in scdl:
(sample_batch, batch_index, label, *_) = _unpack_tensors(tensors)
if normalized:
imputed_value.append(
self.module.sample_scale(
sample_batch,
mode,
batch_index,
label,
deterministic=deterministic,
decode_mode=decode_mode,
)
.cpu()
.detach()
)
else:
imputed_value.append(
self.module.sample_rate(
sample_batch,
mode,
batch_index,
label,
deterministic=deterministic,
decode_mode=decode_mode,
)
.cpu()
.detach()
)
imputed_value = torch.cat(imputed_value).numpy()
imputed_values.append(imputed_value)
return imputed_values
# ------------------------------------------------------------------ #
# Save / Load
# ------------------------------------------------------------------ #
def save(
self,
dir_path: str,
prefix: str | None = None,
overwrite: bool = False,
save_anndata: bool = False,
save_kwargs: dict | None = None,
**anndata_write_kwargs,
):
"""Save the state of the model.
Parameters
----------
dir_path
Path to a directory.
prefix
Prefix to prepend to saved file names.
overwrite
Overwrite existing data or not.
save_anndata
If True, also saves the anndata objects.
save_kwargs
Keyword arguments passed into :func:`~torch.save`.
"""
if not os.path.exists(dir_path) or overwrite:
os.makedirs(dir_path, exist_ok=overwrite)
else:
raise ValueError(
f"{dir_path} already exists. Provide a non-existing directory for saving."
)
file_name_prefix = prefix or ""
save_kwargs = save_kwargs or {}
seq_adata = self.adatas[0]
spatial_adata = self.adatas[1]
if save_anndata:
seq_adata.write(os.path.join(dir_path, f"{file_name_prefix}adata_seq.h5ad"))
spatial_adata.write(os.path.join(dir_path, f"{file_name_prefix}adata_spatial.h5ad"))
model_state_dict = self.module.state_dict()
seq_var_names = seq_adata.var_names.astype(str).to_numpy()
spatial_var_names = spatial_adata.var_names.astype(str).to_numpy()
user_attributes = self._get_user_attributes()
user_attributes = {a[0]: a[1] for a in user_attributes if a[0][-1] == "_"}
torch.save(
{
"model_state_dict": model_state_dict,
"seq_var_names": seq_var_names,
"spatial_var_names": spatial_var_names,
"attr_dict": user_attributes,
},
os.path.join(dir_path, f"{file_name_prefix}model.pt"),
**save_kwargs,
)
@classmethod
@devices_dsp.dedent
def load(
cls,
dir_path: str,
adata_seq: AnnData | None = None,
adata_spatial: AnnData | None = None,
accelerator: str = "auto",
device: int | str = "auto",
prefix: str | None = None,
backup_url: str | None = None,
):
"""Instantiate a model from the saved output.
Parameters
----------
dir_path
Path to saved outputs.
adata_seq
scRNA-seq AnnData. If None, will check for saved anndata.
adata_spatial
Spatial AnnData. If None, will check for saved anndata.
%(param_accelerator)s
%(param_device)s
prefix
Prefix of saved file names.
backup_url
URL to retrieve saved outputs from if not present on disk.
Returns
-------
Model with loaded state dictionaries.
"""
_, _, device = parse_device_args(
accelerator=accelerator,
devices=device,
return_device="torch",
)
(
attr_dict,
seq_var_names,
spatial_var_names,
model_state_dict,
loaded_adata_seq,
loaded_adata_spatial,
) = _load_saved_gimvi_files(
dir_path,
adata_seq is None,
adata_spatial is None,
prefix=prefix,
map_location=device,
backup_url=backup_url,
)
adata_seq = loaded_adata_seq or adata_seq
adata_spatial = loaded_adata_spatial or adata_spatial
adatas = [adata_seq, adata_spatial]
var_names = [seq_var_names, spatial_var_names]
for i, adata in enumerate(adatas):
saved_var_names = var_names[i]
user_var_names = adata.var_names.astype(str)
if not np.array_equal(saved_var_names, user_var_names):
warnings.warn(
"var_names for adata passed in does not match var_names of adata "
"used to train the model. For valid results, the vars need to be the "
"same and in the same order as the adata used to train the model.",
UserWarning,
stacklevel=settings.warnings_stacklevel,
)
registries = attr_dict.pop("registries_")
for adata, registry in zip(adatas, registries, strict=True):
if _MODEL_NAME_KEY in registry and registry[_MODEL_NAME_KEY] != cls.__name__:
raise ValueError("It appears you are loading a model from a different class.")
if _SETUP_ARGS_KEY not in registry:
raise ValueError(
"Saved model does not contain original setup inputs. "
"Cannot load the original setup."
)
cls.setup_anndata(adata, source_registry=registry, **registry[_SETUP_ARGS_KEY])
init_params = attr_dict.pop("init_params_")
if "non_kwargs" in init_params.keys():
non_kwargs = init_params["non_kwargs"]
kwargs = init_params["kwargs"]
kwargs = {k: v for (i, j) in kwargs.items() for (k, v) in j.items()}
else:
non_kwargs = {k: v for k, v in init_params.items() if not isinstance(v, dict)}
kwargs = {k: v for k, v in init_params.items() if isinstance(v, dict)}
kwargs = {k: v for (i, j) in kwargs.items() for (k, v) in j.items()}
model = cls(adata_seq, adata_spatial, **non_kwargs, **kwargs)
for attr, val in attr_dict.items():
setattr(model, attr, val)
model.module.load_state_dict(model_state_dict)
model.module.eval()
model.to_device(device)
return model
# ------------------------------------------------------------------ #
# AnnData setup
# ------------------------------------------------------------------ #
@classmethod
@setup_anndata_dsp.dedent
def setup_anndata(
cls,
adata: AnnData,
batch_key: str | None = None,
labels_key: str | None = None,
layer: str | None = None,
**kwargs,
):
"""%(summary)s.
Call once for ``adata_seq`` and once for ``adata_spatial`` before constructing the model.
Parameters
----------
%(param_batch_key)s
%(param_labels_key)s
%(param_layer)s
"""
setup_method_args = cls._get_setup_method_args(**locals())
anndata_fields = [
LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True),
CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key),
CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key),
]
adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
adata_manager.register_fields(adata, **kwargs)
cls.register_manager(adata_manager)
class TrainDL(DataLoader):
"""Train data loader for GIMVI that cycles the shorter dataset."""
def __init__(self, data_loader_list, **kwargs):
self.data_loader_list = data_loader_list
self.largest_train_dl_idx = np.argmax([len(dl.indices) for dl in data_loader_list])
self.largest_dl = self.data_loader_list[self.largest_train_dl_idx]
super().__init__(self.largest_dl, **kwargs)
def __len__(self):
return len(self.largest_dl)
def __iter__(self):
train_dls = [
dl if i == self.largest_train_dl_idx else cycle(dl)
for i, dl in enumerate(self.data_loader_list)
]
return zip(*train_dls, strict=True)