spatialvi.model.SCVIVA#
- class spatialvi.model.SCVIVA(adata=None, n_hidden=128, n_latent=10, n_layers=1, dropout_rate=0.1, dispersion='gene', gene_likelihood='poisson', latent_distribution='normal', **kwargs)[source]#
scVIVA: variational auto-encoder with niche decoders for ST:cite:p:Levy25.
- Parameters:
adata (
AnnData|None) – AnnData object that has been registered viasetup_anndata(). IfNone, then the underlying module will not be initialized until training, and aLightningDataModulemust be passed in during training.n_hidden (
int) – Number of nodes per hidden layer.n_latent (
int) – Dimensionality of the latent space.n_layers (
int) – Number of hidden layers used for encoder and decoder NNs.dropout_rate (
float) – Dropout rate for neural networks.dispersion (
Literal['gene','gene-batch','gene-label','gene-cell']) –One of the following:
'gene'- dispersion parameter of NB is constant per gene across cells'gene-batch'- dispersion can differ between different batches'gene-label'- dispersion can differ between different labels'gene-cell'- dispersion can differ for every gene in every cell
gene_likelihood (
Literal['zinb','nb','poisson']) –One of:
'nb'- Negative binomial distribution'zinb'- Zero-inflated negative binomial distribution'poisson'- Poisson distribution
latent_distribution (
Literal['normal','ln']) –One of:
'normal'- Normal distribution'ln'- Logistic normal distribution (Normal(0, I) transformed by softmax)
**kwargs – Additional keyword arguments for
nicheVAE.
Examples
>>> adata = anndata.read_h5ad(path_to_anndata) >>> spatialvi.model.SCVIVA.preprocessing_anndata( adata, k_nn = 20, sample_key = 'slide_ID', labels_key = "cell_type", cell_coordinates_key = "spatial", expression_embedding_key = "X_scVI", **kwargs ) >>> spatialvi.model.SCVIVA.setup_anndata(adata, batch_key="batch") >>> vae = spatialvi.model.SCVIVA(adata) >>> vae.train() >>> adata.obsm["X_scVIVA"] = vae.get_latent_representation() >>> adata.obsm["X_normalized_scVIVA"] = vae.get_normalized_expression()
Notes
See further usage examples in the following tutorials:
/tutorials/notebooks/spatial/scVIVA_tutorial
See also
- __init__(adata=None, n_hidden=128, n_latent=10, n_layers=1, dropout_rate=0.1, dispersion='gene', gene_likelihood='poisson', latent_distribution='normal', **kwargs)[source]#
Methods
__init__([adata, n_hidden, n_latent, ...])compute_neighbors(adata[, spatial_key, ...])Compute spatial neighbor graph and store in
adata.obsm.convert_legacy_save(dir_path, output_dir_path)Converts a legacy saved model (<v0.15.0) to the updated save format.
data_registry(registry_key)Returns the object in AnnData associated with the key in the data registry.
deregister_manager([adata])Deregisters the
AnnDataManagerinstance associated with adata.differential_abundance([adata, sample_key, ...])Compute the differential abundance between samples.
differential_expression([adata, groupby, ...])A unified method for differential expression analysis.
from_spatialdata(sdata[, table_key, region])Convenience constructor from a SpatialData object.
get_aggregated_posterior([adata, indices, ...])Compute the aggregated posterior over the
ulatent representations.get_anndata_manager(adata[, required])Retrieves the
AnnDataManagerfor a given AnnData object.get_batch_representation([adata, indices, ...])Get the batch representation for a given set of indices.
get_composition_error([adata, indices, ...])Compute the composition prediction error on the data.
get_elbo([adata, indices, batch_size, ...])Compute the evidence lower bound (ELBO) on the data.
get_feature_correlation_matrix([adata, ...])Generate gene-gene correlation matrix using scvi uncertainty and expression.
get_from_registry(adata, registry_key)Returns the object in AnnData associated with the key in the data registry.
get_importance_weights(adata, indices, qz, ...)Computes importance weights for the given samples.
get_latent_library_size([adata, indices, ...])Returns the latent library size for each cell.
get_latent_representation([adata, indices, ...])Return latent representation with optional RAPIDS backend.
get_likelihood_parameters([adata, indices, ...])Estimates for the parameters of the likelihood \(p(x \mid z)\).
get_marginal_ll([adata, indices, ...])Compute the marginal log-likehood of the data.
get_niche_error([adata, indices, ...])Compute the niche state prediction error on the data.
get_normalized_expression([adata, indices, ...])Returns the normalized (decoded) gene expression.
get_reconstruction_error([adata, indices, ...])Compute the reconstruction error on the data.
get_setup_arg(setup_arg)Returns the string provided to setup of a specific setup_arg.
get_state_registry(registry_key)Returns the state registry for the AnnDataField registered with this instance.
get_var_names([legacy_mudata_format])Variable names of input data.
load(dir_path[, adata, accelerator, device, ...])Instantiate a model from the saved output.
load_query_data([adata, reference_model, ...])Online update of a reference model with scArches algorithm [].
load_registry(dir_path[, prefix])Return the full registry saved with the model.
minify_adata([minified_data_type, ...])Minify the model's
adata.plot_spatial_embedding([adata, basis, color])Plot latent embedding overlaid on tissue spatial coordinates.
posterior_predictive_sample([adata, ...])Generate predictive samples from the posterior predictive distribution.
predict_neighborhood([adata, indices, ...])Predict the cell type composition of each cell niche in the dataset.
predict_niche_activation([adata, indices, ...])Predict the activation of each cell niche in the dataset.
prepare_query_anndata(adata, reference_model)Prepare data for query integration.
prepare_query_mudata(mdata, reference_model)Prepare multimodal dataset for query integration.
preprocessing_anndata(adata[, k_nn, ...])Preprocess an AnnData object for scVIVA analysis.
preprocessing_query_anndata(adata, ...[, ...])Prepare data for query integration.
register_manager(adata_manager)Registers an
AnnDataManagerinstance with this model class.save(dir_path[, prefix, overwrite, ...])Save the state of the model.
setup_anndata(adata[, layer, batch_key, ...])Sets up the
AnnDataobject for this model.setup_spatialdata(sdata[, table_key, region])Register fields from a SpatialData object.
to_device(device)Move the model to the device.
train([max_epochs, accelerator, devices, ...])Train the model.
transfer_fields(adata, **kwargs)Transfer fields from a model to an AnnData object.
update_setup_method_args(setup_method_args)Update setup method args.
view_anndata_setup([adata, ...])Print summary of the setup for the initial AnnData or a given AnnData object.
view_registry([hide_state_registries])Prints summary of the registry.
view_setup_args(dir_path[, prefix])Print args used to setup a saved model.
view_setup_method_args()Prints setup kwargs used to produce a given registry.
Attributes
adataData attached to model instance.
adata_managerManager instance associated with self.adata.
deviceThe current device that the module's params are on.
get_normalized_function_nameWhat the get normalized functions name is
historyReturns computed metrics during training.
is_trainedWhether the model has been trained.
minified_data_typeThe type of minified data associated with this model, if applicable.
registryData attached to model instance.
run_idReturns the run id of the model.
run_nameReturns the run name of the model.
summary_stringSummary string of the model.
test_indicesObservations that are in test set.
train_indicesObservations that are in train set.
validation_indicesObservations that are in validation set.