from copy import deepcopy
from typing import Callable, Literal, Optional, Union
import numpy as np
import pyro.distributions as dist
import torch
from pyro.distributions.distribution import Distribution
from pyro.distributions.transforms import SoftplusTransform
from pyro.infer.autoguide import AutoHierarchicalNormalMessenger
from pyro.infer.autoguide.initialization import init_to_feasible, init_to_mean
from pyro.infer.autoguide.utils import (
deep_getattr,
deep_setattr,
helpful_support_errors,
)
from pyro.nn.module import PyroModule, PyroParam, to_pyro_module_
from torch.distributions import biject_to, constraints
from cell2fate.nn import FCLayers
[docs]
class FCLayersPyro(FCLayers, PyroModule):
pass
[docs]
class AutoAmortisedHierarchicalNormalMessenger(AutoHierarchicalNormalMessenger):
"""
EXPERIMENTAL Automatic :class:`~pyro.infer.effect_elbo.GuideMessenger` ,
intended for use with :class:`~pyro.infer.effect_elbo.Effect_ELBO` or
similar. Amortise specific sites
The mean-field posterior at any site is a transformed normal distribution,
the mean of which depends on the value of that site given its dependencies in the model:
loc = loc + transform.inv(prior.mean) * weight
Where the value of `prior.mean` is conditional on upstream sites in the model.
This approach doesn't work for distributions that don't have the mean.
loc, scales and element-specific weight are amortised for each site specified in `amortised_plate_sites`.
Derived classes may override particular sites and use this simply as a
default, see AutoNormalMessenger documentation for example.
:param callable model: A Pyro model.
:param dict amortised_plate_sites: Dictionary with amortised plate details:
the name of observation/minibatch plate,
indexes of model args to provide to encoder,
variable names that belong to the observation plate
and the number of dimensions in non-plate axis of each variable - such as:
{
"name": "obs_plate",
"input": [0], # expression data + (optional) batch index ([0, 2])
"input_transform": [torch.log1p], # how to transform input data before passing to NN
"sites": {
"n_s": 1,
"y_s": 1,
"z_sr": R,
"w_sf": F,
}
}
:param int n_in: Number of input dimensions (for encoder_class).
:param int n_hidden: Number of hidden nodes in each layer, one of 3 options:
1. Integer denoting the number of hidden nodes
2. Dictionary with {"single": 200, "multiple": 200} denoting the number of hidden nodes for each `encoder_mode` (See below)
3. Allowing different number of hidden nodes for each model site. Dictionary with the number of hidden nodes for single encode mode and each model site:
{
"single": 200
"n_s": 5,
"y_s": 5,
"z_sr": 128,
"w_sf": 200,
}
:param float init_param_scale: How to scale/normalise initial values for weights converting hidden layers to loc and scales.
:param float scales_offset: offset between the output of the NN and scales.
:param Callable encoder_class: Class that defines encoder network.
:param dict encoder_kwargs: Keyword arguments for encoder class.
:param dict multi_encoder_kwargs: Optional separate keyword arguments for encoder_class,
useful when encoder_mode == "single-multiple".
:param Callable encoder_instance: Encoder network instance, overrides class input and the input instance is copied with deepcopy.
:param str encoder_mode: Use single encoder for all variables ("single"), one encoder per variable ("multiple")
or a single encoder in the first step and multiple encoders in the second step ("single-multiple").
:param list hierarchical_sites: List of latent variables (model sites)
that have hierarchical dependencies.
If None, all sites are assumed to have hierarchical dependencies. If None, for the sites
that don't have upstream sites, the guide is representing/learning deviation from the prior.
"""
# 'element-wise' or 'scalar'
weight_type = "element-wise"
def __init__(
self,
model: Callable,
*,
amortised_plate_sites: dict,
n_in: int,
n_hidden: dict = None,
init_param_scale: float = 1 / 50,
init_scale: float = 0.1,
init_weight: float = 1.0,
init_loc_fn: Callable = init_to_mean(fallback=init_to_feasible),
encoder_class=FCLayersPyro,
encoder_kwargs=None,
multi_encoder_kwargs=None,
encoder_instance: torch.nn.Module = None,
encoder_mode: Literal["single", "multiple", "single-multiple"] = "single",
hierarchical_sites: Optional[list] = None,
bias=True,
use_posterior_lsw_encoders=False,
):
if not isinstance(init_scale, float) or not (init_scale > 0):
raise ValueError("Expected init_scale > 0. but got {}".format(init_scale))
super().__init__(model, init_loc_fn=init_loc_fn)
self._init_scale = init_scale
self._init_weight = init_weight
self._hierarchical_sites = hierarchical_sites
self.amortised_plate_sites = amortised_plate_sites
self.encoder_mode = encoder_mode
self.bias = bias
self.use_posterior_lsw_encoders = use_posterior_lsw_encoders
self._computing_median = False
self._computing_quantiles = False
self._quantile_values = None
self._computing_mi = False
self.mi = dict()
self.samples_for_mi = None
self.softplus = SoftplusTransform()
# default n_hidden values and checking input
if n_hidden is None:
n_hidden = {"single": 200, "multiple": 200}
else:
if isinstance(n_hidden, int):
n_hidden = {"single": n_hidden, "multiple": n_hidden}
elif not isinstance(n_hidden, dict):
raise ValueError("n_hidden must be either int or dict")
# process encoder kwargs, add n_hidden, create argument for multiple encoders
encoder_kwargs = deepcopy(encoder_kwargs) if isinstance(encoder_kwargs, dict) else dict()
encoder_kwargs["n_hidden"] = n_hidden["single"]
if multi_encoder_kwargs is None:
multi_encoder_kwargs = deepcopy(encoder_kwargs)
# save encoder parameters
self.encoder_kwargs = encoder_kwargs
self.multi_encoder_kwargs = multi_encoder_kwargs
self.single_n_in = n_in
self.multiple_n_in = n_in
self.n_hidden = n_hidden
if ("single" in encoder_mode) and ("multiple" in encoder_mode):
# if single network precedes multiple networks
self.multiple_n_in = self.n_hidden["single"]
self.encoder_class = encoder_class
self.encoder_instance = encoder_instance
self.init_param_scale = init_param_scale
[docs]
def get_posterior(
self,
name: str,
prior: Distribution,
) -> Union[Distribution, torch.Tensor]:
if self._computing_median:
return self._get_posterior_median(name, prior)
if self._computing_quantiles:
return self._get_posterior_quantiles(name, prior)
with helpful_support_errors({"name": name, "fn": prior}):
transform = biject_to(prior.support)
# If hierarchical_sites not specified all sites are assumed to be hierarchical
if (self._hierarchical_sites is None) or (name in self._hierarchical_sites):
loc, scale, weight = self._get_params(name, prior)
loc = loc + transform.inv(prior.mean) * weight # - torch.tensor(3.0, device=prior.mean.device)
posterior = dist.TransformedDistribution(
dist.Normal(loc, scale).to_event(transform.domain.event_dim),
transform.with_cache(),
)
return posterior
else:
# Fall back to mean field when hierarchical_sites list is not empty and site not in the list.
loc, scale = self._get_params(name, prior)
posterior = dist.TransformedDistribution(
dist.Normal(loc, scale).to_event(transform.domain.event_dim),
transform.with_cache(),
)
return posterior
[docs]
def encode(self, name: str, prior: Distribution):
"""
Apply encoder network to input data to obtain hidden layer encoding.
Parameters
----------
args
Pyro model args
kwargs
Pyro model kwargs
-------
"""
try:
args, kwargs = self.args_kwargs # stored as a tuple of (tuple, dict)
# get the data for NN from
in_names = self.amortised_plate_sites["input"]
x_in = [kwargs[i] if i in kwargs.keys() else args[i] for i in in_names]
# apply data transform before passing to NN
site_transform = self.amortised_plate_sites.get("site_transform", None)
if site_transform is not None and name in site_transform.keys():
# when input data transform and input dimensions differ between variables
in_transforms = site_transform[name]["input_transform"]
single_n_in = site_transform[name]["n_in"]
multiple_n_in = site_transform[name]["n_in"]
if ("single" in self.encoder_mode) and ("multiple" in self.encoder_mode):
# if single network precedes multiple networks
multiple_n_in = self.multiple_n_in
else:
in_transforms = self.amortised_plate_sites["input_transform"]
single_n_in = self.single_n_in
multiple_n_in = self.multiple_n_in
x_in = [in_transforms[i](x) for i, x in enumerate(x_in)]
# apply learnable normalisation before passing to NN:
input_normalisation = self.amortised_plate_sites.get("input_normalisation", None)
if input_normalisation is not None:
for i in range(len(self.amortised_plate_sites["input"])):
if input_normalisation[i]:
x_in[i] = x_in[i] * deep_getattr(self, f"input_normalisation_{i}")
if "single" in self.encoder_mode:
# encode with a single encoder
res = deep_getattr(self, "one_encoder")(*x_in)
if "multiple" in self.encoder_mode:
# when there is a second layer of multiple encoders fetch encoders and encode data
x_in[0] = res
res = deep_getattr(self.multiple_encoders, name)(*x_in)
else:
# when there are multiple encoders fetch encoders and encode data
res = deep_getattr(self.multiple_encoders, name)(*x_in)
return res
except AttributeError:
pass
# Initialize.
# create normalisation parameters if necessary:
input_normalisation = self.amortised_plate_sites.get("input_normalisation", None)
if input_normalisation is not None:
for i in range(len(self.amortised_plate_sites["input"])):
if input_normalisation[i]:
deep_setattr(
self,
f"input_normalisation_{i}",
PyroParam(torch.ones((1, single_n_in)).to(prior.mean.device).requires_grad_(True)),
)
# create encoder neural networks
if "single" in self.encoder_mode:
if self.encoder_instance is not None:
# copy provided encoder instance
one_encoder = deepcopy(self.encoder_instance).to(prior.mean.device)
# convert to pyro module
to_pyro_module_(one_encoder)
deep_setattr(self, "one_encoder", one_encoder)
else:
# create encoder instance from encoder class
deep_setattr(
self,
"one_encoder",
self.encoder_class(n_in=single_n_in, n_out=self.n_hidden["single"], **self.encoder_kwargs).to(
prior.mean.device
),
)
return self.encode(name, prior)
def _get_params(self, name: str, prior: Distribution):
if name not in self.amortised_plate_sites["sites"].keys():
# don't use amortisation unless requested (site in the list)
return super()._get_params(name, prior)
args, kwargs = self.args_kwargs # stored as a tuple of (tuple, dict)
hidden = self.encode(name, prior)
try:
linear_loc = deep_getattr(self.hidden2locs, name)
linear_scale = deep_getattr(self.hidden2scales, name)
if not self.use_posterior_lsw_encoders:
loc = linear_loc(hidden)
scale = self.softplus(linear_scale(hidden) + self._init_scale_unconstrained)
else:
args, kwargs = self.args_kwargs # stored as a tuple of (tuple, dict)
# get the data for NN from
in_names = self.amortised_plate_sites["input"]
x_in = [kwargs[i] if i in kwargs.keys() else args[i] for i in in_names]
x_in[0] = hidden
# apply data transform before passing to NN
site_transform = self.amortised_plate_sites.get("site_transform", None)
if site_transform is not None and name in site_transform.keys():
# when input data transform and input dimensions differ between variables
in_transforms = site_transform[name]["input_transform"]
else:
in_transforms = self.amortised_plate_sites["input_transform"]
x_in = [in_transforms[i](x) if i != 0 else x for i, x in enumerate(x_in)]
linear_loc_encoder = deep_getattr(self.hidden2locs, f"{name}.encoder")
linear_scale_encoder = deep_getattr(self.hidden2scales, f"{name}.encoder")
loc = linear_loc(linear_loc_encoder(*x_in))
scale = self.softplus(linear_scale(linear_scale_encoder(*x_in)) + self._init_scale_unconstrained)
# determine parameter dimensions
out_dim = self.amortised_plate_sites["sites"][name]
if isinstance(out_dim, tuple):
from string import ascii_lowercase
from einops import rearrange
variables = [ascii_lowercase[i] for i in range(len(out_dim))]
variables_str = " ".join(variables)
loc = rearrange(
loc, f"z ({variables_str}) -> z {variables_str}", **{v: dim for v, dim in zip(variables, out_dim)}
)
scale = rearrange(
scale, f"z ({variables_str}) -> z {variables_str}", **{v: dim for v, dim in zip(variables, out_dim)}
)
if (self._hierarchical_sites is None) or (name in self._hierarchical_sites):
if self.weight_type == "element-wise":
# weight is element-wise
linear_weight = deep_getattr(self.hidden2weights, name)
if not self.use_posterior_lsw_encoders:
weight = self.softplus(linear_weight(hidden) + self._init_weight_unconstrained)
else:
linear_weight_encoder = deep_getattr(self.hidden2weights, f"{name}.encoder")
weight = self.softplus(
linear_weight(linear_weight_encoder(hidden)) + self._init_weight_unconstrained
)
if isinstance(out_dim, tuple):
weight = rearrange(
weight,
f"z ({variables_str}) -> z {variables_str}",
**{v: dim for v, dim in zip(variables, out_dim)},
)
if self.weight_type == "scalar":
# weight is a single value parameter
weight = deep_getattr(self.weights, name)
return loc, scale, weight
else:
return loc, scale
except AttributeError:
pass
# Initialize.
with torch.no_grad():
init_scale = torch.full((), self._init_scale)
self._init_scale_unconstrained = self.softplus.inv(init_scale)
init_weight = torch.full((), self._init_weight)
self._init_weight_unconstrained = self.softplus.inv(init_weight)
# determine the number of hidden layers
if "multiple" in self.encoder_mode:
if name in self.n_hidden.keys():
n_hidden = self.n_hidden[name]
else:
n_hidden = self.n_hidden["multiple"]
elif "single" in self.encoder_mode:
n_hidden = self.n_hidden["single"]
# determine parameter dimensions
out_dim = self.amortised_plate_sites["sites"][name]
if isinstance(out_dim, tuple):
out_dim = np.product(out_dim)
deep_setattr(
self,
"hidden2locs." + name,
PyroModule[torch.nn.Linear](n_hidden, out_dim, bias=self.bias, device=prior.mean.device),
)
deep_setattr(
self,
"hidden2scales." + name,
PyroModule[torch.nn.Linear](n_hidden, out_dim, bias=self.bias, device=prior.mean.device),
)
if (self._hierarchical_sites is None) or (name in self._hierarchical_sites):
if self.weight_type == "scalar":
# weight is a single value parameter
deep_setattr(self, "weights." + name, PyroParam(init_weight, constraint=constraints.positive))
if self.weight_type == "element-wise":
# weight is element-wise
deep_setattr(
self,
"hidden2weights." + name,
PyroModule[torch.nn.Linear](n_hidden, out_dim, bias=self.bias, device=prior.mean.device),
)
if self.use_posterior_lsw_encoders:
# determine the number of hidden layers
if name in self.n_hidden.keys():
n_hidden = self.n_hidden[name]
else:
n_hidden = self.n_hidden["multiple"]
multi_encoder_kwargs = deepcopy(self.multi_encoder_kwargs)
multi_encoder_kwargs["n_hidden"] = n_hidden
# create multiple encoders
if self.encoder_instance is not None:
# copy instances
encoder_ = deepcopy(self.encoder_instance).to(prior.mean.device)
# convert to pyro module
to_pyro_module_(encoder_)
deep_setattr(
self,
f"hidden2locs.{name}.encoder",
encoder_,
)
deep_setattr(
self,
f"hidden2scales.{name}.encoder",
encoder_,
)
if (self._hierarchical_sites is None) or (name in self._hierarchical_sites):
deep_setattr(
self,
f"hidden2weights.{name}.encoder",
encoder_,
)
else:
# create instances
deep_setattr(
self,
f"hidden2locs.{name}.encoder",
self.encoder_class(n_in=n_hidden, n_out=n_hidden, **multi_encoder_kwargs).to(prior.mean.device),
)
deep_setattr(
self,
f"hidden2scales.{name}.encoder",
self.encoder_class(n_in=n_hidden, n_out=n_hidden, **multi_encoder_kwargs).to(prior.mean.device),
)
if (self._hierarchical_sites is None) or (name in self._hierarchical_sites):
deep_setattr(
self,
f"hidden2weights.{name}.encoder",
self.encoder_class(n_in=n_hidden, n_out=n_hidden, **multi_encoder_kwargs).to(prior.mean.device),
)
return self._get_params(name, prior)
@torch.no_grad()
def _get_posterior_median(self, name, prior):
transform = biject_to(prior.support)
if (self._hierarchical_sites is None) or (name in self._hierarchical_sites):
loc, scale, weight = self._get_params(name, prior)
loc = loc + transform.inv(prior.mean) * weight
else:
loc, scale = self._get_params(name, prior)
return transform(loc)
[docs]
def quantiles(self, quantiles, *args, **kwargs):
self._computing_quantiles = True
self._quantile_values = quantiles
try:
return self(*args, **kwargs)
finally:
self._computing_quantiles = False
@torch.no_grad()
def _get_posterior_quantiles(self, name, prior):
transform = biject_to(prior.support)
if (self._hierarchical_sites is None) or (name in self._hierarchical_sites):
loc, scale, weight = self._get_params(name, prior)
loc = loc + transform.inv(prior.mean) * weight
else:
loc, scale = self._get_params(name, prior)
site_quantiles = torch.tensor(self._quantile_values, dtype=loc.dtype, device=loc.device)
site_quantiles_values = dist.Normal(loc, scale).icdf(site_quantiles)
return transform(site_quantiles_values)