spatialvi.model.SCVIVA

spatialvi.model.SCVIVA#

class spatialvi.model.SCVIVA(adata=None, n_hidden=128, n_latent=10, n_layers=1, dropout_rate=0.1, dispersion='gene', gene_likelihood='poisson', latent_distribution='normal', **kwargs)[source]#

scVIVA: variational auto-encoder with niche decoders for ST:cite:p:Levy25.

Parameters:
  • adata (AnnData | None) – AnnData object that has been registered via setup_anndata(). If None, then the underlying module will not be initialized until training, and a LightningDataModule must be passed in during training.

  • n_hidden (int) – Number of nodes per hidden layer.

  • n_latent (int) – Dimensionality of the latent space.

  • n_layers (int) – Number of hidden layers used for encoder and decoder NNs.

  • dropout_rate (float) – Dropout rate for neural networks.

  • dispersion (Literal['gene', 'gene-batch', 'gene-label', 'gene-cell']) –

    One of the following:

    • 'gene' - dispersion parameter of NB is constant per gene across cells

    • 'gene-batch' - dispersion can differ between different batches

    • 'gene-label' - dispersion can differ between different labels

    • 'gene-cell' - dispersion can differ for every gene in every cell

  • gene_likelihood (Literal['zinb', 'nb', 'poisson']) –

    One of:

    • 'nb' - Negative binomial distribution

    • 'zinb' - Zero-inflated negative binomial distribution

    • 'poisson' - Poisson distribution

  • latent_distribution (Literal['normal', 'ln']) –

    One of:

    • 'normal' - Normal distribution

    • 'ln' - Logistic normal distribution (Normal(0, I) transformed by softmax)

  • **kwargs – Additional keyword arguments for nicheVAE.

Examples

>>> adata = anndata.read_h5ad(path_to_anndata)
>>> spatialvi.model.SCVIVA.preprocessing_anndata(
    adata,
    k_nn = 20,
    sample_key = 'slide_ID',
    labels_key = "cell_type",
    cell_coordinates_key = "spatial",
    expression_embedding_key = "X_scVI",
    **kwargs
)
>>> spatialvi.model.SCVIVA.setup_anndata(adata, batch_key="batch")
>>> vae = spatialvi.model.SCVIVA(adata)
>>> vae.train()
>>> adata.obsm["X_scVIVA"] = vae.get_latent_representation()
>>> adata.obsm["X_normalized_scVIVA"] = vae.get_normalized_expression()

Notes

See further usage examples in the following tutorials:

  1. /tutorials/notebooks/spatial/scVIVA_tutorial

See also

nicheVAE

__init__(adata=None, n_hidden=128, n_latent=10, n_layers=1, dropout_rate=0.1, dispersion='gene', gene_likelihood='poisson', latent_distribution='normal', **kwargs)[source]#

Methods

__init__([adata, n_hidden, n_latent, ...])

compute_neighbors(adata[, spatial_key, ...])

Compute spatial neighbor graph and store in adata.obsm.

convert_legacy_save(dir_path, output_dir_path)

Converts a legacy saved model (<v0.15.0) to the updated save format.

data_registry(registry_key)

Returns the object in AnnData associated with the key in the data registry.

deregister_manager([adata])

Deregisters the AnnDataManager instance associated with adata.

differential_abundance([adata, sample_key, ...])

Compute the differential abundance between samples.

differential_expression([adata, groupby, ...])

A unified method for differential expression analysis.

from_spatialdata(sdata[, table_key, region])

Convenience constructor from a SpatialData object.

get_aggregated_posterior([adata, indices, ...])

Compute the aggregated posterior over the u latent representations.

get_anndata_manager(adata[, required])

Retrieves the AnnDataManager for a given AnnData object.

get_batch_representation([adata, indices, ...])

Get the batch representation for a given set of indices.

get_composition_error([adata, indices, ...])

Compute the composition prediction error on the data.

