Source code for cell2fate._cell2fate_DynamicalModel_module

from typing import Optional
import numpy as np
import pandas as pd
import pyro
import pyro.distributions as dist
import torch
from pyro.nn import PyroModule
from scvi import REGISTRY_KEYS
import pandas as pd
from scvi.nn import one_hot
from cell2fate.utils import G_a, G_b, mu_mRNA_continousAlpha_globalTime_twoStates
from pyro.infer import config_enumerate
from pyro.ops.indexing import Vindex
from torch.distributions import constraints

[docs] class Cell2fate_DynamicalModel_module(PyroModule): r""" Models spliced and unspliced counts for each gene as a dynamical process in which transcriptional modules switch on at one point in time and increase the transcription rate by different values across genes and then optionally switches off to a transcription rate of 0. Splicing and degradation rates are constant for each gene. The model also includes negative binomial noise, batch effects, and technical variables. Parameters ---------- n_obs Number of observations in the dataset (e.g., number of cells or samples). n_vars Number of variables or features in the dataset (e.g., number of genes). n_batch Number of batches or experimental conditions in the dataset. n_extra_categoricals Number of additional categorical variables beyond the primary variables of interest. gene_add_alpha_hyp_prior Hyperparameter prior for the gene additive parameter. gene_add_mean_hyp_prior Hyperparameter prior for the mean of the gene additive parameter distribution. detection_hyp_prior Hyperparameter prior for the detection process. """ def __init__( self, n_obs, n_vars, n_batch, n_extra_categoricals=None, n_modules = 10, stochastic_v_ag_hyp_prior={"alpha": 6.0, "beta": 3.0}, factor_prior={"rate": 1.0, "alpha": 1.0, "states_per_gene": 3.0}, t_switch_alpha_prior = {"mean": 1000., "alpha": 1000.}, splicing_rate_hyp_prior={"mean": 1.0, "alpha": 5.0, "mean_hyp_alpha": 10., "alpha_hyp_alpha": 20.}, degredation_rate_hyp_prior={"mean": 1.0, "alpha": 5.0, "mean_hyp_alpha": 10., "alpha_hyp_alpha": 20.}, activation_rate_hyp_prior={"mean_hyp_prior_mean": 2., "mean_hyp_prior_sd": 0.33, "sd_hyp_prior_mean": 0.33, "sd_hyp_prior_sd": 0.1}, s_overdispersion_factor_hyp_prior={'alpha_mean': 100., 'beta_mean': 1., 'alpha_sd': 1., 'beta_sd': 0.1}, detection_hyp_prior={"alpha": 10.0, "mean_alpha": 1.0, "mean_beta": 1.0}, detection_i_prior={"mean": 1, "alpha": 100}, detection_gi_prior={"mean": 1, "alpha": 200}, gene_add_alpha_hyp_prior={"alpha": 9.0, "beta": 3.0}, gene_add_mean_hyp_prior={"alpha": 1.0, "beta": 100.0}, Tmax_prior={"mean": 50., "sd": 50.}, switch_time_sd = 0.1, init_vals: Optional[dict] = None ): ############# Initialise parameters ################ super().__init__() self.n_modules = n_modules self.n_obs = n_obs self.n_vars = n_vars self.n_batch = n_batch self.n_extra_categoricals = n_extra_categoricals self.factor_prior = factor_prior self.stochastic_v_ag_hyp_prior = stochastic_v_ag_hyp_prior self.gene_add_alpha_hyp_prior = gene_add_alpha_hyp_prior self.gene_add_mean_hyp_prior = gene_add_mean_hyp_prior self.detection_hyp_prior = detection_hyp_prior self.splicing_rate_hyp_prior = splicing_rate_hyp_prior self.degredation_rate_hyp_prior = degredation_rate_hyp_prior self.s_overdispersion_factor_hyp_prior = s_overdispersion_factor_hyp_prior self.t_switch_alpha_prior = t_switch_alpha_prior self.detection_gi_prior = detection_gi_prior self.detection_i_prior = detection_i_prior if (init_vals is not None) & (type(init_vals) is dict): self.np_init_vals = init_vals for k in init_vals.keys(): if k == 'I_cm': self.register_buffer(f"init_val_{k}", torch.tensor(init_vals[k], dtype = torch.long)) else: self.register_buffer(f"init_val_{k}", torch.tensor(init_vals[k])) self.register_buffer( "n_modules_torch", torch.tensor(n_modules, dtype = torch.float32), ) self.register_buffer( "s_overdispersion_factor_alpha_mean", torch.tensor(self.s_overdispersion_factor_hyp_prior["alpha_mean"]), ) self.register_buffer( "s_overdispersion_factor_beta_mean", torch.tensor(self.s_overdispersion_factor_hyp_prior["beta_mean"]), ) self.register_buffer( "s_overdispersion_factor_alpha_sd", torch.tensor(self.s_overdispersion_factor_hyp_prior["alpha_sd"]), ) self.register_buffer( "s_overdispersion_factor_beta_sd", torch.tensor(self.s_overdispersion_factor_hyp_prior["beta_sd"]), ) self.register_buffer( "detection_gi_prior_alpha", torch.tensor(self.detection_gi_prior["alpha"]), ) self.register_buffer( "detection_gi_prior_beta", torch.tensor(self.detection_gi_prior["alpha"] / self.detection_gi_prior["mean"]), ) self.register_buffer( "detection_i_prior_alpha", torch.tensor(self.detection_i_prior["alpha"]), ) self.register_buffer( "detection_i_prior_beta", torch.tensor(self.detection_i_prior["alpha"] / self.detection_i_prior["mean"]), ) self.register_buffer( "Tmax_mean", torch.tensor(Tmax_prior["mean"]), ) self.register_buffer( "Tmax_sd", torch.tensor(Tmax_prior["sd"]), ) self.register_buffer( "switch_time_sd", torch.tensor(switch_time_sd), ) self.register_buffer( "t_mi_alpha_alpha", torch.tensor(t_switch_alpha_prior['alpha']), ) self.register_buffer( "t_mi_alpha_mu", torch.tensor(t_switch_alpha_prior['alpha']), ) self.t_switch_alpha_prior self.register_buffer( "detection_mean_hyp_prior_alpha", torch.tensor(self.detection_hyp_prior["mean_alpha"]), ) self.register_buffer( "detection_mean_hyp_prior_beta", torch.tensor(self.detection_hyp_prior["mean_beta"]), ) self.register_buffer( "stochastic_v_ag_hyp_prior_alpha", torch.tensor(self.stochastic_v_ag_hyp_prior["alpha"]), ) self.register_buffer( "stochastic_v_ag_hyp_prior_beta", torch.tensor(self.stochastic_v_ag_hyp_prior["beta"]), ) self.register_buffer( "gene_add_alpha_hyp_prior_alpha", torch.tensor(self.gene_add_alpha_hyp_prior["alpha"]), ) self.register_buffer( "gene_add_alpha_hyp_prior_beta", torch.tensor(self.gene_add_alpha_hyp_prior["beta"]), ) self.register_buffer( "gene_add_mean_hyp_prior_alpha", torch.tensor(self.gene_add_mean_hyp_prior["alpha"]), ) self.register_buffer( "gene_add_mean_hyp_prior_beta", torch.tensor(self.gene_add_mean_hyp_prior["beta"]), ) self.register_buffer( "detection_hyp_prior_alpha", torch.tensor(self.detection_hyp_prior["alpha"]), ) self.register_buffer("ones_n_batch_1", torch.ones((self.n_batch, 1))) self.register_buffer("ones", torch.ones((1, 1))) self.register_buffer("ones2", torch.ones((self.n_obs, self.n_vars))) self.register_buffer("eps", torch.tensor(1e-8)) self.register_buffer("alpha_OFFg", torch.tensor(10**(-5))) self.register_buffer("one", torch.tensor(1.)) self.register_buffer("zero", torch.tensor(0.)) self.register_buffer("zero_point_one", torch.tensor(0.1)) self.register_buffer("one_point_one", torch.tensor(1.1)) self.register_buffer("one_point_two", torch.tensor(1.2)) self.register_buffer("zeros", torch.zeros(self.n_obs, self.n_vars)) self.register_buffer("ones_g", torch.ones((1,self.n_vars,1))) self.register_buffer("ones_m", torch.ones(n_modules)) # Register parameters for activation rate hyperprior: self.register_buffer( "activation_rate_mean_hyp_prior_mean", torch.tensor(activation_rate_hyp_prior["mean_hyp_prior_mean"]), ) self.register_buffer( "activation_rate_mean_hyp_prior_sd", torch.tensor(activation_rate_hyp_prior["mean_hyp_prior_sd"]), ) self.register_buffer( "activation_rate_sd_hyp_prior_mean", torch.tensor(activation_rate_hyp_prior["sd_hyp_prior_mean"]), ) self.register_buffer( "activation_rate_sd_hyp_prior_sd", torch.tensor(activation_rate_hyp_prior["sd_hyp_prior_sd"]), ) # Register parameters for splicing rate hyperprior: self.register_buffer( "splicing_rate_alpha_hyp_prior_mean", torch.tensor(self.splicing_rate_hyp_prior["alpha"]), ) self.register_buffer( "splicing_rate_mean_hyp_prior_mean", torch.tensor(self.splicing_rate_hyp_prior["mean"]), ) self.register_buffer( "splicing_rate_alpha_hyp_prior_alpha", torch.tensor(self.splicing_rate_hyp_prior["alpha_hyp_alpha"]), ) self.register_buffer( "splicing_rate_mean_hyp_prior_alpha", torch.tensor(self.splicing_rate_hyp_prior["mean_hyp_alpha"]), ) # Register parameters for degredation rate hyperprior: self.register_buffer( "degredation_rate_alpha_hyp_prior_mean", torch.tensor(self.degredation_rate_hyp_prior["alpha"]), ) self.register_buffer( "degredation_rate_mean_hyp_prior_mean", torch.tensor(self.degredation_rate_hyp_prior["mean"]), ) self.register_buffer( "degredation_rate_alpha_hyp_prior_alpha", torch.tensor(self.degredation_rate_hyp_prior["alpha_hyp_alpha"]), ) self.register_buffer( "degredation_rate_mean_hyp_prior_alpha", torch.tensor(self.degredation_rate_hyp_prior["mean_hyp_alpha"]), ) # per gene rate priors self.register_buffer( "factor_prior_alpha", torch.tensor(self.factor_prior["alpha"]), ) self.register_buffer( "factor_prior_beta", torch.tensor(self.factor_prior["alpha"] / self.factor_prior["rate"]), ) self.register_buffer("n_factors_torch", torch.tensor(self.n_modules)) self.register_buffer( "factor_states_per_gene", torch.tensor(self.factor_prior["states_per_gene"]), ) self.register_buffer( "t_c_init", self.one*torch.ones((self.n_obs, 1, 1))/2., ) self.register_buffer( "t_mON_init", torch.ones((1, 1, self.n_modules))/2., ) self.register_buffer( "t_mOFF_init", torch.zeros((1, 1, self.n_modules)) + 0.1, ) self.register_buffer( "probs_I_cm", torch.tensor(1.-10**(-10)), ) ############# Define the model ################ @staticmethod def _get_fn_args_from_batch_no_cat(tensor_dict): u_data = tensor_dict['unspliced'] s_data = tensor_dict['spliced'] ind_x = tensor_dict["ind_x"].long().squeeze() batch_index = tensor_dict[REGISTRY_KEYS.BATCH_KEY] return (u_data, s_data, ind_x, batch_index), {} @staticmethod def _get_fn_args_from_batch_cat(tensor_dict): u_data = tensor_dict['unspliced'] s_data = tensor_dict['spliced'] ind_x = tensor_dict["ind_x"].long().squeeze() batch_index = tensor_dict[REGISTRY_KEYS.BATCH_KEY] extra_categoricals = tensor_dict[REGISTRY_KEYS.CAT_COVS_KEY] return (u_data, s_data, ind_x, batch_index), {} @property def _get_fn_args_from_batch(self): if self.n_extra_categoricals is not None: return self._get_fn_args_from_batch_cat else: return self._get_fn_args_from_batch_no_cat
[docs] def create_plates(self, u_data, s_data, idx, batch_index): """ Creates a Pyro plate for observations. Parameters ---------- u_data Unspliced count data. s_data Spliced count data. idx Index tensor to subsample. batch_index Index tensor indicating batch assignments. Returns ------- Pyro.plate A Pyro plate representing the observations in the dataset. """ return pyro.plate("obs_plate", size=self.n_obs, dim=-3, subsample=idx)
[docs] def list_obs_plate_vars(self): """ Creates a dictionary with the name of observation/minibatch plate, indexes of model args to provide to encoder, variable names that belong to the observation plate and the number of dimensions in non-plate axis of each variable. Returns ------- Dict A dictionary containing the following keys: - **name:** Name of the observation plate. - **input:** List of indexes of model arguments to provide to the encoder. - **input_transform:** List of transformations to apply to input data before passing to the neural network. - **sites:** Dictionary containing information about variables that belong to the observation plate, including their names and the number of dimensions in the non-plate axis of each variable. """ return { "name": "obs_plate", "input": [], # expression data + (optional) batch index "input_transform": [], # how to transform input data before passing to NN "sites": {}, }
[docs] def forward(self, u_data, s_data, idx, batch_index): """ Forward pass of the :class:`~cell2fate.Cell2fate_DynamicalModel_module`. Parameters ---------- u_data Unspliced count data. s_data Spliced count data. idx Index tensor to subsample. batch_index Index tensor indicating batch assignments. """ batch_size = len(idx) obs2sample = one_hot(batch_index, self.n_batch) obs_plate = self.create_plates(u_data, s_data, idx, batch_index) # ===================== Kinetic Rates ======================= # # Splicing rate: splicing_alpha = pyro.sample('splicing_alpha', dist.Gamma(self.splicing_rate_alpha_hyp_prior_alpha, self.splicing_rate_alpha_hyp_prior_alpha/self.splicing_rate_alpha_hyp_prior_mean)) splicing_mean = pyro.sample('splicing_mean', dist.Gamma(self.splicing_rate_mean_hyp_prior_alpha, self.splicing_rate_mean_hyp_prior_alpha/self.splicing_rate_mean_hyp_prior_mean)) beta_g = pyro.sample('beta_g', dist.Gamma(splicing_alpha, splicing_alpha/splicing_mean).expand([1,self.n_vars]).to_event(2)) # Degredation rate: degredation_alpha = pyro.sample('degredation_alpha', dist.Gamma(self.degredation_rate_alpha_hyp_prior_alpha, self.degredation_rate_alpha_hyp_prior_alpha/self.degredation_rate_alpha_hyp_prior_mean)) degredation_alpha = degredation_alpha + 0.001 degredation_mean = pyro.sample('degredation_mean', dist.Gamma(self.degredation_rate_mean_hyp_prior_alpha, self.degredation_rate_mean_hyp_prior_alpha/self.degredation_rate_mean_hyp_prior_mean)) gamma_g = pyro.sample('gamma_g', dist.Gamma(degredation_alpha, degredation_alpha/degredation_mean).expand([1,self.n_vars]).to_event(2)) # Transcription rate contribution of each module: factor_level_g = pyro.sample( "factor_level_g", dist.Gamma(self.factor_prior_alpha, self.factor_prior_beta) .expand([1, self.n_vars]) .to_event(2) ) g_fg = pyro.sample( # (g_fg corresponds to module's spliced counts in steady state) "g_fg", dist.Gamma( self.factor_states_per_gene / self.n_factors_torch, self.ones / factor_level_g, ) .expand([self.n_modules, self.n_vars]) .to_event(2) ) A_mgON = pyro.deterministic('A_mgON', g_fg*gamma_g) # (transform from spliced counts to transcription rate) A_mgOFF = self.alpha_OFFg # Activation and Deactivation rate: lam_mu = pyro.sample('lam_mu', dist.Gamma(G_a(self.activation_rate_mean_hyp_prior_mean, self.activation_rate_mean_hyp_prior_sd), G_b(self.activation_rate_mean_hyp_prior_mean, self.activation_rate_mean_hyp_prior_sd))) lam_sd = pyro.sample('lam_sd', dist.Gamma(G_a(self.activation_rate_sd_hyp_prior_mean, self.activation_rate_sd_hyp_prior_sd), G_b(self.activation_rate_sd_hyp_prior_mean, self.activation_rate_sd_hyp_prior_sd))) lam_m_mu = pyro.sample('lam_m_mu', dist.Gamma(G_a(lam_mu, lam_sd), G_b(lam_mu, lam_sd)).expand([self.n_modules, 1, 1]).to_event(3)) lam_mi = pyro.sample('lam_mi', dist.Gamma(G_a(lam_m_mu, lam_m_mu*0.05), G_b(lam_m_mu, lam_m_mu*0.05)).expand([self.n_modules, 1, 2]).to_event(3)) # =====================Time======================= # # Global time for each cell: T_max = pyro.sample('Tmax', dist.Gamma(G_a(self.Tmax_mean, self.Tmax_sd), G_b(self.Tmax_mean, self.Tmax_sd))) t_c_loc = pyro.sample('t_c_loc', dist.Gamma(self.one, self.one/0.5)) t_c_scale = pyro.sample('t_c_scale', dist.Gamma(self.one, self.one/0.25)) with obs_plate: t_c = pyro.sample('t_c', dist.Normal(t_c_loc, t_c_scale).expand([batch_size, 1, 1])) T_c = pyro.deterministic('T_c', t_c*T_max) # Global switch on time for each gene: # t_mON = pyro.sample('t_mON', dist.Uniform(self.zero, self.one).expand([1, 1, self.n_modules]).to_event(2)) t_delta = pyro.sample('t_delta', dist.Gamma(self.one*20, self.one * 20 *self.n_modules_torch). expand([self.n_modules]).to_event(1)) t_mON = torch.cumsum(torch.concat([self.zero.unsqueeze(0), t_delta[:-1]]), dim = 0).unsqueeze(0).unsqueeze(0) T_mON = pyro.deterministic('T_mON', T_max*t_mON) # Global switch off time for each gene: t_mOFF = pyro.sample('t_mOFF', dist.Exponential(self.n_modules_torch).expand([1, 1, self.n_modules]).to_event(2)) T_mOFF = pyro.deterministic('T_mOFF', T_mON + T_max*t_mOFF) # =========== Mean expression according to RNAvelocity model ======================= # mu_total = torch.stack([self.zeros[idx,...], self.zeros[idx,...]], axis = -1) for m in range(self.n_modules): mu_total += mu_mRNA_continousAlpha_globalTime_twoStates( A_mgON[m,:], A_mgOFF, beta_g, gamma_g, lam_mi[m,...], T_c[:,:,0], T_mON[:,:,m], T_mOFF[:,:,m], self.zeros[idx,...]) with obs_plate: mu_expression = pyro.deterministic('mu_expression', mu_total) # =============Detection efficiency of spliced and unspliced counts =============== # # Cell specific relative detection efficiency with hierarchical prior across batches: detection_mean_y_e = pyro.sample( "detection_mean_y_e", dist.Beta( self.ones * self.detection_mean_hyp_prior_alpha, self.ones * self.detection_mean_hyp_prior_beta, ) .expand([self.n_batch, 1]) .to_event(2), ) detection_hyp_prior_alpha = pyro.deterministic( "detection_hyp_prior_alpha", self.detection_hyp_prior_alpha, ) beta = detection_hyp_prior_alpha / (obs2sample @ detection_mean_y_e) with obs_plate: detection_y_c = pyro.sample( "detection_y_c", dist.Gamma(detection_hyp_prior_alpha.unsqueeze(dim=-1), beta.unsqueeze(dim=-1)), ) # (self.n_obs, 1) # Global relative detection efficiency between spliced and unspliced counts detection_y_i = pyro.sample( "detection_y_i", dist.Gamma( self.ones * self.detection_i_prior_alpha, self.ones * self.detection_i_prior_alpha, ) .expand([1, 1, 2]).to_event(3) ) # Gene specific relative detection efficiency between spliced and unspliced counts detection_y_gi = pyro.sample( "detection_y_gi", dist.Gamma( self.ones * self.detection_gi_prior_alpha, self.ones * self.detection_gi_prior_alpha, ) .expand([1, self.n_vars, 2]) .to_event(3), ) # =======Gene-specific additive component (Ambient RNA/ "Soup") for spliced and unspliced counts ====== # # Independently sampled for spliced and unspliced counts: s_g_gene_add_alpha_hyp = pyro.sample( "s_g_gene_add_alpha_hyp", dist.Gamma(self.gene_add_alpha_hyp_prior_alpha, self.gene_add_alpha_hyp_prior_beta).expand([2]).to_event(1), ) s_g_gene_add_mean = pyro.sample( "s_g_gene_add_mean", dist.Gamma( self.gene_add_mean_hyp_prior_alpha, self.gene_add_mean_hyp_prior_beta, ) .expand([self.n_batch, 1, 2]) .to_event(3), ) s_g_gene_add_alpha_e_inv = pyro.sample( "s_g_gene_add_alpha_e_inv", dist.Exponential(s_g_gene_add_alpha_hyp).expand([self.n_batch, 1, 2]).to_event(3), ) s_g_gene_add_alpha_e = self.ones / s_g_gene_add_alpha_e_inv.pow(2) s_g_gene_add = pyro.sample( "s_g_gene_add", dist.Gamma(s_g_gene_add_alpha_e, s_g_gene_add_alpha_e / s_g_gene_add_mean) .expand([self.n_batch, self.n_vars, 2]) .to_event(3), ) # =========Gene-specific overdispersion of spliced and unspliced counts ============== # # Overdispersion of unspliced counts: stochastic_v_ag_hyp = pyro.sample( "stochastic_v_ag_hyp", dist.Gamma( self.stochastic_v_ag_hyp_prior_alpha, self.stochastic_v_ag_hyp_prior_beta, ).expand([1, 2]).to_event(2)) stochastic_v_ag_hyp = stochastic_v_ag_hyp + 0.001 stochastic_v_ag_inv = pyro.sample( "stochastic_v_ag_inv", dist.Exponential(stochastic_v_ag_hyp) .expand([1, self.n_vars, 2]).to_event(3), ) stochastic_v_ag = (self.ones / stochastic_v_ag_inv.pow(2)) # =====================Expected expression ======================= # # biological expression with obs_plate: mu = pyro.deterministic('mu', (mu_expression + torch.einsum('cbi,bgi->cgi', obs2sample.unsqueeze(dim=-1), s_g_gene_add)) * \ detection_y_c * detection_y_i * detection_y_gi) # =====================DATA likelihood ======================= # # Likelihood (sampling distribution) of data_target & add overdispersion via NegativeBinomial with obs_plate: pyro.sample("data_target", dist.GammaPoisson(concentration= stochastic_v_ag, rate= stochastic_v_ag / mu), obs=torch.stack([u_data, s_data], axis = 2))