from typing import List, Optional
from datetime import date
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from anndata import AnnData
from pyro import clear_param_store
from scvi.model._utils import parse_use_gpu_arg
from scvi.dataloaders import AnnDataLoader
from scvi.utils import track
from scvi import REGISTRY_KEYS
from scvi.data import AnnDataManager
from scvi.data.fields import (
CategoricalObsField,
LayerField,
NumericalJointObsField,
NumericalObsField,
)
from scvi.model.base import BaseModelClass, PyroSampleMixin, PyroSviTrainMixin
from scvi.utils import setup_anndata_dsp
import pyro.distributions as dist
import torch
import scanpy as sc
import contextlib
import io
from numpy.linalg import norm
from scipy.sparse import csr_matrix
from numpy import inner
import scvelo as scv
from scvelo.plotting.velocity_embedding_grid import compute_velocity_on_grid
from ._velocity_embedding_stream import velocity_embedding_stream_modules
import scipy
import gseapy as gp
from cell2fate._pyro_base_cell2fate_module import Cell2FateBaseModule
from cell2fate._pyro_mixin import QuantileMixin
from ._cell2fate_DynamicalModel_module import \
Cell2fate_DynamicalModel_module
from cell2fate.utils import multiplot_from_generator
from cell2fate.utils import mu_mRNA_continousAlpha_globalTime_twoStates
import cell2fate as c2f
[docs]
class Cell2fate_DynamicalModel(QuantileMixin, PyroSampleMixin, PyroSviTrainMixin, BaseModelClass):
r"""
Cell2fate model. User-end model class.
.. note::
See Module class for description of the model.
Parameters
----------
adata
Single-cell AnnData object that has been registered via :func:`~scvi.data.setup_anndata`
and contains spliced and unspliced counts in ``adata.layers['spliced']``, ``adata.layers['unspliced']``
**model_kwargs
Keyword args for :class:`~scvi.external.LocationModelLinearDependentWMultiExperimentModel`
"""
def __init__(
self,
adata: AnnData,
model_class=None,
**model_kwargs,
):
# in case any other model was created before that shares the same parameter names.
clear_param_store()
super().__init__(adata)
if model_class is None:
model_class = Cell2fate_DynamicalModel_module
self.module = Cell2FateBaseModule(
model=model_class,
n_obs=len(adata.obs_names),
n_vars=len(adata.var_names),
n_batch=self.summary_stats["n_batch"],
**model_kwargs,
)
self._model_summary_string = f'Cell2fate Dynamical Model with the following params: \nn_batch: {self.summary_stats["n_batch"]} '
self.init_params_ = self._get_init_params(locals())
[docs]
@classmethod
@setup_anndata_dsp.dedent
def setup_anndata(
cls,
adata: AnnData,
layer: Optional[str] = None,
batch_key: Optional[str] = None,
labels_key: Optional[str] = None,
unspliced_label = 'unspliced',
spliced_label = 'spliced',
cluster_label = None,
**kwargs,
):
"""
%(summary)s.
Parameters
----------
%(param_layer)s
%(param_batch_key)s
%(param_labels_key)s
"""
setup_method_args = cls._get_setup_method_args(**locals())
adata.obs["_indices"] = np.arange(adata.n_obs).astype("int64")
if cluster_label:
anndata_fields = [
LayerField('unspliced', unspliced_label, is_count_data=True),
LayerField('spliced', spliced_label, is_count_data=True),
CategoricalObsField('clusters', cluster_label),
NumericalObsField(REGISTRY_KEYS.INDICES_KEY, "_indices"),
CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key)
]
else:
anndata_fields = [
LayerField('unspliced', unspliced_label, is_count_data=True),
LayerField('spliced', spliced_label, is_count_data=True),
CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key),
NumericalObsField(REGISTRY_KEYS.INDICES_KEY, "_indices"),
]
adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
adata_manager.register_fields(adata, **kwargs)
cls.register_manager(adata_manager)
[docs]
def train(
self,
max_epochs: int = 500,
batch_size: int = 1000,
train_size: float = 1,
lr: float = 0.01,
**kwargs,
):
"""
Training function for the model.
Parameters
----------
max_epochs
Number of passes through the dataset. If `None`, defaults to
``np.min([round((20000 / n_cells) * 400), 400])``
train_size
Size of training set in the range [0.0, 1.0].
batch_size
Minibatch size to use during training. If `None`, no minibatching occurs and all
data is copied to device (e.g., GPU).
lr
Optimiser learning rate (default optimiser is :class:`~pyro.optim.ClippedAdam`).
Specifying optimiser via plan_kwargs overrides this choice of lr.
kwargs
Other arguments to :py:meth:`scvi.model.base.PyroSviTrainMixin().train` method
"""
self.max_epochs = max_epochs
kwargs["max_epochs"] = max_epochs
kwargs["batch_size"] = batch_size
kwargs["train_size"] = train_size
kwargs["lr"] = lr
super().train(**kwargs)
def _export2adata(self, samples):
r"""
Export key model variables and samples
Parameters
----------
samples
Dictionary with posterior mean, 5%/95% quantiles, SD, samples, generated by ``.sample_posterior()``
Returns
-------
Dict
Updated dictionary with additional details is saved to ``adata.uns['mod']``.
"""
# add factor filter and samples of all parameters to unstructured data
results = {
"model_name": str(self.module.__class__.__name__),
"date": str(date.today()),
"var_names": self.adata.var_names.tolist(),
"obs_names": self.adata.obs_names.tolist(),
"post_sample_means": samples["post_sample_means"],
"post_sample_stds": samples["post_sample_stds"],
"post_sample_q05": samples["post_sample_q05"],
"post_sample_q95": samples["post_sample_q95"],
}
return results
def _export2adata_quantiles(self, samples):
r"""
Export key model variables and quantiles
Parameters
----------
samples
Dictionary with posterior mean, 5%/95% quantiles, SD, samples, generated by ``.sample_posterior()``
Returns
-------
Dict
Updated dictionary with additional details is saved to ``adata.uns['mod']``.
"""
# add factor filter and samples of all parameters to unstructured data
results = {
"model_name": str(self.module.__class__.__name__),
"date": str(date.today()),
"var_names": self.adata.var_names.tolist(),
"obs_names": self.adata.obs_names.tolist(),
"post_q25": samples["0.25"],
"post_q50": samples["0.5"],
"post_q75": samples["0.75"],
}
return results
[docs]
def plot_history(self, iter_start=0, iter_end=-1, ax=None):
r"""Plot training history
Parameters
----------
iter_start
Omit initial iterations from the plot.
iter_end
Omit last iterations from the plot.
ax
Matplotlib axis.
"""
if ax is None:
ax = plt
ax.set_xlabel = plt.xlabel
ax.set_ylabel = plt.ylabel
if iter_end == -1:
iter_end = len(self.history_["elbo_train"])
ax.plot(
self.history_["elbo_train"].index[iter_start:iter_end],
np.array(self.history_["elbo_train"].values.flatten())[iter_start:iter_end],
label="train",
)
ax.legend()
ax.xlim(0, len(self.history_["elbo_train"]))
ax.set_xlabel("Training epochs")
ax.set_ylabel("-ELBO loss")
plt.tight_layout()
[docs]
def compute_module_summary_statistics(self, adata):
"""
Computes the contribution of each module to mRNA molecules in each cell.
Parameters
----------
adata
AnnData object.
Returns
-------
AnnData
AnnData object with additional module-related summary statistics.
"""
if scipy.sparse.issparse(self.adata_manager.get_from_registry('unspliced')):
observed_total = torch.sum(torch.sum(torch.stack([torch.tensor(self.adata_manager.get_from_registry('unspliced').toarray()),
torch.tensor(self.adata_manager.get_from_registry('spliced').toarray())], axis = -1), axis = -1), axis = -1)
else:
observed_total = torch.sum(torch.sum(torch.stack([torch.tensor(self.adata_manager.get_from_registry('unspliced')),
torch.tensor(self.adata_manager.get_from_registry('spliced'))], axis = -1), axis = -1), axis = -1)
inferred_total = torch.sum(torch.sum(torch.tensor(self.samples['post_sample_means']['mu_expression']), axis = -1), axis = -1)
for m in range(self.module.model.n_modules):
mu_m = mu_mRNA_continousAlpha_globalTime_twoStates(
torch.tensor(self.samples['post_sample_means']['A_mgON'][m,:]),
torch.tensor(0.),
torch.tensor(self.samples['post_sample_means']['beta_g']),
torch.tensor(self.samples['post_sample_means']['gamma_g']),
torch.tensor(self.samples['post_sample_means']['lam_mi'][m,:]),
torch.tensor(self.samples['post_sample_means']['T_c'][:,:,0]),
torch.tensor(self.samples['post_sample_means']['T_mON'][:,:,m]),
torch.tensor(self.samples['post_sample_means']['T_mOFF'][:,:,m]),
torch.zeros((self.module.model.n_obs, self.module.model.n_vars)))
ss_total = torch.sum(torch.tensor(self.samples['post_sample_means']['A_mgON'][m,:])/torch.tensor(self.samples['post_sample_means']['gamma_g']) + \
torch.tensor(self.samples['post_sample_means']['A_mgON'][m,:])/torch.tensor(self.samples['post_sample_means']['beta_g']), axis = 1)
adata.obs['Module ' + str(m) + ' Activation'] = torch.sum(torch.sum(mu_m, axis = -1), axis = -1)/ss_total
adata.obs['Module ' + str(m) + ' State'] = 'OFF'
adata.obs['Module ' + str(m) + ' State'
][self.samples['post_sample_means']['T_c'][:,0,0] > self.samples['post_sample_means']['T_mON'][0,0,m]
] = 'Induction'
adata.obs['Module ' + str(m) + ' State'
][self.samples['post_sample_means']['T_c'][:,0,0] > self.samples['post_sample_means']['T_mOFF'][0,0,m]
] = 'Repression'
adata.obs['Module ' + str(m) + ' State'
][adata.obs['Module ' + str(m) + ' Activation'] > 0.95
] = 'ON'
adata.obs['Module ' + str(m) + ' State'
][adata.obs['Module ' + str(m) + ' Activation'] < 0.05
] = 'OFF'
adata.obs['Module ' + str(m) + ' Activation'] = torch.sum(torch.sum(mu_m, axis = -1), axis = -1)
return adata
[docs]
def plot_module_summary_statistics(self, adata, save = None):
'''
Plots weight, activation, velocity, switch ON/OFF time histograms for each module.
Parameters
----------
adata
AnnData object containing single-cell RNA sequencing data.
save
File path to save the plot. If not provided, the plot will not be saved.
'''
limit = np.max([np.sort(adata.obs['Module ' + str(i) + ' Activation'])[int(np.round(0.99*len(adata.obs['Module ' + str(i) + ' Activation'])))] for i in range(self.module.model.n_modules)])
fig, ax = plt.subplots(self.module.model.n_modules, 2, figsize = (10, 4*self.module.model.n_modules))
for i in range(self.module.model.n_modules):
sc.pl.umap(adata, color = ['Module ' + str(i) + ' Activation'], legend_loc = None,
size = 200, color_map = 'viridis', ax = ax[i,0], show = False, vmin = 0, vmax = limit)
sc.pl.umap(adata, color = ['Module ' + str(i) + ' State'], legend_loc = 'on data',
size = 200, ax = ax[i,1], show = False,
palette = {'ON': 'lime', 'OFF': 'grey', 'Induction': 'lightgreen', 'Repression': 'orange'})
plt.tight_layout()
if save:
plt.savefig(save)
[docs]
def export_posterior(
self,
adata,
sample_kwargs = {"num_samples": 30, "batch_size" : None,
"use_gpu" : True, 'return_samples': True},
export_slot: str = "mod",
full_velocity_posterior = False,
normalize = True):
"""
Summarises posterior distribution and exports results to anndata object. Also computes RNAvelocity (based on posterior of rates) and normalized counts (based on posterior of technical variables).
- **adata.obs:** Latent time, sequencing depth constant.
- **adata.var:** transcription/splicing/degredation rates, switch on and off times.
- **adata.uns:** Posterior of all parameters ('mean', 'sd', 'q05', 'q95' and optionally all samples), model name, date.
- **adata.layers:** ``velocity`` (expected gradient of spliced counts), ``velocity_sd`` (uncertainty in this gradient), ``spliced_norm``, ``unspliced_norm`` (normalized counts).
- **adata.uns:** If ``return_samples: True`` and ``full_velocity_posterior = True`` full posterior distribution for velocity is saved in ``adata.uns['velocity_posterior']``.
Parameters
----------
adata
AnnData object where results should be saved.
sample_kwargs
Optionally a dictionary of arguments for self.sample_posterior, namely:
- **num_sample:s** Number of samples to use (Default = 1000).
- **batch_size:** Data batch size (keep low enough to fit on GPU, default 2048).
- **use_gpu:** Use gpu for generating samples.
- **return_samples:** Export all posterior samples (Otherwise just summary statistics).
export_slot
adata.uns slot where to export results.
full_velocity_posterior
Whether to save full posterior of velocity (only possible if "return_samples: True").
normalize
Whether to compute normalized spliced and unspliced counts based on posterior of technical variables.
Returns
-------
AnnData
AnnData object with posterior added in adata.obs, adata.var and adata.uns.
"""
if sample_kwargs['batch_size'] == None:
sample_kwargs['batch_size'] = adata.n_obs
# generate samples from posterior distributions for all parameters
# and compute mean, 5%/95% quantiles and standard deviation
self.samples = self.sample_posterior(**sample_kwargs)
# export posterior distribution summary for all parameters and
# annotation (model, date, var, obs and cell type names) to anndata object
adata.uns[export_slot] = self._export2adata(self.samples)
if sample_kwargs['return_samples']:
print('Warning: Saving ALL posterior samples. Specify "return_samples: False" to save just summary statistics.')
adata.uns[export_slot]['post_samples'] = self.samples['posterior_samples']
adata.obs['Time (hours)'] = self.samples['post_sample_means']['T_c'].flatten() - np.min(self.samples['post_sample_means']['T_c'].flatten())
adata.obs['Time Uncertainty (sd)'] = self.samples['post_sample_stds']['T_c'].flatten()
# adata.layers['spliced mean'] = self.samples['post_sample_means']['mu_expression'][...,1]
# adata.layers['velocity'] = torch.tensor(self.samples['post_sample_means']['beta_g']) * \
# self.samples['post_sample_means']['mu_expression'][...,0] - \
# torch.tensor(self.samples['post_sample_means']['gamma_g']) * \
# self.samples['post_sample_means']['mu_expression'][...,1]
return adata
[docs]
def export_posterior_quantiles(
self,
adata,
batch_size = None,
export_slot: str = "mod",
full_velocity_posterior = False,
normalize = True,
use_gpu = True,
use_median = False):
"""
Exports posteriors as quantiles. Similar to :py:meth:`cell2fate.Cell2fate_DynamicalModel.export_posterior` for more scalable workflow.
Parameters
----------
adata
AnnData object.
batch_size
Batch size for exporting quantiles. Defaults to None.
export_slot
Slot for exporting quantiles. Defaults to "mod".
full_velocity_posterior : bool, optional
Flag to indicate whether to export the full velocity posterior. Defaults to False.
normalize
Flag to indicate whether to normalize. Defaults to True.
use_gpu
Flag to indicate whether to use GPU for computation. Defaults to True.
Returns
-------
AnnData
AnnData object with exported quantiles.
"""
if batch_size == None:
batch_size = adata.n_obs
quantiles = [0.25,0.5,0.75]
quantiles_dict = {}
for quantile in quantiles:
quantiles_dict[str(quantile)]=self.posterior_quantile(q=quantile,batch_size=batch_size, use_gpu=use_gpu, use_median=use_median)
adata.uns[export_slot] = self._export2adata_quantiles(quantiles_dict)
self.samples = {}
self.samples['post_sample_means'] = quantiles_dict["0.5"]
adata.obs['Time (hours)'] = self.samples['post_sample_means']['T_c'].flatten() - np.min(self.samples['post_sample_means']['T_c'].flatten())
adata.obs['Time Uncertainty (QCD)'] = (quantiles_dict['0.75']['T_c'].flatten() - quantiles_dict['0.25']['T_c'].flatten())/ (quantiles_dict['0.25']['T_c'].flatten() + quantiles_dict['0.75']['T_c'].flatten())
adata.layers['spliced mean'] = self.samples['post_sample_means']['mu_expression'][...,1]
adata.layers['velocity'] = torch.tensor(self.samples['post_sample_means']['beta_g']) * \
self.samples['post_sample_means']['mu_expression'][...,0] - \
torch.tensor(self.samples['post_sample_means']['gamma_g']) * \
self.samples['post_sample_means']['mu_expression'][...,1]
return adata
[docs]
def compute_velocity_graph_Bergen2020(mod, adata, n_neighbours = None, full_posterior = True, spliced_key = 'Ms',
velocity_key = 'velocity'):
"""
Computes a "velocity graph" similar to the method in:
"Bergen et al. (2020), Generalizing RNA velocity to transient cell states through dynamical modeling"
Parameters
----------
adata
AnnData object with velocity information in ``adata.layers['velocity']`` (expectation value) or
``adata.uns['velocity_posterior']`` (full posterior). Also normalized spliced counts in ``adata.layers['spliced_nor']``
n_neighbours
How many nearest neighbours to consider (all non nearest neighbours have edge weights set to 0).
If not specified, 10% of the total number of cells is used.
full_posterior
Whether to use full posterior to compute velocity graph (otherwise expectation value is used).
velocity_key
Key to access velocity information in adata.
spliced_key
Key to access normalized spliced counts in adata.
Returns
-------
Numpy.ndarray
Velocity graph
"""
M = len(adata.obs_names)
if not n_neighbours:
n_neighbours = int(np.round(M*0.05, 0))
scv.pp.neighbors(adata, n_neighbors = n_neighbours)
adata.obsp['binary'] = adata.obsp['connectivities'] != 0
distances = []
velocities = []
cosines = []
transition_probabilities = []
matrices = []
if full_posterior:
for i in range(M):
distances += [adata.layers[spliced_key][adata.obsp['binary'].toarray()[i,:],:] - adata.layers[spliced_key][i,:].flatten()]
velocities += [adata.uns['velocity_posterior'][:,i,:]]
cosines += [inner(distances[i], velocities[i])/(norm(distances[i])*norm(velocities[i]))]
transition_probabilities += [np.exp(2*cosines[i])]
transition_probabilities[i] = transition_probabilities[i]/np.sum(transition_probabilities[i], axis = 0)
matrices += [csr_matrix((np.mean(np.array(transition_probabilities[i]), axis = 1),
(np.repeat(i, len(transition_probabilities[i])), np.where(adata.obsp['binary'][i,:].toarray())[1])),
shape=(M, M))]
else:
for i in range(M):
distances += [adata.layers[spliced_key][adata.obsp['binary'].toarray()[i,:],:] - adata.layers[spliced_key][i,:].flatten()]
velocities += [adata.layers[velocity_key][i,:].reshape(1,len(adata.var_names))]
cosines += [inner(distances[i], velocities[i])/(norm(distances[i])*norm(velocities[i]))]
transition_probabilities += [np.exp(2*cosines[i])]
transition_probabilities[i] = transition_probabilities[i]/np.sum(transition_probabilities[i], axis = 0)
matrices += [csr_matrix((np.mean(np.array(transition_probabilities[i]), axis = 1),
(np.repeat(i, len(transition_probabilities[i])), np.where(adata.obsp['binary'][i,:].toarray())[1])),
shape=(M, M))]
return sum(matrices)
[docs]
def compute_and_plot_module_velocity(self, adata, delete = True, plot = True, save = None,
plotting_kwargs = {"color": 'clusters', 'legend_fontsize': 10,
'legend_loc': 'right_margin', 'min_mass': 4}):
"""
Computes the RNA velocity produced by each module, as well as associated "velocity graph" and
then plots results on a UMAP based on the method in:
"Bergen et al. (2020), Generalizing RNA velocity to transient cell states through dynamical modeling"
Parameters
----------
adata
AnnData object with spliced and unspliced count data.
delete
Whether to delete computed layers after processing.
plot
Whether to plot the results.
save
Filepath to save the plot.
plotting_kwargs
Keyword arguments for plotting.
"""
n_modules = self.module.model.n_modules
fix, ax = plt.subplots(n_modules, 1, figsize = (6, n_modules*4))
for m in range(n_modules):
print('Computing velocity produced by Module ' + str(m) + ' ...')
with contextlib.redirect_stdout(io.StringIO()):
mu_m = mu_mRNA_continousAlpha_globalTime_twoStates(
torch.tensor(self.samples['post_sample_means']['A_mgON'][m,:]),
torch.tensor(0., dtype = torch.float),
torch.tensor(self.samples['post_sample_means']['beta_g']),
torch.tensor(self.samples['post_sample_means']['gamma_g']),
torch.tensor(self.samples['post_sample_means']['lam_mi'][m,:]),
torch.tensor(self.samples['post_sample_means']['T_c'][:,:,0]),
torch.tensor(self.samples['post_sample_means']['T_mON'][:,:,m]),
torch.tensor(self.samples['post_sample_means']['T_mOFF'][:,:,m]),
torch.zeros((self.module.model.n_obs, self.module.model.n_vars)))
count_sum = torch.sum(torch.sum(mu_m, axis = 1), axis = -1)
n_problem_cells = torch.sum(count_sum == torch.min(count_sum))
if n_problem_cells > 3:
problem_cells_index = count_sum == torch.min(count_sum)
mu_m[problem_cells_index,:,0] = torch.tensor(np.random.sample(n_problem_cells)).unsqueeze(-1)*\
torch.tensor(self.samples['post_sample_means']['mu_expression'][problem_cells_index,:,0])*torch.tensor(10**(-5))
mu_m[problem_cells_index,:,1] = torch.tensor(np.random.sample(n_problem_cells)).unsqueeze(-1)*\
torch.tensor(self.samples['post_sample_means']['mu_expression'][problem_cells_index,:,1])*torch.tensor(10**(-5))
adata.layers['Module ' + str(m) + 'Spliced Mean'] = mu_m[...,1]
adata.layers['Module ' + str(m) + ' Velocity'] = torch.tensor(self.samples['post_sample_means']['beta_g']) * \
mu_m[...,0] - torch.tensor(self.samples['post_sample_means']['gamma_g']) * mu_m[...,1]
adata.uns['Module ' + str(m) + ' Velocity' + '_graph'] = self.compute_velocity_graph_Bergen2020(
adata, n_neighbours = None, full_posterior = False,
velocity_key = 'Module ' + str(m) + ' Velocity',
spliced_key = 'Module ' + str(m) + 'Spliced Mean')
if plot:
try:
scv.pl.velocity_embedding_stream(adata, basis='umap', save = False, vkey='Module ' + str(m) + ' Velocity',
**plotting_kwargs, show = False, ax = ax[m])
ax[m].set_title('Module ' + str(m) + '\n Velocity Graph UMAP Embedding')
except:
print(f"no velocity for module {m}")
del adata.layers['Module ' + str(m) + 'Spliced Mean']
del mu_m
if delete:
del adata.layers['Module ' + str(m) + ' Velocity']
if save:
plt.savefig(save)
[docs]
def compute_and_plot_total_velocity(self, adata, delete = True, plot = True, save = None,
plotting_kwargs = {"color": 'clusters', 'legend_fontsize': 10, 'legend_loc': 'right_margin'}, return_adata=False):
"""
Computes total RNA velocity, as well as associated "velocity graph" and
then plots results on a UMAP based on the method in:
"Bergen et al. (2020), Generalizing RNA velocity to transient cell states through dynamical modeling"
Parameters
----------
adata
AnnData object with spliced and unspliced count data.
delete
Whether to delete computed layers after processing.
plot
Whether to plot the results.
save
Filepath to save the plot.
plotting_kwargs
Keyword arguments for plotting.
"""
print('Computing total RNAvelocity ...')
with contextlib.redirect_stdout(io.StringIO()):
adata.layers['Spliced Mean'] = self.samples['post_sample_means']['mu_expression'][...,1]
adata.layers['Velocity'] = torch.tensor(self.samples['post_sample_means']['beta_g']) * \
self.samples['post_sample_means']['mu_expression'][...,0] - \
torch.tensor(self.samples['post_sample_means']['gamma_g']) * \
self.samples['post_sample_means']['mu_expression'][...,1]
adata.uns['Velocity' + '_graph'] = self.compute_velocity_graph_Bergen2020(
adata, n_neighbours = None, full_posterior = False,
velocity_key = 'Velocity',
spliced_key = 'Spliced Mean')
if plot:
fix, ax = plt.subplots(1, 1, figsize = (6, 4))
scv.pl.velocity_embedding_stream(adata, basis='umap', save = False, vkey='Velocity',
**plotting_kwargs, show = False, ax = ax)
if save:
plt.savefig(save)
del adata.layers['Spliced Mean']
if delete:
del adata.layers['Velocity']
[docs]
def compute_and_plot_total_velocity_scvelo(self, adata, delete = True, plot = True, save = None,
plotting_kwargs = {"color": 'clusters', 'legend_fontsize': 10, 'legend_loc': 'right_margin'}):
"""
Computes total RNA velocity, as well as associated "velocity graph" and
then plots results on a UMAP based on the method in:
"Bergen et al. (2020), Generalizing RNA velocity to transient cell states through dynamical modeling"
Parameters
----------
adata
AnnData object with spliced and unspliced count data.
delete
Whether to delete computed layers after processing.
plot
Whether to plot the results.
save
Filepath to save the plot.
plotting_kwargs
Keyword arguments for plotting.
"""
print('Computing total RNAvelocity ...')
with contextlib.redirect_stdout(io.StringIO()):
adata.layers['Mu'] = self.samples['post_sample_means']['mu_expression'][...,0]
adata.layers['Ms'] = self.samples['post_sample_means']['mu_expression'][...,1]
adata.layers['velocity'] = torch.tensor(self.samples['post_sample_means']['beta_g']) * \
self.samples['post_sample_means']['mu_expression'][...,0] - \
torch.tensor(self.samples['post_sample_means']['gamma_g']) * \
self.samples['post_sample_means']['mu_expression'][...,1]
scv.pp.neighbors(adata)
scv.tl.velocity_graph(adata, vkey = 'velocity')
scv.tl.velocity_embedding(adata, vkey = 'velocity')
if plot:
fix, ax = plt.subplots(1, 1, figsize = (6, 4))
scv.pl.velocity_embedding_stream(adata, basis='umap', save = False, vkey='velocity',
**plotting_kwargs, show = False, ax = ax)
if save:
plt.savefig(save)
del adata.layers['Ms']
del adata.layers['Mu']
if delete:
del adata.layers['velocity']
[docs]
def compare_module_activation(self, adata, chosen_modules, time_max = None, time_min = 0,
save = None, ncol = 1):
"""
Compares the activation of chosen modules across time.
Parameters
----------
adata
AnnData object.
chosen_modules
List of module indices to compare.
time_max
Maximum time point for comparison.
time_min
Minimum time point for comparison.
save
Filepath to save the plot.
ncol
Number of columns in the legend.
"""
n_modules = self.module.model.n_modules
fig, ax = plt.subplots(1, 1, figsize=(18, 5))
for m in chosen_modules:
T_c = torch.tensor(0.).unsqueeze(-1).unsqueeze(-1)
Tmax = self.samples['post_sample_means']['Tmax']
count = 0
fraction = 0
fraction_list = [fraction]
T_c_list = [0]
ss_spliced = torch.sum(torch.tensor(self.samples['post_sample_means']['A_mgON'][m,:])\
/torch.tensor(self.samples['post_sample_means']['gamma_g']))
abundance = torch.sum(mu_mRNA_continousAlpha_globalTime_twoStates(
torch.tensor(self.samples['post_sample_means']['A_mgON'][m,:]),
torch.tensor(0.),
torch.tensor(self.samples['post_sample_means']['beta_g']),
torch.tensor(self.samples['post_sample_means']['gamma_g']),
torch.tensor(self.samples['post_sample_means']['lam_mi'][m,:]),
torch.tensor(self.samples['post_sample_means']['T_c'][:,:,0]),
torch.tensor(np.mean(self.samples['post_sample_means']['T_mON'][:,:,m])),
torch.tensor(np.mean(self.samples['post_sample_means']['T_mOFF'][:,:,m])),
torch.zeros((self.module.model.n_obs, self.module.model.n_vars)))[...,1], axis = -1)
ax.scatter(self.samples['post_sample_means']['T_c'][:,:,0].flatten() - np.min(self.samples['post_sample_means']['T_c'][:,:,0].flatten()), abundance, s = 10, label = 'Module ' + str(m))
ax.set_xlabel('Time (hours)')
ax.set_ylabel('Total Spliced UMI Counts')
if time_max or time_min > 0:
ax.set_xlim(time_min, time_max)
ax.legend(frameon=False, ncol = ncol)
ax.set_title('Module Activation Across Cells In Dataset')
if save:
plt.savefig(save)
[docs]
def plot_technical_variables(self, adata, save = False):
"""
Plot posterior of technical variables in the model.
Parameters
----------
adata
AnnData object.
save
Whether to save the plot.
"""
fig, ax = plt.subplots(3, 2, figsize = (12, 9))
ax[0,0].scatter([str(x) for x in np.unique(adata.obs['_scvi_batch'])],
self.samples['post_sample_means']['detection_mean_y_e'], s = 150, c = 'black')
ax[0,0].set_xlabel('Batch Number')
ax[0,0].set_ylabel('Relative Detection Efficiency')
ax[0,0].set_title('Mean Relative Detection Efficiency in each batch')
detection_y_i = self.samples['post_sample_means']['detection_y_i']
ax[0,1].hist(self.samples['post_sample_means']['detection_y_c'].flatten()*detection_y_i[0,0,0],
bins = 100, label = 'unspliced', alpha = 0.75, color = 'red')
ax[0,1].hist(self.samples['post_sample_means']['detection_y_c'].flatten()*detection_y_i[0,0,1],
bins = 100, label = 'spliced', alpha = 0.75, color = 'blue')
ax[0,1].legend(frameon=False)
ax[0,1].set_xlabel('Relative Detection Efficiency')
ax[0,1].set_ylabel('Number of Cells')
ax[0,1].set_title('Relative Detection Efficiency across cells')
ax[1,0].scatter([str(x) for x in np.unique(adata.obs['_scvi_batch'])],
self.samples['post_sample_means']['s_g_gene_add_mean'][...,0],
s = 150, c = 'red', label = 'unspliced')
ax[1,0].scatter([str(x) for x in np.unique(adata.obs['_scvi_batch'])],
self.samples['post_sample_means']['s_g_gene_add_mean'][...,1],
s = 150, c = 'blue', label = 'spliced')
ax[1,0].legend(frameon=False)
ax[1,0].set_xlabel('Batch Number')
ax[1,0].set_ylabel('Expected Ambient RNA')
ax[1,0].set_title('Mean Ambient RNA counts in each batch')
ax[1,1].hist(np.log10(self.samples['post_sample_means']['s_g_gene_add'][...,0].flatten()),
bins = 100, alpha = 0.75, label = 'unspliced', color = 'red')
ax[1,1].hist(np.log10(self.samples['post_sample_means']['s_g_gene_add'][...,1].flatten()),
bins = 100, alpha = 0.75, label = 'spliced', color = 'blue')
ax[1,1].legend(frameon=False)
ax[1,1].set_xlabel('log10 Expected Ambient RNA Counts')
ax[1,1].set_ylabel('Number of Genes')
ax[1,1].set_title('log10 Ambient RNA across genes')
ax[2,1].hist(np.log10(1./self.samples['post_sample_means']['stochastic_v_ag_inv'][...,0].flatten()**2),
bins = 100, alpha = 0.75, label = 'unspliced', color = 'red')
ax[2,1].hist(np.log10(1./self.samples['post_sample_means']['stochastic_v_ag_inv'][...,1].flatten()**2),
bins = 100, alpha = 0.75, label = 'spliced', color = 'blue')
ax[2,1].set_xlabel('log10 Overdispersion Factor')
ax[2,1].set_ylabel('Number of Genes')
ax[2,1].set_title('log10 Overdispersion factor across genes \n (smaller = more variance)')
ax[2,0].hist(self.samples['post_sample_means']['detection_y_gi'][...,0].flatten(),
bins = 100, alpha = 0.75, label = 'unspliced', color = 'red')
ax[2,0].hist(self.samples['post_sample_means']['detection_y_gi'][...,1].flatten(),
bins = 100, alpha = 0.75, label = 'spliced', color = 'blue')
ax[2,0].set_xlabel('Relative Detection Efficiency')
ax[2,0].set_ylabel('Number of Genes')
ax[2,0].set_title('Relative Detection Efficiency across genes')
ax[2,0].legend(frameon=False)
plt.tight_layout()
if save:
plt.savefig(save)
[docs]
def view_history(self):
"""
View training history over various training windows to assess convergence or spot potential training problems.
"""
def generatePlots():
yield
self.plot_history()
yield
self.plot_history(int(np.round(self.max_epochs/8)))
yield
self.plot_history(int(np.round(self.max_epochs/4)))
yield
self.plot_history(int(np.round(self.max_epochs/2)))
multiplot_from_generator(generatePlots(), 4)
[docs]
def get_module_top_features(self, adata, background, species = 'Mouse', p_adj_cutoff = 0.01, n_top_genes = None):
"""
Returns a dataframe with top Genes, TFs, and GO terms of each module.
Parameters
----------
adata
AnnData object.
background
List of genes to consider as background.
species
Species for which to consider TFs and gene sets. Defaults to 'Mouse'.
p_adj_cutoff
Adjusted p-value cutoff for enrichment analysis. Defaults to 0.01.
n_top_genes
Number of top genes to consider for each module. Defaults to None.
Returns
-------
Tuple
A tuple containing:
- **DataFrame:** DataFrame with top genes, TFs, and GO terms of each module.
- **List:** List of DataFrames containing enriched GO terms for each module.
"""
tab = pd.DataFrame(columns = ('Module Number', 'Genes Ranked', 'TFs Ranked', 'Terms Ranked'))
tab['Module Number'] = list(range(self.module.model.n_modules))
if species == 'Human':
TFs = np.array(pd.read_csv(c2f.__file__[:-11] + 'Human_TFs.txt', header=None, index_col=False)).flatten()
elif species == 'Mouse':
TFs = np.array(pd.read_csv(c2f.__file__[:-11] + 'Mouse_TFs.txt', header=None, index_col=False)).flatten()
TFs = np.array([tf for tf in TFs if tf in adata.var_names])
gene_by_module_weight = torch.zeros((self.module.model.n_modules, self.module.model.n_vars))
gene_by_module_sorted = np.empty((self.module.model.n_modules, self.module.model.n_vars), dtype=object)
TF_by_module_sorted = np.empty((self.module.model.n_modules, len(TFs)), dtype=object)
TF_boolean = np.array([g in TFs for g in adata.var_names])
inferred_total = torch.sum(torch.tensor(self.samples['post_sample_means']['mu_expression'])[...,1], axis = 0)
for m in range(self.module.model.n_modules):
mu_m = mu_mRNA_continousAlpha_globalTime_twoStates(
torch.tensor(self.samples['post_sample_means']['A_mgON'][m,:]),
torch.tensor(0.),
torch.tensor(self.samples['post_sample_means']['beta_g']),
torch.tensor(self.samples['post_sample_means']['gamma_g']),
torch.tensor(self.samples['post_sample_means']['lam_mi'][m,:]),
torch.tensor(self.samples['post_sample_means']['T_c'][:,:,0]),
torch.tensor(self.samples['post_sample_means']['T_mON'][:,:,m]),
torch.tensor(self.samples['post_sample_means']['T_mOFF'][:,:,m]),
torch.zeros((self.module.model.n_obs, self.module.model.n_vars)))
gene_by_module_weight[m,:] = torch.sum(mu_m[...,1], axis = 0)/inferred_total
gene_by_module_sorted[m,:] = adata.var_names[np.argsort(-1*gene_by_module_weight[m,:])]
TF_by_module_sorted[m,:] = adata.var_names[TF_boolean][np.argsort(-1*gene_by_module_weight[m,TF_boolean])]
tab.iloc[m,1] = ', '.join(list(gene_by_module_sorted[m,:]))
tab.iloc[m,2] = ', '.join(list(TF_by_module_sorted[m,:]))
### Select n_genes/n_modules top genes for each module
### Find enriched GO terms
n_modules = self.module.model.n_modules
results = []
if not n_top_genes:
n_top_genes = int(self.module.model.n_vars/n_modules/2)
for m in range(n_modules):
gene_list = list(gene_by_module_sorted[m,:n_top_genes])
if species == 'Mouse':
enr = gp.enrichr(gene_list=gene_list,
background = background,
gene_sets=['GO_Biological_Process_2021'], # 'GO_Cellular_Component_2021', 'KEGG_2019_Mouse'
organism='mouse', # don't forget to set organism to the one you desired! e.g. Yeast
outdir=None, # don't write to disk
)
elif species == 'Human':
enr = gp.enrichr(gene_list=gene_list,
background = background,
gene_sets=['GO_Biological_Process_2021', 'GO_Cellular_Component_2021', 'KEGG_2021_Human'],
organism='human', # don't forget to set organism to the one you desired! e.g. Yeast
outdir=None, # don't write to disk
)
tab.iloc[m,3] = ', '.join(list(enr.results.loc[enr.results['Adjusted P-value'] < p_adj_cutoff,:]['Term']))
results += [enr.results.loc[enr.results['Adjusted P-value'] < p_adj_cutoff,:]]
### Save topGenes, topTFs and topGOterms in dataframe.
return tab, results
[docs]
def plot_top_features(self, adata, tab, chosen_modules, mode = 'all genes',
n_top_features = 3, save = False, process = True):
"""
Plot top features for chosen modules.
Parameters
----------
adata
AnnData object.
tab
Table containing feature rankings.
chosen_modules
List of module indices to plot.
mode
Mode for selecting features.
n_top_features
Number of top features to plot.
save
Whether to save the plot.
process
Whether to preprocess the data.
"""
if process:
print('Reprocessing adata.X, set process = False if this is not desired.')
adata.X = adata.layers['unspliced'] + adata.layers['spliced']
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
sc.pp.scale(adata, max_value=10)
fig, ax, = plt.subplots(len(chosen_modules), n_top_features, figsize = (5*n_top_features, 4*len(chosen_modules)))
for i in range(len(chosen_modules)):
m = chosen_modules[i]
if mode == 'all genes':
for_plotting = tab['Genes Ranked'].iloc[m].replace(" ", "").split(',')[:n_top_features]
elif mode == 'TFs':
for_plotting = tab['TFs Ranked'].iloc[m].replace(" ", "").split(',')[:n_top_features]
for j in range(n_top_features):
sc.pl.umap(adata, color = for_plotting[j], legend_loc = 'right margin',
size = 200, ncols = n_top_features, show = False, ax = ax[i,j])
if save:
plt.savefig(save)
def plot_module_summary_statistics_2(self, adata, chosen_modules, chosen_clusters,
marker_genes, marker_TFs, cluster_key = 'clusters', save = None):
"""
Plot summary statistics for chosen modules.
Parameters
----------
adata
AnnData object.
chosen_modules
List of module indices to plot.
chosen_clusters
List of cluster names to include.
marker_genes
List of marker genes to plot.
marker_TFs
List of marker transcription factors to plot.
cluster_key
Key in adata.obs storing cluster information.
save
Filepath to save the plot.
"""
limit = np.max([np.sort(adata.obs['Module ' + str(i) + ' Activation'])[int(np.round(0.99*len(adata.obs['Module ' + str(i) + ' Activation'])))]
for i in chosen_modules])
adata.X = adata.layers['unspliced'] + adata.layers['spliced']
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
sc.pp.scale(adata, max_value=10)
plt.rcParams.update({'font.size': 14})
subset = np.array([c in chosen_clusters for c in adata.obs[cluster_key]])
plt.scatter(1, 1, label='Induction', s = 100, c='lightgreen')
plt.scatter(1, 1, label='ON', s = 100, c='lime')
plt.scatter(1, 1, label='Repression', s = 100, c='orange')
plt.scatter(1, 1, label='OFF', s = 100, c='grey')
plt.legend()
if save:
plt.savefig(save[:-4] + '_legend.pdf')
plt.show()
fig, ax = plt.subplots(4, len(chosen_modules), figsize = (3*len(chosen_modules), 20))
adata = adata[subset,:]
for i in range(len(chosen_modules)):
m = chosen_modules[i]
sc.pl.umap(adata, color = ['Module ' + str(m) + ' Activation'], legend_loc = None,
size = 200, color_map = 'viridis', ax = ax[0,i], show = False, s = 300,
vmin = 0, vmax = limit)
sc.pl.umap(adata, color = ['Module ' + str(m) + ' State'],
size = 200, ax = ax[1,i], show = False, legend_fontsize = 'x-large', s = 300, legend_loc = 'right_margin',
palette = {'ON': 'lime', 'OFF': 'grey', 'Induction': 'lightgreen', 'Repression': 'orange'})
sc.pl.umap(adata, color = marker_genes[i], legend_loc = 'right margin',
size = 200, show = False, ax = ax[2,i])
sc.pl.umap(adata, color = marker_TFs[i], legend_loc = 'right margin',
size = 200, show = False, ax = ax[3,i])
plt.tight_layout()
if save:
plt.savefig(save)
def example_module_activation(self, adata, chosen_module, time_max = None, time_min = 0,
save = None):
fig, ax = plt.subplots(1, 1, figsize=(10, 5))
m = chosen_module
n_obs = 10000
T_c_min = torch.min(torch.tensor(self.samples['post_sample_means']['T_c']))
T_c_max = torch.max(torch.tensor(self.samples['post_sample_means']['T_c']))
T_c = torch.arange(T_c_min, T_c_max, (T_c_max - T_c_min)/n_obs).unsqueeze(-1).unsqueeze(-1)
Tmax = self.samples['post_sample_means']['Tmax']
count = 0
fraction = 0
fraction_list = [fraction]
ss_spliced = torch.sum(torch.tensor(self.samples['post_sample_means']['A_mgON'][m,:])\
/torch.tensor(self.samples['post_sample_means']['gamma_g']))
abundance = torch.sum(mu_mRNA_continousAlpha_globalTime_twoStates(
torch.tensor(self.samples['post_sample_means']['A_mgON'][m,:]),
torch.tensor(0.),
torch.tensor(self.samples['post_sample_means']['beta_g']),
torch.tensor(self.samples['post_sample_means']['gamma_g']),
torch.tensor(self.samples['post_sample_means']['lam_mi'][m,:]),
T_c[:,:,0],
torch.tensor(np.mean(self.samples['post_sample_means']['T_mON'][:,:,m])),
torch.tensor(np.mean(self.samples['post_sample_means']['T_mOFF'][:,:,m])),
torch.zeros((n_obs, self.module.model.n_vars)))[...,1], axis = -1)
abundance2 = torch.sum(mu_mRNA_continousAlpha_globalTime_twoStates(
torch.tensor(self.samples['post_sample_means']['A_mgON'][m,:]),
torch.tensor(0.),
torch.tensor(self.samples['post_sample_means']['beta_g']),
torch.tensor(self.samples['post_sample_means']['gamma_g']),
torch.tensor(self.samples['post_sample_means']['lam_mi'][m,:]),
T_c[:,:,0],
torch.tensor(np.mean(self.samples['post_sample_means']['T_mON'][:,:,m])),
torch.tensor(np.mean(self.samples['post_sample_means']['T_mOFF'][:,:,m]))*1000.,
torch.zeros((n_obs, self.module.model.n_vars)))[...,1], axis = -1)
steady_state = torch.sum(torch.tensor(self.samples['post_sample_means']['A_mgON'][m,:])/torch.tensor(self.samples['post_sample_means']['gamma_g']))
plt.axhspan(xmin = 0, xmax = 1, ymin = steady_state*0.95, ymax = steady_state,
facecolor='lime', alpha=0.5)
plt.axhspan(xmin = 0, xmax = (np.mean(self.samples['post_sample_means']['T_mOFF'][:,:,m])-time_min)/(time_max-time_min), ymin = steady_state*0.05, ymax = steady_state*0.95,
facecolor='lightgreen', alpha=0.5)
plt.axhspan(xmin = (np.mean(self.samples['post_sample_means']['T_mOFF'][:,:,m])-time_min)/(time_max-time_min), xmax = 1,
ymin = steady_state*0.05, ymax = steady_state*0.95,
facecolor='orange', alpha=0.5)
plt.axhspan(xmin = 0, xmax = 1,
ymin = steady_state*0, ymax = steady_state*0.05,
facecolor='grey', alpha=0.5)
ax.scatter(T_c,
abundance2, s = 3, label = 'Module ' + str(m), c = 'grey', alpha = 0.25)
ax.scatter(T_c,
abundance, s = 5, label = 'Module ' + str(m), c = 'black')
ax.axhline(xmin = 0, xmax = time_max, y = steady_state, linestyle = '--', linewidth = 1, c = 'black')
ax.axhline(xmin = 0, xmax = time_max, y = steady_state*0.05, linestyle = '--', linewidth = 1, c = 'black')
ax.axhline(xmin = 0, xmax = time_max, y = steady_state*0.95, linestyle = '--', linewidth = 1, c = 'black')
ax.axvline(ymin = 0, ymax = np.float(steady_state), x = np.mean(self.samples['post_sample_means']['T_mOFF'][:,:,m]), linestyle = '--', linewidth = 1, c = 'black')
ax.set_xlabel('Time (hours)')
ax.set_ylabel('Total Spliced UMI Counts')
ax.set_ylim(-10, steady_state+20)
if time_max or time_min > 0:
ax.set_xlim(time_min, time_max)
ax.set_title('Example Module Activation')
if save:
plt.savefig(save)
[docs]
def plot_genes(self, adata, chosen_clusters, marker_genes, cluster_key = 'clusters', save = None):
"""
Plot expression of marker genes across chosen clusters.
Parameters
----------
adata
AnnData object.
chosen_clusters
List of cluster names to include.
marker_genes
List of marker genes to plot.
cluster_key
Key in adata.obs storing cluster information.
save
Filepath to save the plot.
"""
import matplotlib.pyplot as plt
adata.X = adata.layers['unspliced'] + adata.layers['spliced']
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
sc.pp.scale(adata, max_value=10)
plt.rcParams.update({'font.size': 12})
subset = np.array([c in chosen_clusters for c in adata.obs[cluster_key]])
fig, ax = plt.subplots(1, len(marker_genes), figsize = (3*len(marker_genes), 5))
adata = adata[subset,:]
for i in range(len(marker_genes)):
sc.pl.umap(adata, color = marker_genes[i], legend_loc = 'right margin',
size = 200, show = False, ax = ax[i])
plt.tight_layout()
if save:
plt.savefig(save)
def _posterior_samples_minibatch(
self, use_gpu: bool = None, batch_size: Optional[int] = None, **sample_kwargs
):
"""
Temporary solution for batch sampling problem.
Parameters
----------
use_gpu
Load model on default GPU if available (if None or True),
or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False).
batch_size
Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
Returns
-------
dictionary
{variable_name: [array with samples in 0 dimension]}
"""
samples = dict()
_, device = parse_use_gpu_arg(use_gpu)
batch_size = batch_size if batch_size is not None else settings.batch_size
train_dl = AnnDataLoader(
self.adata_manager, shuffle=False, batch_size=batch_size
)
# sample local parameters
i = 0
cell_specific=['t_c', 'T_c', 'mu_expression', 'detection_y_c', 'mu', 'data_target']
for tensor_dict in track(
train_dl,
style="tqdm",
description="Sampling local variables, batch: ",
):
args, kwargs = self.module._get_fn_args_from_batch(tensor_dict)
args = [a.to(device) for a in args]
kwargs = {k: v.to(device) for k, v in kwargs.items()}
self.to_device(device)
if i == 0:
return_observed = getattr(sample_kwargs, "return_observed", False)
obs_plate_sites = self._get_obs_plate_sites(
args, kwargs, return_observed=return_observed
)
if len(obs_plate_sites) == 0:
# if no local variables - don't sample
break
obs_plate_dim = list(obs_plate_sites.values())[0]
sample_kwargs_obs_plate = sample_kwargs.copy()
sample_kwargs_obs_plate[
"return_sites"
] = self._get_obs_plate_return_sites(
sample_kwargs["return_sites"], list(obs_plate_sites.keys())
)
sample_kwargs_obs_plate["show_progress"] = False
samples = self._get_posterior_samples(
args, kwargs, **sample_kwargs_obs_plate
)
else:
samples_ = self._get_posterior_samples(
args, kwargs, **sample_kwargs_obs_plate
)
num_cells_in_batch = samples_['t_c'].shape[1]
for k in samples.keys():
if samples_[k].ndim >1:
if k in cell_specific:
samples[k] = np.concatenate([samples[k], samples_[k]], axis=1)
i += 1
i += 1
global_samples = self._get_posterior_samples(args, kwargs, **sample_kwargs)
global_samples = {
k: v
for k, v in global_samples.items()
if k not in list(cell_specific)
}
for k in global_samples.keys():
samples[k] = global_samples[k]
self.module.to(device)
return samples
[docs]
def visualize_module_trajectories(self, adata, chosen_module, delete = True, plot = True, save = None, smooth=None, min_mass=None, n_neighbors=None, cutoff_perc=None, plotting_kwargs = {"color": 'clusters', 'legend_fontsize': 10, 'legend_loc': 'on data', 'dpi': 300, 'cmap': 'Greys'}):
"""
Visualize relative module activation trajectories using velocity-based embedding.
Parameters
----------
adata
AnnData object containing cell information and embeddings.
chosen_module
The number / name of the chosen module.
delete
Delete temporary data structures after use, by default True.
plot
Whether to generate the plot, by default True.
save
File path to save the generated plot, by default None.
smooth
Smoothing parameter for grid-based velocity calculations, by default None.
min_mass
Minimum cell mass for grid-based velocity calculations, by default None.
n_neighbors
Number of neighbors for grid-based velocity calculations, by default None.
cutoff_perc
Cutoff percentile for adjusting grid-based velocity calculations, by default None.
plotting_kwargs
Additional keyword arguments for customizing the plot appearance, by default ``{"color": 'clusters', 'legend_fontsize': 10, 'legend_loc': 'on data', 'dpi': 300, 'cmap': 'inferno'}``.
"""
adata_groups = {}
for i in range(len(adata.obs['Module 0 State'])):
adata_groups[i]='No'
X_emb = np.array(adata.obsm['X_umap'])
V_emb = np.array(adata.obsm['velocity_umap'])
X_grid, V_grid = compute_velocity_on_grid(
X_emb=X_emb,
V_emb=V_emb,
density=1,
smooth=smooth,
min_mass=min_mass,
n_neighbors=n_neighbors,
autoscale=False,
adjust_for_stream=True,
cutoff_perc=cutoff_perc,
)
color_array=np.zeros((V_grid[0].shape[0],V_grid[0].shape[1]))
#Calculate whether the grid has module specific velocities
for a in range(X_grid.shape[1]):
for b in range(X_grid.shape[1]):
x_up = X_grid[0][a]
y_up = X_grid[1][b]
if a ==0:
x_down = -1000
else:
x_down = X_grid[0][a-1]
if b == 0:
y_down = -1000
else:
y_down = X_grid[1][b-1]
module_activation=0
total_activation=0
for i in range(len(adata.obsm['X_umap'])):
if x_down <= adata.obsm['X_umap'][i][0] <= x_up:
if y_down <= adata.obsm['X_umap'][i][1] <= y_up:
for key in adata.obs.keys():
if 'Activation' in key:
total_activation = total_activation + adata.obs[key][i]
if key == f'Module {chosen_module} Activation':
module_activation = module_activation + adata.obs[key][i]
if total_activation != 0:
rel_activation = module_activation / total_activation
else:
rel_activation = 0
color_array[b][a] = rel_activation
plotting_kwargs['arrow_color']=color_array
velocity_embedding_stream_modules(adata, basis='umap', save = False, vkey='velocity', modules=[chosen_module], **plotting_kwargs, show = False)