Cell2Fate: RNA velocity (scvi-tools/pyro)

Cell2Fate RNA velocity modules for prediction of cell fates model class (scvi-tools BaseModelClass)

class cell2fate.Cell2fate_DynamicalModel(adata: AnnData, model_class=None, **model_kwargs)[source]

Bases: QuantileMixin, PyroSampleMixin, PyroSviTrainMixin, BaseModelClass

Cell2fate model. User-end model class.

Note

See Module class for description of the model.

Parameters:
  • adata – Single-cell AnnData object that has been registered via setup_anndata() and contains spliced and unspliced counts in adata.layers['spliced'], adata.layers['unspliced']

  • **model_kwargs – Keyword args for LocationModelLinearDependentWMultiExperimentModel

classmethod setup_anndata(adata: AnnData, layer: str | None = None, batch_key: str | None = None, labels_key: str | None = None, unspliced_label='unspliced', spliced_label='spliced', cluster_label=None, **kwargs)[source]
Sets up the AnnData object for this model.

A mapping will be created between data fields used by this model to their respective locations in adata.

None of the data in adata are modified. Only adds fields to adata.

Parameters:
  • layer – if not None, uses this as the key in adata.layers for raw count data.

  • batch_key – key in adata.obs for batch information. Categories will automatically be converted into integer categories and saved to adata.obs[‘_scvi_batch’]. If None, assigns the same batch to all the data.

  • labels_key – key in adata.obs for label information. Categories will automatically be converted into integer categories and saved to adata.obs[‘_scvi_labels’]. If None, assigns the same label to all the data.

train(max_epochs: int = 500, batch_size: int = 1000, train_size: float = 1, lr: float = 0.01, **kwargs)[source]

Training function for the model.

Parameters:
  • max_epochs – Number of passes through the dataset. If None, defaults to np.min([round((20000 / n_cells) * 400), 400])

  • train_size – Size of training set in the range [0.0, 1.0].

  • batch_size – Minibatch size to use during training. If None, no minibatching occurs and all data is copied to device (e.g., GPU).

  • lr – Optimiser learning rate (default optimiser is ClippedAdam). Specifying optimiser via plan_kwargs overrides this choice of lr.

  • kwargs – Other arguments to scvi.model.base.PyroSviTrainMixin().train() method

plot_history(iter_start=0, iter_end=-1, ax=None)[source]

Plot training history

Parameters:
  • iter_start – Omit initial iterations from the plot.

  • iter_end – Omit last iterations from the plot.

  • ax – Matplotlib axis.

compute_module_summary_statistics(adata)[source]

Computes the contribution of each module to mRNA molecules in each cell.

Parameters:

adata – AnnData object.

Returns:

AnnData object with additional module-related summary statistics.

Return type:

AnnData

export_posterior(adata, sample_kwargs={'batch_size': None, 'num_samples': 30, 'return_samples': True, 'use_gpu': True}, export_slot: str = 'mod', full_velocity_posterior=False, normalize=True)[source]

Summarises posterior distribution and exports results to anndata object. Also computes RNAvelocity (based on posterior of rates) and normalized counts (based on posterior of technical variables).

  • adata.obs: Latent time, sequencing depth constant.

  • adata.var: transcription/splicing/degredation rates, switch on and off times.

  • adata.uns: Posterior of all parameters (‘mean’, ‘sd’, ‘q05’, ‘q95’ and optionally all samples), model name, date.

  • adata.layers: velocity (expected gradient of spliced counts), velocity_sd (uncertainty in this gradient), spliced_norm, unspliced_norm (normalized counts).

  • adata.uns: If return_samples: True and full_velocity_posterior = True full posterior distribution for velocity is saved in adata.uns['velocity_posterior'].

Parameters:
  • adata – AnnData object where results should be saved.

  • sample_kwargs – Optionally a dictionary of arguments for self.sample_posterior, namely:

    • num_sample:s Number of samples to use (Default = 1000).

    • batch_size: Data batch size (keep low enough to fit on GPU, default 2048).

    • use_gpu: Use gpu for generating samples.

    • return_samples: Export all posterior samples (Otherwise just summary statistics).

  • export_slot – adata.uns slot where to export results.

  • full_velocity_posterior – Whether to save full posterior of velocity (only possible if “return_samples: True”).

  • normalize – Whether to compute normalized spliced and unspliced counts based on posterior of technical variables.

Returns:

AnnData object with posterior added in adata.obs, adata.var and adata.uns.

Return type:

AnnData

export_posterior_quantiles(adata, batch_size=None, export_slot: str = 'mod', full_velocity_posterior=False, normalize=True, use_gpu=True, use_median=False)[source]

Exports posteriors as quantiles. Similar to cell2fate.Cell2fate_DynamicalModel.export_posterior() for more scalable workflow.

Parameters:
  • adata – AnnData object.

  • batch_size – Batch size for exporting quantiles. Defaults to None.

  • export_slot – Slot for exporting quantiles. Defaults to “mod”.

  • full_velocity_posterior (bool, optional) – Flag to indicate whether to export the full velocity posterior. Defaults to False.

  • normalize – Flag to indicate whether to normalize. Defaults to True.

  • use_gpu – Flag to indicate whether to use GPU for computation. Defaults to True.

