Pyro and scvi-tools infrastructure classes

Base mixin classes (AutoGuide setup, posterior quantile computation, plotting & export)

cell2fate._pyro_mixin.init_to_value(site=None, values={})[source]

Initializes the value of a site to a specified value.

Parameters:
  • site – The site dictionary containing information about the site. If None, returns a partial function.

  • values – A dictionary containing the values to initialize sites with.

Returns:

If site is None, returns a partial function with the values preset. Otherwise, returns the value specified for the site in values, or initializes to the mean using init_to_mean with fallback set to init_to_feasible if the site is not in values.

Return type:

Function or any

cell2fate._pyro_mixin.expand_zeros_along_dim(tensor, size, dim)[source]

Expands a tensor with zeros along a specified dimension.

Parameters:
  • tensor – The input tensor.

  • size – The size to expand along the specified dimension.

  • dim – The dimension along which to expand the tensor.

Returns:

A new tensor with zeros expanded along the specified dimension.

Return type:

Numpy.ndarray

cell2fate._pyro_mixin.complete_tensor_along_dim(tensor, indices, dim, value, mode='put')[source]

Completes a tensor along a specified dimension with given indices and values.

Parameters:
  • tensor – The input tensor.

  • indices – The indices to complete along the specified dimension.

  • dim – The dimension along which to complete the tensor.

  • value – The values to insert into the tensor.

  • mode – The mode of completion. “put” for putting values, “take” for taking values. Default is “put”.

Returns:

A new tensor with completed values along the specified dimension.

Return type:

Numpy.ndarray

class cell2fate._pyro_mixin.AutoGuideMixinModule[source]

Bases: object

This mixin class provides methods for:

  • initialising standard AutoNormal guides

  • initialising amortised guides (AutoNormalEncoder)

  • initialising amortised guides with special additional inputs

class cell2fate._pyro_mixin.MyAutoHierarchicalNormalMessenger(model: ~typing.Callable, *, init_loc_fn: ~typing.Callable = functools.partial(<function init_to_mean>, fallback=<function init_to_feasible>), init_scale: float = 0.1, amortized_plates: ~typing.Tuple[str, ...] = (), init_weight: float = 1.0, hierarchical_sites: list | None = None)[source]

Bases: AutoHierarchicalNormalMessenger

call_new(*args, **kwargs) Dict[str, Tensor][source]

Draws posterior samples from the guide and replays the model against those samples.

Returns:

A dict mapping sample site name to sample value. This includes latent, deterministic, and observed values.

Return type:

Dict

get_posterior_quantile(name: str, prior: Distribution) Distribution | Tensor[source]

Get the posterior quantile or median.

Parameters:
  • name – The name of the parameter.

  • prior – The prior distribution.

Returns:

The posterior quantile or median.

Return type:

Union[Distribution, torch.Tensor]

quantiles(quantiles, *args, **kwargs)[source]

Compute quantiles of the posterior distribution.

Parameters:
  • quantiles – List of quantiles to compute.

  • *args, **kwargs – Additional arguments and keyword arguments to be passed to the underlying function.

Returns:

The result of the computation.

Return type:

List

training: bool
class cell2fate._pyro_mixin.PyroTrainingPlan_ClippedAdamDecayingRate(pyro_module: ~scvi.module.base._base_module.PyroBaseModuleClass, loss_fn: ~pyro.infer.elbo.ELBO | None = None, optim: ~pyro.optim.optim.PyroOptim | None = <pyro.optim.optim.PyroOptim object>, optim_kwargs: dict | None = None)[source]

Bases: PyroTrainingPlan

