Cell2Fate: RNA velocity (scvi-tools/pyro)
Cell2Fate RNA velocity modules for prediction of cell fates model class (scvi-tools BaseModelClass)
- class cell2fate.Cell2fate_DynamicalModel(adata: AnnData, model_class=None, **model_kwargs)[source]
Bases:
QuantileMixin,PyroSampleMixin,PyroSviTrainMixin,BaseModelClassCell2fate model. User-end model class.
Note
See Module class for description of the model.
- Parameters:
adata – Single-cell AnnData object that has been registered via
setup_anndata()and contains spliced and unspliced counts inadata.layers['spliced'],adata.layers['unspliced']**model_kwargs – Keyword args for
LocationModelLinearDependentWMultiExperimentModel
- classmethod setup_anndata(adata: AnnData, layer: str | None = None, batch_key: str | None = None, labels_key: str | None = None, unspliced_label='unspliced', spliced_label='spliced', cluster_label=None, **kwargs)[source]
- Sets up the
AnnDataobject for this model. A mapping will be created between data fields used by this model to their respective locations in adata.
None of the data in adata are modified. Only adds fields to adata.
- Parameters:
layer – if not None, uses this as the key in adata.layers for raw count data.
batch_key – key in adata.obs for batch information. Categories will automatically be converted into integer categories and saved to adata.obs[‘_scvi_batch’]. If None, assigns the same batch to all the data.
labels_key – key in adata.obs for label information. Categories will automatically be converted into integer categories and saved to adata.obs[‘_scvi_labels’]. If None, assigns the same label to all the data.
- Sets up the
- train(max_epochs: int = 500, batch_size: int = 1000, train_size: float = 1, lr: float = 0.01, **kwargs)[source]
Training function for the model.
- Parameters:
max_epochs – Number of passes through the dataset. If None, defaults to
np.min([round((20000 / n_cells) * 400), 400])train_size – Size of training set in the range [0.0, 1.0].
batch_size – Minibatch size to use during training. If None, no minibatching occurs and all data is copied to device (e.g., GPU).
lr – Optimiser learning rate (default optimiser is
ClippedAdam). Specifying optimiser via plan_kwargs overrides this choice of lr.kwargs – Other arguments to
scvi.model.base.PyroSviTrainMixin().train()method
- plot_history(iter_start=0, iter_end=-1, ax=None)[source]
Plot training history
- Parameters:
iter_start – Omit initial iterations from the plot.
iter_end – Omit last iterations from the plot.
ax – Matplotlib axis.
- compute_module_summary_statistics(adata)[source]
Computes the contribution of each module to mRNA molecules in each cell.
- Parameters:
adata – AnnData object.
- Returns:
AnnData object with additional module-related summary statistics.
- Return type:
AnnData
- export_posterior(adata, sample_kwargs={'batch_size': None, 'num_samples': 30, 'return_samples': True, 'use_gpu': True}, export_slot: str = 'mod', full_velocity_posterior=False, normalize=True)[source]
Summarises posterior distribution and exports results to anndata object. Also computes RNAvelocity (based on posterior of rates) and normalized counts (based on posterior of technical variables).
adata.obs: Latent time, sequencing depth constant.
adata.var: transcription/splicing/degredation rates, switch on and off times.
adata.uns: Posterior of all parameters (‘mean’, ‘sd’, ‘q05’, ‘q95’ and optionally all samples), model name, date.
adata.layers:
velocity(expected gradient of spliced counts),velocity_sd(uncertainty in this gradient),spliced_norm,unspliced_norm(normalized counts).adata.uns: If
return_samples: Trueandfull_velocity_posterior = Truefull posterior distribution for velocity is saved inadata.uns['velocity_posterior'].
- Parameters:
adata – AnnData object where results should be saved.
sample_kwargs – Optionally a dictionary of arguments for self.sample_posterior, namely:
num_sample:s Number of samples to use (Default = 1000).
batch_size: Data batch size (keep low enough to fit on GPU, default 2048).
use_gpu: Use gpu for generating samples.
return_samples: Export all posterior samples (Otherwise just summary statistics).
export_slot – adata.uns slot where to export results.
full_velocity_posterior – Whether to save full posterior of velocity (only possible if “return_samples: True”).
normalize – Whether to compute normalized spliced and unspliced counts based on posterior of technical variables.
- Returns:
AnnData object with posterior added in adata.obs, adata.var and adata.uns.
- Return type:
AnnData
- export_posterior_quantiles(adata, batch_size=None, export_slot: str = 'mod', full_velocity_posterior=False, normalize=True, use_gpu=True, use_median=False)[source]
Exports posteriors as quantiles. Similar to
cell2fate.Cell2fate_DynamicalModel.export_posterior()for more scalable workflow.- Parameters:
adata – AnnData object.
batch_size – Batch size for exporting quantiles. Defaults to None.
export_slot – Slot for exporting quantiles. Defaults to “mod”.
full_velocity_posterior (bool, optional) – Flag to indicate whether to export the full velocity posterior. Defaults to False.
normalize – Flag to indicate whether to normalize. Defaults to True.
use_gpu – Flag to indicate whether to use GPU for computation. Defaults to True.
- Returns:
AnnData object with exported quantiles.
- Return type:
AnnData
- view_history()[source]
View training history over various training windows to assess convergence or spot potential training problems.
- get_module_top_features(adata, background, species='Mouse', p_adj_cutoff=0.01, n_top_genes=None)[source]
Returns a dataframe with top Genes, TFs, and GO terms of each module.
- Parameters:
adata – AnnData object.
background – List of genes to consider as background.
species – Species for which to consider TFs and gene sets. Defaults to ‘Mouse’.
p_adj_cutoff – Adjusted p-value cutoff for enrichment analysis. Defaults to 0.01.
n_top_genes – Number of top genes to consider for each module. Defaults to None.
- Returns:
A tuple containing:
DataFrame: DataFrame with top genes, TFs, and GO terms of each module.
List: List of DataFrames containing enriched GO terms for each module.
- Return type:
Tuple
Pyro and scvi-tools Module classes
Pyro Module class (defining the model using pyro, math description)
- class cell2fate._cell2fate_DynamicalModel_module.Cell2fate_DynamicalModel_module(n_obs, n_vars, n_batch, n_extra_categoricals=None, n_modules=10, stochastic_v_ag_hyp_prior={'alpha': 6.0, 'beta': 3.0}, factor_prior={'alpha': 1.0, 'rate': 1.0, 'states_per_gene': 3.0}, t_switch_alpha_prior={'alpha': 1000.0, 'mean': 1000.0}, splicing_rate_hyp_prior={'alpha': 5.0, 'alpha_hyp_alpha': 20.0, 'mean': 1.0, 'mean_hyp_alpha': 10.0}, degredation_rate_hyp_prior={'alpha': 5.0, 'alpha_hyp_alpha': 20.0, 'mean': 1.0, 'mean_hyp_alpha': 10.0}, activation_rate_hyp_prior={'mean_hyp_prior_mean': 2.0, 'mean_hyp_prior_sd': 0.33, 'sd_hyp_prior_mean': 0.33, 'sd_hyp_prior_sd': 0.1}, s_overdispersion_factor_hyp_prior={'alpha_mean': 100.0, 'alpha_sd': 1.0, 'beta_mean': 1.0, 'beta_sd': 0.1}, detection_hyp_prior={'alpha': 10.0, 'mean_alpha': 1.0, 'mean_beta': 1.0}, detection_i_prior={'alpha': 100, 'mean': 1}, detection_gi_prior={'alpha': 200, 'mean': 1}, gene_add_alpha_hyp_prior={'alpha': 9.0, 'beta': 3.0}, gene_add_mean_hyp_prior={'alpha': 1.0, 'beta': 100.0}, Tmax_prior={'mean': 50.0, 'sd': 50.0}, switch_time_sd=0.1, init_vals: dict | None = None)[source]
Bases:
PyroModuleModels spliced and unspliced counts for each gene as a dynamical process in which transcriptional modules switch on at one point in time and increase the transcription rate by different values across genes and then optionally switches off to a transcription rate of 0. Splicing and degradation rates are constant for each gene. The model also includes negative binomial noise, batch effects, and technical variables.
- Parameters:
n_obs – Number of observations in the dataset (e.g., number of cells or samples).
n_vars – Number of variables or features in the dataset (e.g., number of genes).
n_batch – Number of batches or experimental conditions in the dataset.
n_extra_categoricals – Number of additional categorical variables beyond the primary variables of interest.
gene_add_alpha_hyp_prior – Hyperparameter prior for the gene additive parameter.
gene_add_mean_hyp_prior – Hyperparameter prior for the mean of the gene additive parameter distribution.
detection_hyp_prior – Hyperparameter prior for the detection process.
- create_plates(u_data, s_data, idx, batch_index)[source]
Creates a Pyro plate for observations.
- Parameters:
u_data – Unspliced count data.
s_data – Spliced count data.
idx – Index tensor to subsample.
batch_index – Index tensor indicating batch assignments.
- Returns:
A Pyro plate representing the observations in the dataset.
- Return type:
Pyro.plate
- list_obs_plate_vars()[source]
Creates a dictionary with the name of observation/minibatch plate, indexes of model args to provide to encoder, variable names that belong to the observation plate and the number of dimensions in non-plate axis of each variable.
- Returns:
A dictionary containing the following keys:
name: Name of the observation plate.
input: List of indexes of model arguments to provide to the encoder.
input_transform: List of transformations to apply to input data before passing to the neural network.
sites: Dictionary containing information about variables that belong to the observation plate, including their names and the number of dimensions in the non-plate axis of each variable.
- Return type:
Dict
- forward(u_data, s_data, idx, batch_index)[source]
Forward pass of the
Cell2fate_DynamicalModel_module.- Parameters:
u_data – Unspliced count data.
s_data – Spliced count data.
idx – Index tensor to subsample.
batch_index – Index tensor indicating batch assignments.
- training: bool