get_elbo([adata, indices, batch_size, ...])

Compute the evidence lower bound (ELBO) on the data.

get_feature_correlation_matrix([adata, ...])

Generate gene-gene correlation matrix using scvi uncertainty and expression.

get_from_registry(adata, registry_key)

Returns the object in AnnData associated with the key in the data registry.

get_importance_weights(adata, indices, qz, ...)

Computes importance weights for the given samples.

get_latent_library_size([adata, indices, ...])

Returns the latent library size for each cell.

get_latent_representation([adata, indices, ...])

Return latent representation with optional RAPIDS backend.

get_likelihood_parameters([adata, indices, ...])

Estimates for the parameters of the likelihood \(p(x \mid z)\).

get_marginal_ll([adata, indices, ...])

Compute the marginal log-likehood of the data.

get_niche_error([adata, indices, ...])

Compute the niche state prediction error on the data.

get_normalized_expression([adata, indices, ...])

Returns the normalized (decoded) gene expression.

get_reconstruction_error([adata, indices, ...])

Compute the reconstruction error on the data.

get_setup_arg(setup_arg)

Returns the string provided to setup of a specific setup_arg.

get_state_registry(registry_key)

Returns the state registry for the AnnDataField registered with this instance.

get_var_names([legacy_mudata_format])

Variable names of input data.

load(dir_path[, adata, accelerator, device, ...])

Instantiate a model from the saved output.

load_query_data([adata, reference_model, ...])

Online update of a reference model with scArches algorithm [].

load_registry(dir_path[, prefix])

Return the full registry saved with the model.

minify_adata([minified_data_type, ...])

Minify the model's adata.

plot_spatial_embedding([adata, basis, color])

Plot latent embedding overlaid on tissue spatial coordinates.

posterior_predictive_sample([adata, ...])

Generate predictive samples from the posterior predictive distribution.

predict_neighborhood([adata, indices, ...])

Predict the cell type composition of each cell niche in the dataset.

predict_niche_activation([adata, indices, ...])

Predict the activation of each cell niche in the dataset.

prepare_query_anndata(adata, reference_model)

Prepare data for query integration.

prepare_query_mudata(mdata, reference_model)

Prepare multimodal dataset for query integration.

preprocessing_anndata(adata[, k_nn, ...])

Preprocess an AnnData object for scVIVA analysis.

preprocessing_query_anndata(adata, ...[, ...])

Prepare data for query integration.

register_manager(adata_manager)

Registers an AnnDataManager instance with this model class.

save(dir_path[, prefix, overwrite, ...])

Save the state of the model.

setup_anndata(adata[, layer, batch_key, ...])

Sets up the AnnData object for this model.

setup_spatialdata(sdata[, table_key, region])

Register fields from a SpatialData object.

to_device(device)

Move the model to the device.

train([max_epochs, accelerator, devices, ...])

Train the model.

transfer_fields(adata, **kwargs)

Transfer fields from a model to an AnnData object.

update_setup_method_args(setup_method_args)

Update setup method args.

view_anndata_setup([adata, ...])

Print summary of the setup for the initial AnnData or a given AnnData object.

view_registry([hide_state_registries])

Prints summary of the registry.

view_setup_args(dir_path[, prefix])

Print args used to setup a saved model.

view_setup_method_args()

Prints setup kwargs used to produce a given registry.

Attributes

adata

Data attached to model instance.

adata_manager

Manager instance associated with self.adata.

device

The current device that the module's params are on.

get_normalized_function_name

What the get normalized functions name is

history

Returns computed metrics during training.

is_trained

Whether the model has been trained.

minified_data_type

The type of minified data associated with this model, if applicable.

registry

Data attached to model instance.

run_id

Returns the run id of the model.

run_name

Returns the run name of the model.

summary_string

Summary string of the model.

test_indices

Observations that are in test set.

train_indices

Observations that are in train set.

validation_indices

Observations that are in validation set.