Returns:

AnnData object with exported quantiles.

Return type:

AnnData

view_history()[source]

View training history over various training windows to assess convergence or spot potential training problems.

get_module_top_features(adata, background, species='Mouse', p_adj_cutoff=0.01, n_top_genes=None)[source]

Returns a dataframe with top Genes, TFs, and GO terms of each module.

Parameters:
  • adata – AnnData object.

  • background – List of genes to consider as background.

  • species – Species for which to consider TFs and gene sets. Defaults to ‘Mouse’.

  • p_adj_cutoff – Adjusted p-value cutoff for enrichment analysis. Defaults to 0.01.

  • n_top_genes – Number of top genes to consider for each module. Defaults to None.

Returns:

A tuple containing:

  • DataFrame: DataFrame with top genes, TFs, and GO terms of each module.

  • List: List of DataFrames containing enriched GO terms for each module.

Return type:

Tuple

Pyro and scvi-tools Module classes

Pyro Module class (defining the model using pyro, math description)

class cell2fate._cell2fate_DynamicalModel_module.Cell2fate_DynamicalModel_module(n_obs, n_vars, n_batch, n_extra_categoricals=None, n_modules=10, stochastic_v_ag_hyp_prior={'alpha': 6.0, 'beta': 3.0}, factor_prior={'alpha': 1.0, 'rate': 1.0, 'states_per_gene': 3.0}, t_switch_alpha_prior={'alpha': 1000.0, 'mean': 1000.0}, splicing_rate_hyp_prior={'alpha': 5.0, 'alpha_hyp_alpha': 20.0, 'mean': 1.0, 'mean_hyp_alpha': 10.0}, degredation_rate_hyp_prior={'alpha': 5.0, 'alpha_hyp_alpha': 20.0, 'mean': 1.0, 'mean_hyp_alpha': 10.0}, activation_rate_hyp_prior={'mean_hyp_prior_mean': 2.0, 'mean_hyp_prior_sd': 0.33, 'sd_hyp_prior_mean': 0.33, 'sd_hyp_prior_sd': 0.1}, s_overdispersion_factor_hyp_prior={'alpha_mean': 100.0, 'alpha_sd': 1.0, 'beta_mean': 1.0, 'beta_sd': 0.1}, detection_hyp_prior={'alpha': 10.0, 'mean_alpha': 1.0, 'mean_beta': 1.0}, detection_i_prior={'alpha': 100, 'mean': 1}, detection_gi_prior={'alpha': 200, 'mean': 1}, gene_add_alpha_hyp_prior={'alpha': 9.0, 'beta': 3.0}, gene_add_mean_hyp_prior={'alpha': 1.0, 'beta': 100.0}, Tmax_prior={'mean': 50.0, 'sd': 50.0}, switch_time_sd=0.1, init_vals: dict | None = None)[source]

Bases: PyroModule

Models spliced and unspliced counts for each gene as a dynamical process in which transcriptional modules switch on at one point in time and increase the transcription rate by different values across genes and then optionally switches off to a transcription rate of 0. Splicing and degradation rates are constant for each gene. The model also includes negative binomial noise, batch effects, and technical variables.

Parameters:
  • n_obs – Number of observations in the dataset (e.g., number of cells or samples).

  • n_vars – Number of variables or features in the dataset (e.g., number of genes).

  • n_batch – Number of batches or experimental conditions in the dataset.

  • n_extra_categoricals – Number of additional categorical variables beyond the primary variables of interest.

  • gene_add_alpha_hyp_prior – Hyperparameter prior for the gene additive parameter.

  • gene_add_mean_hyp_prior – Hyperparameter prior for the mean of the gene additive parameter distribution.

  • detection_hyp_prior – Hyperparameter prior for the detection process.

create_plates(u_data, s_data, idx, batch_index)[source]

Creates a Pyro plate for observations.

Parameters:
  • u_data – Unspliced count data.

  • s_data – Spliced count data.

  • idx – Index tensor to subsample.

  • batch_index – Index tensor indicating batch assignments.

Returns:

A Pyro plate representing the observations in the dataset.

Return type:

Pyro.plate

list_obs_plate_vars()[source]

Creates a dictionary with the name of observation/minibatch plate, indexes of model args to provide to encoder, variable names that belong to the observation plate and the number of dimensions in non-plate axis of each variable.

Returns:

A dictionary containing the following keys:

  • name: Name of the observation plate.

  • input: List of indexes of model arguments to provide to the encoder.

  • input_transform: List of transformations to apply to input data before passing to the neural network.

  • sites: Dictionary containing information about variables that belong to the observation plate, including their names and the number of dimensions in the non-plate axis of each variable.

Return type:

Dict

forward(u_data, s_data, idx, batch_index)[source]

Forward pass of the Cell2fate_DynamicalModel_module.

Parameters:
  • u_data – Unspliced count data.

  • s_data – Spliced count data.

  • idx – Index tensor to subsample.

  • batch_index – Index tensor indicating batch assignments.

training: bool