Source code for cell2fate._pyro_base_cell2fate_module

from scvi.module.base import PyroBaseModuleClass
from pyro.infer.autoguide import AutoNormal, AutoHierarchicalNormalMessenger

from cell2fate._pyro_mixin import AutoGuideMixinModule, init_to_value

[docs] class Cell2FateBaseModule(PyroBaseModuleClass, AutoGuideMixinModule): def __init__( self, model, amortised: bool = False, encoder_mode = "single", encoder_kwargs=None, data_transform="log1p", guide_class=AutoHierarchicalNormalMessenger, encoder_instance=None, **kwargs, ): """ Module class which defines AutoGuide given model. Supports multiple model architectures. Parameters ---------- amortised boolean, use a Neural Network to approximate posterior distribution of location-specific (local) parameters? encoder_mode Use single encoder for all variables ("single"), one encoder per variable ("multiple") or a single encoder in the first step and multiple encoders in the second step ("single-multiple"). encoder_kwargs arguments for Neural Network construction (scvi.nn.FCLayers) kwargs arguments for specific model class - e.g. number of genes, values of the prior distribution """ super().__init__() self.hist = [] self._model = model(**kwargs) self._amortised = amortised self._guide = self._create_autoguide( model=self.model, amortised=self.is_amortised, encoder_kwargs=encoder_kwargs, encoder_mode=encoder_mode, init_loc_fn=self.init_to_value, guide_class=guide_class, n_cat_list=[kwargs["n_batch"]], encoder_instance=encoder_instance ) self._get_fn_args_from_batch = self._model._get_fn_args_from_batch @property def model(self): return self._model @property def guide(self): return self._guide @property def is_amortised(self): return self._amortised @property def list_obs_plate_vars(self): return self.model.list_obs_plate_vars()
[docs] def init_to_value(self, site): if getattr(self.model, "np_init_vals", None) is not None: init_vals = {k: getattr(self.model, f"init_val_{k}") for k in self.model.np_init_vals.keys()} else: init_vals = dict() return init_to_value(site=site, values=init_vals)