import logging
from datetime import date
from functools import partial
from typing import Optional, Union
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pyro
import pytorch_lightning as pl
import torch
from pyro import poutine
from pyro.infer.autoguide import AutoNormal, init_to_feasible, init_to_mean
from pytorch_lightning.callbacks import Callback
from scipy.sparse import issparse
from scvi import REGISTRY_KEYS
from scvi.dataloaders import AnnDataLoader
from scvi.model._utils import parse_use_gpu_arg
from scvi.module.base import PyroBaseModuleClass
from scvi.train import PyroTrainingPlan
from scvi.utils import track
from cell2fate.AutoAmortisedNormalMessenger import AutoAmortisedHierarchicalNormalMessenger
#from cell2fate.AutoAmortisedNormalMessenger import (
# AutoAmortisedHierarchicalNormalMessenger,
#)
logger = logging.getLogger(__name__)
import sys
from pyro.infer.autoguide.utils import deep_getattr, deep_setattr, helpful_support_errors
from torch.distributions import biject_to, constraints
import pyro.distributions as dist
from pyro.poutine.util import site_is_subsample
from pyro.nn.module import PyroModule, PyroParam, pyro_method
from pyro.infer.autoguide import AutoHierarchicalNormalMessenger
from pyro.poutine.runtime import get_plates
from pyro.distributions.distribution import Distribution
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Type,
Union,
ValuesView,
)
from scvi.module.base import PyroBaseModuleClass
from scvi.train import PyroTrainingPlan
from typing import Optional, Union
import pyro
max_epochs = 4000
start_lr = 0.01
final_lr = 0.001
lrd = (final_lr/start_lr)**(1/max_epochs)
clipped_adam = pyro.optim.ClippedAdam({"lr": start_lr, "lrd": lrd, "clip_norm": 10.0})
[docs]
def init_to_value(site=None, values={}):
'''
Initializes the value of a site to a specified value.
Parameters
----------
site
The site dictionary containing information about the site. If `None`, returns a partial function.
values
A dictionary containing the values to initialize sites with.
Returns
-------
Function or any
If `site` is None, returns a partial function with the `values` preset. Otherwise, returns the value specified for the site in `values`, or initializes to the mean using `init_to_mean` with `fallback` set to `init_to_feasible` if the site is not in `values`.
'''
if site is None:
return partial(init_to_value, values=values)
if site["name"] in values:
return values[site["name"]]
else:
return init_to_mean(site, fallback=init_to_feasible)
[docs]
def expand_zeros_along_dim(tensor, size, dim):
'''
Expands a tensor with zeros along a specified dimension.
Parameters
----------
tensor
The input tensor.
size
The size to expand along the specified dimension.
dim
The dimension along which to expand the tensor.
Returns
-------
Numpy.ndarray
A new tensor with zeros expanded along the specified dimension.
'''
shape = np.array(tensor.shape)
shape[dim] = size
return np.zeros(shape)
[docs]
def complete_tensor_along_dim(tensor, indices, dim, value, mode="put"):
'''
Completes a tensor along a specified dimension with given indices and values.
Parameters
----------
tensor
The input tensor.
indices
The indices to complete along the specified dimension.
dim
The dimension along which to complete the tensor.
value
The values to insert into the tensor.
mode
The mode of completion. "put" for putting values, "take" for taking values. Default is "put".
Returns
-------
Numpy.ndarray
A new tensor with completed values along the specified dimension.
'''
shape = value.shape
shape = np.ones(len(shape))
shape[dim] = len(indices)
shape = shape.astype(int)
indices = indices.reshape(shape)
if mode == "take":
return np.take_along_axis(arr=tensor, indices=indices, axis=dim)
np.put_along_axis(arr=tensor, indices=indices, values=value, axis=dim)
return tensor
def _complete_full_tensors_using_plates(
means_global, means, plate_dict, obs_plate_sites, plate_indices, plate_dim
):
'''
Completes full-sized tensors with minibatch values given minibatch indices.
Parameters
----------
means_global
Dictionary containing global means.
means
Dictionary containing means.
plate_dict
Dictionary containing plate information.
obs_plate_sites
Dictionary containing observed plate sites.
plate_indices
Dictionary containing plate indices.
plate_dim
Dictionary containing plate dimensions.
Returns
-------
Dict
A dictionary with completed global means.
'''
# complete full sized tensors with minibatch values given minibatch indices
for k in means_global.keys():
# find which and how many plates contain this tensor
plates = [
plate for plate in plate_dict.keys() if k in obs_plate_sites[plate].keys()
]
if len(plates) == 1:
# if only one plate contains this tensor, complete it using the plate indices
means_global[k] = complete_tensor_along_dim(
means_global[k],
plate_indices[plates[0]],
plate_dim[plates[0]],
means[k],
)
elif len(plates) == 2:
# subset data to index for plate 0 and fill index for plate 1
means_global_k = complete_tensor_along_dim(
means_global[k],
plate_indices[plates[0]],
plate_dim[plates[0]],
means[k],
mode="take",
)
means_global_k = complete_tensor_along_dim(
means_global_k,
plate_indices[plates[1]],
plate_dim[plates[1]],
means[k],
)
# fill index for plate 0 in the full data
means_global[k] = complete_tensor_along_dim(
means_global[k],
plate_indices[plates[0]],
plate_dim[plates[0]],
means_global_k,
)
# TODO add a test - observed variables should be identical if this code works correctly
# This code works correctly but the test needs to be added eventually
# np.allclose(
# samples['data_chromatin'].squeeze(-1).T,
# mod_reg.adata_manager.get_from_registry('X')[
# :, ~mod_reg.adata_manager.get_from_registry('gene_bool').ravel()
# ].toarray()
# )
else:
NotImplementedError(
f"Posterior sampling/mean/median/quantile not supported for variables with > 2 plates: {k} has {len(plates)}"
)
return means_global
[docs]
class AutoGuideMixinModule:
"""
This mixin class provides methods for:
- initialising standard AutoNormal guides
- initialising amortised guides (AutoNormalEncoder)
- initialising amortised guides with special additional inputs
"""
def _create_autoguide(
self,
model,
amortised,
encoder_kwargs,
encoder_mode,
init_loc_fn=init_to_mean(fallback=init_to_feasible),
n_cat_list: list = [],
encoder_instance=None,
guide_class=AutoNormal,
guide_kwargs: Optional[dict] = None,
):
if guide_kwargs is None:
guide_kwargs = dict()
if not amortised:
if getattr(model, "discrete_variables", None) is not None:
model = poutine.block(model, hide=model.discrete_variables)
if issubclass(guide_class, poutine.messenger.Messenger):
# messenger guides don't need create_plates function
_guide = guide_class(
model,
init_loc_fn=init_loc_fn,
**guide_kwargs,
)
else:
_guide = guide_class(
model,
init_loc_fn=init_loc_fn,
**guide_kwargs,
create_plates=self.model.create_plates,
)
else:
encoder_kwargs = encoder_kwargs if isinstance(encoder_kwargs, dict) else dict()
n_hidden = encoder_kwargs["n_hidden"] if "n_hidden" in encoder_kwargs.keys() else 200
amortised_vars = model.list_obs_plate_vars()
if len(amortised_vars["input"]) >= 2:
encoder_kwargs["n_cat_list"] = n_cat_list
n_in = amortised_vars["n_in"]
if getattr(model, "discrete_variables", None) is not None:
model = poutine.block(model, hide=model.discrete_variables)
_guide = AutoAmortisedHierarchicalNormalMessenger(
model,
amortised_plate_sites=amortised_vars,
n_in=n_in,
n_hidden=n_hidden,
encoder_kwargs=encoder_kwargs,
encoder_mode=encoder_mode,
encoder_instance=encoder_instance,
init_loc_fn=init_loc_fn,
**guide_kwargs,
)
return _guide
[docs]
class MyAutoHierarchicalNormalMessenger(AutoHierarchicalNormalMessenger):
@pyro_method
def __call__(self, *args, **kwargs):
# Since this guide creates parameters lazily, we need to avoid batching
# those parameters by a particle plate, in case the first time this
# guide is called is inside a particle plate. We assume all plates
# outside the model are particle plates.
self._outer_plates = tuple(f.name for f in get_plates())
try:
if self._computing_quantiles==False:
self._computing_quantiles=False
except:
self._computing_quantiles=False
try:
return self.call_new(*args, **kwargs)
finally:
del self._outer_plates
[docs]
def call_new(self, *args, **kwargs) -> Dict[str, torch.Tensor]: # type: ignore[override]
"""
Draws posterior samples from the guide and replays the model against
those samples.
:returns: A dict mapping sample site name to sample value.
This includes latent, deterministic, and observed values.
:rtype: Dict
"""
self.args_kwargs = args, kwargs
try:
with self:
self.model(*args, **kwargs)
finally:
del self.args_kwargs
model_trace, guide_trace = self.get_traces()
if self._computing_quantiles:
with poutine.block():
model = poutine.condition(self.model, self.quantile_dict)
trace = poutine.trace(model).get_trace(*args, **kwargs)
samples = {
name: site["value"]
for name, site in trace.nodes.items()
if site["type"] == "sample"
if not site_is_subsample(site)
}
#samples = self.quantile_dict
#print(samples.keys())
return samples
else:
samples = {
name: site["value"]
for name, site in model_trace.nodes.items()
if site["type"] == "sample"
}
return samples
def _pyro_sample(self, msg):
if msg["is_observed"] or site_is_subsample(msg):
return
prior = msg["fn"]
msg["infer"]["prior"] = prior
posterior = self.get_posterior(msg["name"], prior)
if isinstance(posterior, torch.Tensor):
posterior = dist.Delta(posterior, event_dim=prior.event_dim)
if posterior.batch_shape != prior.batch_shape:
posterior = posterior.expand(prior.batch_shape)
if self._computing_quantiles==True:
quantiles = self.get_posterior_quantile(msg["name"], prior)
msg["fn"] = posterior
[docs]
def get_posterior_quantile(
self,
name: str,
prior: Distribution,
) -> Union[Distribution, torch.Tensor]:
"""
Get the posterior quantile or median.
Parameters
----------
name
The name of the parameter.
prior
The prior distribution.
Returns
-------
Union[Distribution, torch.Tensor]
The posterior quantile or median.
"""
if self._computing_median:
return self._get_posterior_median(name, prior)
if self._computing_quantiles:
return self._get_posterior_quantiles(name, prior)
return self.get_posterior_quantile(name, prior)
[docs]
def quantiles(self, quantiles, *args, **kwargs):
"""
Compute quantiles of the posterior distribution.
Parameters
----------
quantiles
List of quantiles to compute.
*args, **kwargs
Additional arguments and keyword arguments to be passed to the underlying function.
Returns
-------
List
The result of the computation.
"""
self._computing_quantiles = True
self._quantile_values = quantiles
try:
return self(*args, **kwargs)
finally:
print(f"Sampled for quantile: {quantiles[0]}")
@torch.no_grad()
def _get_posterior_quantiles(self, name, prior):
transform = biject_to(prior.support)
if (self._hierarchical_sites is None) or (name in self._hierarchical_sites):
loc, scale, weight = self._get_params(name, prior)
loc = loc + transform.inv(prior.mean) * weight
else:
loc, scale = self._get_params(name, prior)
site_quantiles = torch.tensor(self._quantile_values, dtype=loc.dtype, device=loc.device)
site_quantiles_values = dist.Normal(loc, scale).icdf(site_quantiles)
try:
if self.quantile_dict=={}:
self.quantile_dict={}
except:
self.quantile_dict={}
self.quantile_dict[name]=transform(site_quantiles_values)
return transform(site_quantiles_values)
[docs]
class PyroTrainingPlan_ClippedAdamDecayingRate(PyroTrainingPlan):
"""
Lightning module task to train Pyro scvi-tools modules.
Parameters
----------
pyro_module
An instance of :class:`~scvi.module.base.PyroBaseModuleClass`. This object should have callable `model` and `guide` attributes or methods.
loss_fn
A Pyro loss. Should be a subclass of :class:`~pyro.infer.ELBO`. If `None`, defaults to :class:`~pyro.infer.Trace_ELBO`.
optim
A Pyro optimizer instance, e.g., :class:`~pyro.optim.Adam`. If `None`, defaults to :class:`pyro.optim.Adam` optimizer with a learning rate of `1e-3`.
optim_kwargs
Keyword arguments for **default** optimiser :class:`pyro.optim.Adam`.
n_aggressive_epochs
Number of epochs in aggressive optimisation of amortised variables.
n_aggressive_steps
Number of steps to spend optimising amortised variables before one step optimising global variables.
n_steps_kl_warmup
Number of training steps (minibatches) to scale weight on KL divergences from 0 to 1. Only activated when `n_epochs_kl_warmup` is set to None.
n_epochs_kl_warmup
Number of epochs to scale weight on KL divergences from 0 to 1. Overrides `n_steps_kl_warmup` when both are not `None`.
"""
def __init__(
self,
pyro_module: PyroBaseModuleClass,
loss_fn: Optional[pyro.infer.ELBO] = None,
optim: Optional[pyro.optim.PyroOptim] = clipped_adam,
optim_kwargs: Optional[dict] = None,
):
super().__init__(
pyro_module=pyro_module,
loss_fn=loss_fn,
optim=optim,
optim_kwargs=optim_kwargs
)
self.svi = pyro.infer.SVI(
model=pyro_module.model,
guide=pyro_module.guide,
optim=self.optim,
loss=self.loss_fn,
)
### Added for amortized model ###
[docs]
class QuantileMixin:
"""
This mixin class provides methods for:
- computing median and quantiles of the posterior distribution using both direct and amortised inference
"""
def _optim_param(
self,
lr: float = 0.01,
autoencoding_lr: float = None,
clip_norm: float = 200,
module_names: list = ["encoder", "hidden2locs", "hidden2scales"],
):
# TODO implement custom training method that can use this function.
# create function which fetches different lr for autoencoding guide
def optim_param(module_name, param_name):
# detect variables in autoencoding guide
if autoencoding_lr is not None and np.any([n in module_name + "." + param_name for n in module_names]):
return {
"lr": autoencoding_lr,
# limit the gradient step from becoming too large
"clip_norm": clip_norm,
}
else:
return {
"lr": lr,
# limit the gradient step from becoming too large
"clip_norm": clip_norm,
}
return optim_param
@torch.no_grad()
def _get_obs_plate_sites_v2(
self,
args: list,
kwargs: dict,
plate_name: str = None,
return_observed: bool = False,
return_deterministic: bool = True,
):
"""
Automatically guess which model sites belong to observation/minibatch plate.
This function requires minibatch plate name specified in `self.module.list_obs_plate_vars["name"]`.
Parameters
----------
args
Arguments to the model.
kwargs
Keyword arguments to the model.
return_observed
Record samples of observed variables.
Returns
-------
Dictionary with keys corresponding to site names and values to plate dimension.
"""
if plate_name is None:
plate_name = self.module.list_obs_plate_vars["name"]
def try_trace(args, kwargs):
try:
trace_ = poutine.trace(self.module.guide).get_trace(*args, **kwargs)
trace_ = poutine.trace(
poutine.replay(self.module.model, trace_)
).get_trace(*args, **kwargs)
except ValueError:
# if sample is unsuccessful try again
trace_ = try_trace(args, kwargs)
return trace_
trace = try_trace(args, kwargs)
# find plate dimension
obs_plate = {
name: {
fun.name: fun
for fun in site["cond_indep_stack"]
if (fun.name in plate_name) or (fun.name == plate_name)
}
for name, site in trace.nodes.items()
if (
(site["type"] == "sample") # sample statement
and (
(
(not site.get("is_observed", True)) or return_observed
) # don't save observed unless requested
or (
site.get("infer", False).get("_deterministic", False)
and return_deterministic
)
) # unless it is deterministic
and not isinstance(
site.get("fn", None), poutine.subsample_messenger._Subsample
) # don't save plates
)
if any(f.name == plate_name for f in site["cond_indep_stack"])
}
return obs_plate
def _posterior_quantile_minibatch(
self, q: float = 0.5, batch_size: int = 2048, use_gpu: bool = None, use_median: bool = False
):
"""
Compute median of the posterior distribution of each parameter, separating local (minibatch) variable
and global variables, which is necessary when performing amortised inference.
Note for developers: requires model class method which lists observation/minibatch plate
variables (self.module.model.list_obs_plate_vars()).
Parameters
----------
q
quantile to compute
batch_size
number of observations per batch
use_gpu
Bool, use gpu?
use_median
Bool, when q=0.5 use median rather than quantile method of the guide
Returns
-------
dictionary {variable_name: posterior quantile}
"""
gpus, device = parse_use_gpu_arg(use_gpu)
self.module.eval()
train_dl = AnnDataLoader(self.adata_manager, shuffle=False, batch_size=batch_size)
# sample local parameters
i = 0
for tensor_dict in train_dl:
args, kwargs = self.module._get_fn_args_from_batch(tensor_dict)
args = [a.to(device) for a in args]
kwargs = {k: v.to(device) for k, v in kwargs.items()}
self.to_device(device)
if i == 0:
# find plate sites
obs_plate_sites = self._get_obs_plate_sites(args, kwargs, return_observed=True)
if len(obs_plate_sites) == 0:
# if no local variables - don't sample
break
# find plate dimension
obs_plate_dim = list(obs_plate_sites.values())[0]
if use_median and q == 0.5:
means = self.module.guide.median(*args, **kwargs)
else:
means = self.module.guide.quantiles([q], *args, **kwargs)
means = {k: means[k].cpu().numpy() for k in means.keys() if k in obs_plate_sites}
else:
if use_median and q == 0.5:
means_ = self.module.guide.median(*args, **kwargs)
else:
means_ = self.module.guide.quantiles([q], *args, **kwargs)
means_ = {k: means_[k].cpu().numpy() for k in means_.keys() if k in obs_plate_sites}
means = {k: np.concatenate([means[k], means_[k]], axis=obs_plate_dim) for k in means.keys()}
i += 1
# sample global parameters
tensor_dict = next(iter(train_dl))
args, kwargs = self.module._get_fn_args_from_batch(tensor_dict)
args = [a.to(device) for a in args]
kwargs = {k: v.to(device) for k, v in kwargs.items()}
self.to_device(device)
if use_median and q == 0.5:
global_means = self.module.guide.median(*args, **kwargs)
else:
global_means = self.module.guide.quantiles([q], *args, **kwargs)
global_means = {k: global_means[k].cpu().numpy() for k in global_means.keys() if k not in obs_plate_sites}
for k in global_means.keys():
means[k] = global_means[k]
self.module.to(device)
return means
@torch.no_grad()
def _posterior_quantile(
self, q: float = 0.5, batch_size: int = None, use_gpu: bool = None, use_median: bool = False
):
"""
Compute median of the posterior distribution of each parameter pyro models trained without amortised inference.
Parameters
----------
q
Quantile to compute
use_gpu
Bool, use gpu?
use_median
Bool, when q=0.5 use median rather than quantile method of the guide
Returns
-------
dictionary {variable_name: posterior quantile}
"""
self.module.eval()
gpus, device = parse_use_gpu_arg(use_gpu)
if batch_size is None:
batch_size = self.adata_manager.adata.n_obs
train_dl = AnnDataLoader(self.adata_manager, shuffle=False, batch_size=batch_size)
# sample global parameters
tensor_dict = next(iter(train_dl))
args, kwargs = self.module._get_fn_args_from_batch(tensor_dict)
args = [a.to(device) for a in args]
kwargs = {k: v.to(device) for k, v in kwargs.items()}
self.to_device(device)
if use_median and q == 0.5:
means = self.module.guide.median(*args, **kwargs)
else:
means = self.module.guide.quantiles([q], *args, **kwargs)
means = {k: means[k].cpu().detach().numpy() for k in means.keys()}
return means
[docs]
def posterior_quantile(
self, q: float = 0.5, batch_size: int = 2048, use_gpu: bool = None, use_median: bool = False
):
"""
Compute median of the posterior distribution of each parameter.
Parameters
----------
q
Quantile to compute
use_gpu
Bool, use gpu?
use_median
Bool, when q=0.5 use median rather than quantile method of the guide
Returns
-------
"""
return self._posterior_quantile_minibatch_v2(
q=q, batch_size=batch_size, use_gpu=use_gpu, use_median=use_median
)
@torch.no_grad()
def _posterior_quantile_minibatch_v2(
self,
q: list = 0.5,
batch_size: int = 128,
gene_batch_size: int = 50,
use_gpu: bool = None,
use_median: bool = False,
return_observed: bool = True,
exclude_vars: list = [],
data_loader_indices=None,
show_progress: bool = True,
):
"""
Compute median of the posterior distribution of each parameter, separating local (minibatch) variable
and global variables, which is necessary when performing amortised inference.
Note for developers: requires model class method which lists observation/minibatch plate
variables (self.module.model.list_obs_plate_vars()).
Parameters
----------
q
quantile to compute
batch_size
number of observations per batch
use_gpu
Bool, use gpu?
use_median
Bool, when q=0.5 use median rather than quantile method of the guide
Returns
-------
dictionary {variable_name: posterior quantile}
"""
gpus, device = parse_use_gpu_arg(use_gpu)
self.module.eval()
train_dl = AnnDataLoader(self.adata_manager, shuffle=False, batch_size=batch_size)
# sample local parameters
i = 0
for tensor_dict in track(
train_dl,
style="tqdm",
description=f"Computing posterior quantile {q}, data batch: ",
disable=not show_progress,
):
args, kwargs = self.module._get_fn_args_from_batch(tensor_dict)
args = [a.to(device) for a in args]
kwargs = {k: v.to(device) for k, v in kwargs.items()}
self.to_device(device)
if i == 0:
minibatch_plate_names = self.module.list_obs_plate_vars["name"]
plates = self.module.model.create_plates(*args, **kwargs)
if not isinstance(plates, list):
plates = [plates]
# find plate indices & dim
plate_dict = {
plate.name: plate
for plate in plates
if (
(plate.name in minibatch_plate_names)
or (plate.name == minibatch_plate_names)
)
}
plate_size = {name: plate.size for name, plate in plate_dict.items()}
if data_loader_indices is not None:
# set total plate size to the number of indices in DL not total number of observations
# this option is not really used
plate_size = {
name: len(train_dl.indices)
for name, plate in plate_dict.items()
if plate.name == minibatch_plate_names
}
plate_dim = {name: plate.dim for name, plate in plate_dict.items()}
plate_indices = {
name: plate.indices.detach().cpu().numpy()
for name, plate in plate_dict.items()
}
# find plate sites
obs_plate_sites = {
plate: self._get_obs_plate_sites_v2(
args, kwargs, plate_name=plate, return_observed=return_observed
)
for plate in plate_dict.keys()
}
if use_median and q == 0.5:
# use median rather than quantile method
def try_median(args, kwargs):
try:
means_ = self.module.guide.median(*args, **kwargs)
except ValueError:
# if sample is unsuccessful try again
means_ = try_median(args, kwargs)
return means_
means = try_median(args, kwargs)
else:
def try_quantiles(args, kwargs):
try:
means_ = self.module.guide.quantiles([q], *args, **kwargs)
except ValueError:
# if sample is unsuccessful try again
means_ = try_quantiles(args, kwargs)
return means_
means = try_quantiles(args, kwargs)
means = {
k: means[k].detach().cpu().numpy()
for k in means.keys()
if k not in exclude_vars
}
means_global = means.copy()
for plate in plate_dict.keys():
# create full sized tensors according to plate size
means_global = {
k: (
expand_zeros_along_dim(
means_global[k], plate_size[plate], plate_dim[plate]
)
if k in obs_plate_sites[plate].keys()
else means_global[k]
)
for k in means_global.keys()
}
# complete full sized tensors with minibatch values given minibatch indices
means_global = _complete_full_tensors_using_plates(
means_global=means_global,
means=means,
plate_dict=plate_dict,
obs_plate_sites=obs_plate_sites,
plate_indices=plate_indices,
plate_dim=plate_dim,
)
if np.all([len(v) == 0 for v in obs_plate_sites.values()]):
# if no local variables - don't sample further - return results now
break
else:
if use_median and q == 0.5:
def try_median(args, kwargs):
try:
means_ = self.module.guide.median(*args, **kwargs)
except ValueError:
# if sample is unsuccessful try again
means_ = try_median(args, kwargs)
return means_
means = try_median(args, kwargs)
else:
def try_quantiles(args, kwargs):
try:
means_ = self.module.guide.quantiles([q], *args, **kwargs)
except ValueError:
# if sample is unsuccessful try again
means_ = try_quantiles(args, kwargs)
return means_
means = try_quantiles(args, kwargs)
means = {
k: means[k].detach().cpu().numpy()
for k in means.keys()
if k not in exclude_vars
}
# find plate indices & dim
plates = self.module.model.create_plates(*args, **kwargs)
if not isinstance(plates, list):
plates = [plates]
plate_dict = {
plate.name: plate
for plate in plates
if (
(plate.name in minibatch_plate_names)
or (plate.name == minibatch_plate_names)
)
}
plate_indices = {
name: plate.indices.detach().cpu().numpy()
for name, plate in plate_dict.items()
}
# TODO - is this correct to call this function again? find plate sites
obs_plate_sites = {
plate: self._get_obs_plate_sites_v2(
args, kwargs, plate_name=plate, return_observed=return_observed
)
for plate in plate_dict.keys()
}
# complete full sized tensors with minibatch values given minibatch indices
means_global = _complete_full_tensors_using_plates(
means_global=means_global,
means=means,
plate_dict=plate_dict,
obs_plate_sites=obs_plate_sites,
plate_indices=plate_indices,
plate_dim=plate_dim,
)
i += 1
self.module.to(device)
return means_global
[docs]
class PyroAggressiveTrainingPlan1(PyroTrainingPlan):
"""
Lightning module task to train Pyro scvi-tools modules.
Parameters
----------
pyro_module
An instance of :class:`~scvi.module.base.PyroBaseModuleClass`. This object
should have callable `model` and `guide` attributes or methods.
loss_fn
A Pyro loss. Should be a subclass of :class:`~pyro.infer.ELBO`.
If `None`, defaults to :class:`~pyro.infer.Trace_ELBO`.
optim
A Pyro optimizer instance, e.g., :class:`~pyro.optim.Adam`. If `None`,
defaults to :class:`pyro.optim.Adam` optimizer with a learning rate of `1e-3`.
optim_kwargs
Keyword arguments for **default** optimiser :class:`pyro.optim.Adam`.
n_aggressive_epochs
Number of epochs in aggressive optimisation of amortised variables.
n_aggressive_steps
Number of steps to spend optimising amortised variables before one step optimising global variables.
n_steps_kl_warmup
Number of training steps (minibatches) to scale weight on KL divergences from 0 to 1.
Only activated when `n_epochs_kl_warmup` is set to None.
n_epochs_kl_warmup
Number of epochs to scale weight on KL divergences from 0 to 1.
Overrides `n_steps_kl_warmup` when both are not `None`.
"""
def __init__(
self,
pyro_module: PyroBaseModuleClass,
loss_fn: Optional[pyro.infer.ELBO] = None,
optim: Optional[pyro.optim.PyroOptim] = None,
optim_kwargs: Optional[dict] = None,
n_aggressive_epochs: int = 1000,
n_aggressive_steps: int = 20,
n_steps_kl_warmup: Union[int, None] = None,
n_epochs_kl_warmup: Union[int, None] = 400,
aggressive_vars: Union[list, None] = None,
invert_aggressive_selection: bool = False,
):
super().__init__(
pyro_module=pyro_module,
loss_fn=loss_fn,
optim=optim,
optim_kwargs=optim_kwargs,
n_steps_kl_warmup=n_steps_kl_warmup,
n_epochs_kl_warmup=n_epochs_kl_warmup,
)
self.n_aggressive_epochs = n_aggressive_epochs
self.n_aggressive_steps = n_aggressive_steps
self.aggressive_steps_counter = 0
self.aggressive_epochs_counter = 0
self.mi = []
self.n_epochs_patience = 0
# in list not provided use amortised variables for aggressive training
if aggressive_vars is None:
aggressive_vars = list(self.module.list_obs_plate_vars["sites"].keys())
aggressive_vars = aggressive_vars + [f"{i}_initial" for i in aggressive_vars]
aggressive_vars = aggressive_vars + [f"{i}_unconstrained" for i in aggressive_vars]
self.aggressive_vars = aggressive_vars
self.invert_aggressive_selection = invert_aggressive_selection
self.svi = pyro.infer.SVI(
model=pyro_module.model,
guide=pyro_module.guide,
optim=self.optim,
loss=self.loss_fn,
)
[docs]
def change_requires_grad(self, aggressive_vars_status, non_aggressive_vars_status):
for k, v in self.module.guide.named_parameters():
k_in_vars = np.any([i in k for i in self.aggressive_vars])
# hide variables on the list if they are not hidden
if k_in_vars and v.requires_grad and (aggressive_vars_status == "hide"):
v.requires_grad = False
# expose variables on the list if they are hidden
if k_in_vars and (not v.requires_grad) and (aggressive_vars_status == "expose"):
v.requires_grad = True
# hide variables not on the list if they are not hidden
if (not k_in_vars) and v.requires_grad and (non_aggressive_vars_status == "hide"):
v.requires_grad = False
# expose variables not on the list if they are hidden
if (not k_in_vars) and (not v.requires_grad) and (non_aggressive_vars_status == "expose"):
v.requires_grad = True
[docs]
def training_epoch_end(self, outputs):
self.aggressive_epochs_counter += 1
elbo = 0
n = 0
for out in outputs:
elbo += out["loss"]
n += 1
elbo /= n
self.log("elbo_train", elbo, prog_bar=True)
[docs]
def training_step(self, batch, batch_idx):
args, kwargs = self.module._get_fn_args_from_batch(batch)
# Set KL weight if necessary.
# Note: if applied, ELBO loss in progress bar is the effective KL annealed loss, not the true ELBO.
if self.use_kl_weight:
kwargs.update({"kl_weight": self.kl_weight})
if self.aggressive_epochs_counter < self.n_aggressive_epochs:
if self.aggressive_steps_counter < self.n_aggressive_steps:
self.aggressive_steps_counter += 1
# Do parameter update exclusively for amortised variables
if self.invert_aggressive_selection:
self.change_requires_grad(
aggressive_vars_status="hide",
non_aggressive_vars_status="expose",
)
else:
self.change_requires_grad(
aggressive_vars_status="expose",
non_aggressive_vars_status="hide",
)
loss = torch.Tensor([self.svi.step(*args, **kwargs)])
else:
self.aggressive_steps_counter = 0
# Do parameter update exclusively for non-amortised variables
if self.invert_aggressive_selection:
self.change_requires_grad(
aggressive_vars_status="expose",
non_aggressive_vars_status="hide",
)
else:
self.change_requires_grad(
aggressive_vars_status="hide",
non_aggressive_vars_status="expose",
)
loss = torch.Tensor([self.svi.step(*args, **kwargs)])
else:
# Do parameter update for both types of variables
self.change_requires_grad(
aggressive_vars_status="expose",
non_aggressive_vars_status="expose",
)
loss = torch.Tensor([self.svi.step(*args, **kwargs)])
return {"loss": loss}