Lightning module task to train Pyro scvi-tools modules. :Parameters: * pyro_module – An instance of PyroBaseModuleClass. This object should have callable model and guide attributes or methods.

  • loss_fn – A Pyro loss. Should be a subclass of ELBO. If None, defaults to Trace_ELBO.

  • optim – A Pyro optimizer instance, e.g., Adam. If None, defaults to pyro.optim.Adam optimizer with a learning rate of 1e-3.

  • optim_kwargs – Keyword arguments for default optimiser pyro.optim.Adam.

  • n_aggressive_epochs – Number of epochs in aggressive optimisation of amortised variables.

  • n_aggressive_steps – Number of steps to spend optimising amortised variables before one step optimising global variables.

  • n_steps_kl_warmup – Number of training steps (minibatches) to scale weight on KL divergences from 0 to 1. Only activated when n_epochs_kl_warmup is set to None.

  • n_epochs_kl_warmup – Number of epochs to scale weight on KL divergences from 0 to 1. Overrides n_steps_kl_warmup when both are not None.

training: bool
class cell2fate._pyro_mixin.QuantileMixin[source]

Bases: object

This mixin class provides methods for:

  • computing median and quantiles of the posterior distribution using both direct and amortised inference

posterior_quantile(q: float = 0.5, batch_size: int = 2048, use_gpu: bool | None = None, use_median: bool = False)[source]

Compute median of the posterior distribution of each parameter.

Parameters:
  • q – Quantile to compute

  • use_gpu – Bool, use gpu?

  • use_median – Bool, when q=0.5 use median rather than quantile method of the guide

class cell2fate._pyro_mixin.PyroAggressiveTrainingPlan1(pyro_module: PyroBaseModuleClass, loss_fn: ELBO | None = None, optim: PyroOptim | None = None, optim_kwargs: dict | None = None, n_aggressive_epochs: int = 1000, n_aggressive_steps: int = 20, n_steps_kl_warmup: int | None = None, n_epochs_kl_warmup: int | None = 400, aggressive_vars: list | None = None, invert_aggressive_selection: bool = False)[source]

Bases: PyroTrainingPlan

Lightning module task to train Pyro scvi-tools modules. :Parameters: * pyro_module – An instance of PyroBaseModuleClass. This object

should have callable model and guide attributes or methods.

  • loss_fn – A Pyro loss. Should be a subclass of ELBO. If None, defaults to Trace_ELBO.

  • optim – A Pyro optimizer instance, e.g., Adam. If None, defaults to pyro.optim.Adam optimizer with a learning rate of 1e-3.

  • optim_kwargs – Keyword arguments for default optimiser pyro.optim.Adam.

  • n_aggressive_epochs – Number of epochs in aggressive optimisation of amortised variables.

  • n_aggressive_steps – Number of steps to spend optimising amortised variables before one step optimising global variables.

  • n_steps_kl_warmup – Number of training steps (minibatches) to scale weight on KL divergences from 0 to 1. Only activated when n_epochs_kl_warmup is set to None.

  • n_epochs_kl_warmup – Number of epochs to scale weight on KL divergences from 0 to 1. Overrides n_steps_kl_warmup when both are not None.

training: bool
change_requires_grad(aggressive_vars_status, non_aggressive_vars_status)[source]
training_epoch_end(outputs)[source]

Called at the end of the training epoch with the outputs of all training steps. Use this in case you need to do something with all the outputs returned by training_step().

# the pseudocode for these calls
train_outs = []
for train_batch in train_data:
    out = training_step(train_batch)
    train_outs.append(out)
training_epoch_end(train_outs)
Parameters:

outputs – List of outputs you defined in training_step(). If there are multiple optimizers, it is a list containing a list of outputs for each optimizer. If using truncated_bptt_steps > 1, each element is a list of outputs corresponding to the outputs of each processed split batch.

Returns:

None

Note

If this method is not overridden, this won’t be called.

def training_epoch_end(self, training_step_outputs):
    # do something with all training_step outputs
    for out in training_step_outputs:
        ...
training_step(batch, batch_idx)[source]

Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.

Parameters:
Returns:

