Pyro and scvi-tools infrastructure classes
Base mixin classes (AutoGuide setup, posterior quantile computation, plotting & export)
- cell2fate._pyro_mixin.init_to_value(site=None, values={})[source]
Initializes the value of a site to a specified value.
- Parameters:
site – The site dictionary containing information about the site. If None, returns a partial function.
values – A dictionary containing the values to initialize sites with.
- Returns:
If site is None, returns a partial function with the values preset. Otherwise, returns the value specified for the site in values, or initializes to the mean using init_to_mean with fallback set to init_to_feasible if the site is not in values.
- Return type:
Function or any
- cell2fate._pyro_mixin.expand_zeros_along_dim(tensor, size, dim)[source]
Expands a tensor with zeros along a specified dimension.
- Parameters:
tensor – The input tensor.
size – The size to expand along the specified dimension.
dim – The dimension along which to expand the tensor.
- Returns:
A new tensor with zeros expanded along the specified dimension.
- Return type:
Numpy.ndarray
- cell2fate._pyro_mixin.complete_tensor_along_dim(tensor, indices, dim, value, mode='put')[source]
Completes a tensor along a specified dimension with given indices and values.
- Parameters:
tensor – The input tensor.
indices – The indices to complete along the specified dimension.
dim – The dimension along which to complete the tensor.
value – The values to insert into the tensor.
mode – The mode of completion. “put” for putting values, “take” for taking values. Default is “put”.
- Returns:
A new tensor with completed values along the specified dimension.
- Return type:
Numpy.ndarray
- class cell2fate._pyro_mixin.AutoGuideMixinModule[source]
Bases:
objectThis mixin class provides methods for:
initialising standard AutoNormal guides
initialising amortised guides (AutoNormalEncoder)
initialising amortised guides with special additional inputs
- class cell2fate._pyro_mixin.MyAutoHierarchicalNormalMessenger(model: ~typing.Callable, *, init_loc_fn: ~typing.Callable = functools.partial(<function init_to_mean>, fallback=<function init_to_feasible>), init_scale: float = 0.1, amortized_plates: ~typing.Tuple[str, ...] = (), init_weight: float = 1.0, hierarchical_sites: list | None = None)[source]
Bases:
AutoHierarchicalNormalMessenger- call_new(*args, **kwargs) Dict[str, Tensor][source]
Draws posterior samples from the guide and replays the model against those samples.
- Returns:
A dict mapping sample site name to sample value. This includes latent, deterministic, and observed values.
- Return type:
Dict
- get_posterior_quantile(name: str, prior: Distribution) Distribution | Tensor[source]
Get the posterior quantile or median.
- Parameters:
name – The name of the parameter.
prior – The prior distribution.
- Returns:
The posterior quantile or median.
- Return type:
Union[Distribution, torch.Tensor]
- quantiles(quantiles, *args, **kwargs)[source]
Compute quantiles of the posterior distribution.
- Parameters:
quantiles – List of quantiles to compute.
*args, **kwargs – Additional arguments and keyword arguments to be passed to the underlying function.
- Returns:
The result of the computation.
- Return type:
List
- training: bool
- class cell2fate._pyro_mixin.PyroTrainingPlan_ClippedAdamDecayingRate(pyro_module: ~scvi.module.base._base_module.PyroBaseModuleClass, loss_fn: ~pyro.infer.elbo.ELBO | None = None, optim: ~pyro.optim.optim.PyroOptim | None = <pyro.optim.optim.PyroOptim object>, optim_kwargs: dict | None = None)[source]
Bases:
PyroTrainingPlanLightning module task to train Pyro scvi-tools modules. :Parameters: * pyro_module – An instance of
PyroBaseModuleClass. This object should have callable model and guide attributes or methods.loss_fn – A Pyro loss. Should be a subclass of
ELBO. If None, defaults toTrace_ELBO.optim – A Pyro optimizer instance, e.g.,
Adam. If None, defaults topyro.optim.Adamoptimizer with a learning rate of 1e-3.optim_kwargs – Keyword arguments for default optimiser
pyro.optim.Adam.n_aggressive_epochs – Number of epochs in aggressive optimisation of amortised variables.
n_aggressive_steps – Number of steps to spend optimising amortised variables before one step optimising global variables.
n_steps_kl_warmup – Number of training steps (minibatches) to scale weight on KL divergences from 0 to 1. Only activated when n_epochs_kl_warmup is set to None.
n_epochs_kl_warmup – Number of epochs to scale weight on KL divergences from 0 to 1. Overrides n_steps_kl_warmup when both are not None.
- training: bool
- class cell2fate._pyro_mixin.QuantileMixin[source]
Bases:
objectThis mixin class provides methods for:
computing median and quantiles of the posterior distribution using both direct and amortised inference
- posterior_quantile(q: float = 0.5, batch_size: int = 2048, use_gpu: bool | None = None, use_median: bool = False)[source]
Compute median of the posterior distribution of each parameter.
- Parameters:
q – Quantile to compute
use_gpu – Bool, use gpu?
use_median – Bool, when q=0.5 use median rather than quantile method of the guide
- class cell2fate._pyro_mixin.PyroAggressiveTrainingPlan1(pyro_module: PyroBaseModuleClass, loss_fn: ELBO | None = None, optim: PyroOptim | None = None, optim_kwargs: dict | None = None, n_aggressive_epochs: int = 1000, n_aggressive_steps: int = 20, n_steps_kl_warmup: int | None = None, n_epochs_kl_warmup: int | None = 400, aggressive_vars: list | None = None, invert_aggressive_selection: bool = False)[source]
Bases:
PyroTrainingPlanLightning module task to train Pyro scvi-tools modules. :Parameters: * pyro_module – An instance of
PyroBaseModuleClass. This objectshould have callable model and guide attributes or methods.
loss_fn – A Pyro loss. Should be a subclass of
ELBO. If None, defaults toTrace_ELBO.optim – A Pyro optimizer instance, e.g.,
Adam. If None, defaults topyro.optim.Adamoptimizer with a learning rate of 1e-3.optim_kwargs – Keyword arguments for default optimiser
pyro.optim.Adam.n_aggressive_epochs – Number of epochs in aggressive optimisation of amortised variables.
n_aggressive_steps – Number of steps to spend optimising amortised variables before one step optimising global variables.
n_steps_kl_warmup – Number of training steps (minibatches) to scale weight on KL divergences from 0 to 1. Only activated when n_epochs_kl_warmup is set to None.
n_epochs_kl_warmup – Number of epochs to scale weight on KL divergences from 0 to 1. Overrides n_steps_kl_warmup when both are not None.
- training: bool
- training_epoch_end(outputs)[source]
Called at the end of the training epoch with the outputs of all training steps. Use this in case you need to do something with all the outputs returned by
training_step().# the pseudocode for these calls train_outs = [] for train_batch in train_data: out = training_step(train_batch) train_outs.append(out) training_epoch_end(train_outs)
- Parameters:
outputs – List of outputs you defined in
training_step(). If there are multiple optimizers, it is a list containing a list of outputs for each optimizer. If usingtruncated_bptt_steps > 1, each element is a list of outputs corresponding to the outputs of each processed split batch.- Returns:
None
Note
If this method is not overridden, this won’t be called.
def training_epoch_end(self, training_step_outputs): # do something with all training_step outputs for out in training_step_outputs: ...
- training_step(batch, batch_idx)[source]
Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.
- Parameters:
batch (
Tensor| (Tensor, …) | [Tensor, …]) – The output of yourDataLoader. A tensor, tuple or list.batch_idx (
int) – Integer displaying index of this batchoptimizer_idx (
int) – When using multiple optimizers, this argument will also be present.hiddens (
Any) – Passed in if :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0.
- Returns:
Any of.
Tensor- The loss tensordict- A dictionary. Can include any keys, but must include the key'loss'None- Training will skip to the next batch. This is only for automatic optimization.This is not supported for multi-GPU, TPU, IPU, or DeepSpeed.
In this step you’d normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.
Example:
def training_step(self, batch, batch_idx): x, y, z = batch out = self.encoder(x) loss = self.loss(out, x) return loss
If you define multiple optimizers, this step will be called with an additional
optimizer_idxparameter.# Multiple optimizers (e.g.: GANs) def training_step(self, batch, batch_idx, optimizer_idx): if optimizer_idx == 0: # do training_step with encoder ... if optimizer_idx == 1: # do training_step with decoder ...
If you add truncated back propagation through time you will also get an additional argument with the hidden states of the previous step.
# Truncated back-propagation through time def training_step(self, batch, batch_idx, hiddens): # hiddens are the hidden states from the previous truncated backprop step out, hiddens = self.lstm(data, hiddens) loss = ... return {"loss": loss, "hiddens": hiddens}
Note
The loss value shown in the progress bar is smoothed (averaged) over the last values, so it differs from the actual loss returned in train/validation step.
Scvi-tools Module classes (initialising the model and the guide, PyroBaseModuleClass)
- class cell2fate._pyro_base_cell2fate_module.Cell2FateBaseModule(model, amortised: bool = False, encoder_mode='single', encoder_kwargs=None, data_transform='log1p', guide_class=<class 'pyro.infer.autoguide.effect.AutoHierarchicalNormalMessenger'>, encoder_instance=None, **kwargs)[source]
Bases:
PyroBaseModuleClass,AutoGuideMixinModule- property model
- property guide
- property is_amortised
- property list_obs_plate_vars
Model annotation for minibatch training with pyro plate.
A dictionary with: 1. “name” - the name of observation/minibatch plate; 2. “in” - indexes of model args to provide to encoder network when using amortised inference; 3. “sites” - dictionary with
- keys - names of variables that belong to the observation plate (used to recognise
and merge posterior samples for minibatch variables)
- values - the dimensions in non-plate axis of each variable (used to construct output
layer of encoder network when using amortised inference)
- training: bool
Auto amortised hierarchical normal messenger for amortized inference
- class cell2fate.AutoAmortisedNormalMessenger.FCLayersPyro(n_in: int, n_out: int, n_cat_list: ~typing.Iterable[int] | None = None, n_layers: int = 1, n_hidden: int = 128, dropout_rate: float = 0.1, use_batch_norm: bool = True, use_layer_norm: bool = False, use_activation: bool = True, bias: bool = True, inject_covariates: bool = True, activation_fn: ~torch.nn.modules.module.Module = <class 'torch.nn.modules.activation.ReLU'>)[source]
Bases:
FCLayers,PyroModule- training: bool
- class cell2fate.AutoAmortisedNormalMessenger.AutoAmortisedHierarchicalNormalMessenger(model: ~typing.Callable, *, amortised_plate_sites: dict, n_in: int, n_hidden: dict | None = None, init_param_scale: float = 0.02, init_scale: float = 0.1, init_weight: float = 1.0, init_loc_fn: ~typing.Callable = functools.partial(<function init_to_mean>, fallback=<function init_to_feasible>), encoder_class=<class 'cell2fate.AutoAmortisedNormalMessenger.FCLayersPyro'>, encoder_kwargs=None, multi_encoder_kwargs=None, encoder_instance: ~torch.nn.modules.module.Module | None = None, encoder_mode: ~typing.Literal['single', 'multiple', 'single-multiple'] = 'single', hierarchical_sites: list | None = None, bias=True, use_posterior_lsw_encoders=False)[source]
Bases:
AutoHierarchicalNormalMessengerEXPERIMENTAL Automatic
GuideMessenger, intended for use withEffect_ELBOor similar. Amortise specific sitesThe mean-field posterior at any site is a transformed normal distribution, the mean of which depends on the value of that site given its dependencies in the model:
loc = loc + transform.inv(prior.mean) * weight
Where the value of prior.mean is conditional on upstream sites in the model. This approach doesn’t work for distributions that don’t have the mean.
loc, scales and element-specific weight are amortised for each site specified in amortised_plate_sites.
Derived classes may override particular sites and use this simply as a default, see AutoNormalMessenger documentation for example.
- Parameters:
model (callable) – A Pyro model.
amortised_plate_sites (dict) –
Dictionary with amortised plate details: the name of observation/minibatch plate, indexes of model args to provide to encoder, variable names that belong to the observation plate and the number of dimensions in non-plate axis of each variable - such as: {
”name”: “obs_plate”, “input”: [0], # expression data + (optional) batch index ([0, 2]) “input_transform”: [torch.log1p], # how to transform input data before passing to NN “sites”: {
”n_s”: 1, “y_s”: 1, “z_sr”: R, “w_sf”: F,
}
}
n_in (int) – Number of input dimensions (for encoder_class).
n_hidden (int) –
Number of hidden nodes in each layer, one of 3 options: 1. Integer denoting the number of hidden nodes 2. Dictionary with {“single”: 200, “multiple”: 200} denoting the number of hidden nodes for each encoder_mode (See below) 3. Allowing different number of hidden nodes for each model site. Dictionary with the number of hidden nodes for single encode mode and each model site: {
”single”: 200 “n_s”: 5, “y_s”: 5, “z_sr”: 128, “w_sf”: 200,
}
init_param_scale (float) – How to scale/normalise initial values for weights converting hidden layers to loc and scales.
scales_offset (float) – offset between the output of the NN and scales.
encoder_class (Callable) – Class that defines encoder network.
encoder_kwargs (dict) – Keyword arguments for encoder class.
multi_encoder_kwargs (dict) – Optional separate keyword arguments for encoder_class, useful when encoder_mode == “single-multiple”.
encoder_instance (Callable) – Encoder network instance, overrides class input and the input instance is copied with deepcopy.
encoder_mode (str) – Use single encoder for all variables (“single”), one encoder per variable (“multiple”) or a single encoder in the first step and multiple encoders in the second step (“single-multiple”).
hierarchical_sites (list) – List of latent variables (model sites) that have hierarchical dependencies. If None, all sites are assumed to have hierarchical dependencies. If None, for the sites that don’t have upstream sites, the guide is representing/learning deviation from the prior.
- weight_type = 'element-wise'
- get_posterior(name: str, prior: Distribution) Distribution | Tensor[source]
Abstract method to compute a posterior distribution or sample a posterior value given a prior distribution conditioned on upstream posterior samples.
Implementations may use
pyro.paramandpyro.sampleinside this function, butpyro.samplestatements should setinfer={"is_auxiliary": True"}.Implementations may access further information for computations:
value = self.upstream_value(name)is the value of an upstreamsample or deterministic site.
self.traceis a trace of upstream sites, and may be useful for other information such asself.trace.nodes["my_site"]["fn"]orself.trace.nodes["my_site"]["cond_indep_stack"].args, kwargs = self.args_kwargsare the inputs to the model, andmay be useful for amortization.
- Parameters:
name (str) – The name of the sample site to sample.
prior (Distribution) – The prior distribution of this sample site (conditioned on upstream samples from the posterior).
- Returns:
A posterior distribution or sample from the posterior distribution.
- Return type:
Distribution or torch.Tensor
- encode(name: str, prior: Distribution)[source]
Apply encoder network to input data to obtain hidden layer encoding. :Parameters: * args – Pyro model args
kwargs – Pyro model kwargs
——-
- training: bool