Any of.

  • Tensor - The loss tensor

  • dict - A dictionary. Can include any keys, but must include the key 'loss'

  • None - Training will skip to the next batch. This is only for automatic optimization.

    This is not supported for multi-GPU, TPU, IPU, or DeepSpeed.

In this step you’d normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.

Example:

def training_step(self, batch, batch_idx):
    x, y, z = batch
    out = self.encoder(x)
    loss = self.loss(out, x)
    return loss

If you define multiple optimizers, this step will be called with an additional optimizer_idx parameter.

# Multiple optimizers (e.g.: GANs)
def training_step(self, batch, batch_idx, optimizer_idx):
    if optimizer_idx == 0:
        # do training_step with encoder
        ...
    if optimizer_idx == 1:
        # do training_step with decoder
        ...

If you add truncated back propagation through time you will also get an additional argument with the hidden states of the previous step.

# Truncated back-propagation through time
def training_step(self, batch, batch_idx, hiddens):
    # hiddens are the hidden states from the previous truncated backprop step
    out, hiddens = self.lstm(data, hiddens)
    loss = ...
    return {"loss": loss, "hiddens": hiddens}

Note

The loss value shown in the progress bar is smoothed (averaged) over the last values, so it differs from the actual loss returned in train/validation step.

Scvi-tools Module classes (initialising the model and the guide, PyroBaseModuleClass)

class cell2fate._pyro_base_cell2fate_module.Cell2FateBaseModule(model, amortised: bool = False, encoder_mode='single', encoder_kwargs=None, data_transform='log1p', guide_class=<class 'pyro.infer.autoguide.effect.AutoHierarchicalNormalMessenger'>, encoder_instance=None, **kwargs)[source]

Bases: PyroBaseModuleClass, AutoGuideMixinModule

property model
property guide
property is_amortised
property list_obs_plate_vars

Model annotation for minibatch training with pyro plate.

A dictionary with: 1. “name” - the name of observation/minibatch plate; 2. “in” - indexes of model args to provide to encoder network when using amortised inference; 3. “sites” - dictionary with

keys - names of variables that belong to the observation plate (used to recognise

and merge posterior samples for minibatch variables)

values - the dimensions in non-plate axis of each variable (used to construct output

layer of encoder network when using amortised inference)

init_to_value(site)[source]
training: bool

Auto amortised hierarchical normal messenger for amortized inference

class cell2fate.AutoAmortisedNormalMessenger.FCLayersPyro(n_in: int, n_out: int, n_cat_list: ~typing.Iterable[int] | None = None, n_layers: int = 1, n_hidden: int = 128, dropout_rate: float = 0.1, use_batch_norm: bool = True, use_layer_norm: bool = False, use_activation: bool = True, bias: bool = True, inject_covariates: bool = True, activation_fn: ~torch.nn.modules.module.Module = <class 'torch.nn.modules.activation.ReLU'>)[source]

Bases: FCLayers, PyroModule

training: bool
class cell2fate.AutoAmortisedNormalMessenger.AutoAmortisedHierarchicalNormalMessenger(model: ~typing.Callable, *, amortised_plate_sites: dict, n_in: int, n_hidden: dict | None = None, init_param_scale: float = 0.02, init_scale: float = 0.1, init_weight: float = 1.0, init_loc_fn: ~typing.Callable = functools.partial(<function init_to_mean>, fallback=<function init_to_feasible>), encoder_class=<class 'cell2fate.AutoAmortisedNormalMessenger.FCLayersPyro'>, encoder_kwargs=None, multi_encoder_kwargs=None, encoder_instance: ~torch.nn.modules.module.Module | None = None, encoder_mode: ~typing.Literal['single', 'multiple', 'single-multiple'] = 'single', hierarchical_sites: list | None = None, bias=True, use_posterior_lsw_encoders=False)[source]

Bases: AutoHierarchicalNormalMessenger

EXPERIMENTAL Automatic GuideMessenger , intended for use with Effect_ELBO or similar. Amortise specific sites

The mean-field posterior at any site is a transformed normal distribution, the mean of which depends on the value of that site given its dependencies in the model:

loc = loc + transform.inv(prior.mean) * weight

Where the value of prior.mean is conditional on upstream sites in the model. This approach doesn’t work for distributions that don’t have the mean.

loc, scales and element-specific weight are amortised for each site specified in amortised_plate_sites.

Derived classes may override particular sites and use this simply as a default, see AutoNormalMessenger documentation for example.

Parameters:
  • model (callable) – A Pyro model.

  • amortised_plate_sites (dict) –

    Dictionary with amortised plate details: the name of observation/minibatch plate, indexes of model args to provide to encoder, variable names that belong to the observation plate and the number of dimensions in non-plate axis of each variable - such as: {

    ”name”: “obs_plate”, “input”: [0], # expression data + (optional) batch index ([0, 2]) “input_transform”: [torch.log1p], # how to transform input data before passing to NN “sites”: {

    ”n_s”: 1, “y_s”: 1, “z_sr”: R, “w_sf”: F,

    }

    }

  • n_in (int) – Number of input dimensions (for encoder_class).

  • n_hidden (int) –

    Number of hidden nodes in each layer, one of 3 options: 1. Integer denoting the number of hidden nodes 2. Dictionary with {“single”: 200, “multiple”: 200} denoting the number of hidden nodes for each encoder_mode (See below) 3. Allowing different number of hidden nodes for each model site. Dictionary with the number of hidden nodes for single encode mode and each model site: {

    ”single”: 200 “n_s”: 5, “y_s”: 5, “z_sr”: 128, “w_sf”: 200,

    }

  • init_param_scale (float) – How to scale/normalise initial values for weights converting hidden layers to loc and scales.

  • scales_offset (float) – offset between the output of the NN and scales.

  • encoder_class (Callable) – Class that defines encoder network.

  • encoder_kwargs (dict) – Keyword arguments for encoder class.

  • multi_encoder_kwargs (dict) – Optional separate keyword arguments for encoder_class, useful when encoder_mode == “single-multiple”.

  • encoder_instance (Callable) – Encoder network instance, overrides class input and the input instance is copied with deepcopy.

  • encoder_mode (str) – Use single encoder for all variables (“single”), one encoder per variable (“multiple”) or a single encoder in the first step and multiple encoders in the second step (“single-multiple”).

  • hierarchical_sites (list) – List of latent variables (model sites) that have hierarchical dependencies. If None, all sites are assumed to have hierarchical dependencies. If None, for the sites that don’t have upstream sites, the guide is representing/learning deviation from the prior.

weight_type = 'element-wise'
get_posterior(name: str, prior: Distribution) Distribution | Tensor[source]

Abstract method to compute a posterior distribution or sample a posterior value given a prior distribution conditioned on upstream posterior samples.

Implementations may use pyro.param and pyro.sample inside this function, but pyro.sample statements should set infer={"is_auxiliary": True"} .

Implementations may access further information for computations:

  • value = self.upstream_value(name) is the value of an upstream

    sample or deterministic site.

  • self.trace is a trace of upstream sites, and may be useful for other information such as self.trace.nodes["my_site"]["fn"] or self.trace.nodes["my_site"]["cond_indep_stack"] .

  • args, kwargs = self.args_kwargs are the inputs to the model, and

    may be useful for amortization.

Parameters:
  • name (str) – The name of the sample site to sample.

  • prior (Distribution) – The prior distribution of this sample site (conditioned on upstream samples from the posterior).

Returns:

A posterior distribution or sample from the posterior distribution.

Return type:

Distribution or torch.Tensor

encode(name: str, prior: Distribution)[source]

Apply encoder network to input data to obtain hidden layer encoding. :Parameters: * args – Pyro model args

  • kwargs – Pyro model kwargs

  • ——-

median(*args, **kwargs)[source]
quantiles(quantiles, *args, **kwargs)[source]
training: bool