Skip to content

BayesianNeuralFieldVI¤

Bases: BayesianNeuralFieldEstimator

Fits models using stochastic ensembles of surrogate posteriors from VI.

Implementation of BayesianNeuralFieldEstimator using variational inference (VI).

Source code in bayesnf/spatiotemporal.py
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
class BayesianNeuralFieldVI(BayesianNeuralFieldEstimator):
  """Fits models using stochastic ensembles of surrogate posteriors from VI.

  Implementation of
  [BayesianNeuralFieldEstimator](BayesianNeuralFieldEstimator.md) using
  variational inference (VI).
  """

  _ensemble_dims = 3
  _scale_epochs_by_batch_size = True

  def fit(
      self,
      table,
      seed,
      ensemble_size=16,
      learning_rate=0.01,
      num_epochs=1_000,
      sample_size_posterior=30,
      sample_size_divergence=5,
      kl_weight=0.1,
      batch_size=None,
      ) -> BayesianNeuralFieldEstimator:
    """Run inference using stochastic variational inference ensembles.

    Args:
      table (pandas.DataFrame):
        See documentation of
        [`table`][bayesnf.spatiotemporal.BayesianNeuralFieldEstimator.fit]
        in the base class.

      seed (jax.random.PRNGKey): The jax random key.

      ensemble_size (int): Number of particles (i.e., surrogate posteriors)
        in the ensemble, **per device**. The available devices can be found
        via `jax.devices()`.

      learning_rate (float): Learning rate for SGD.

      num_epochs (int): Number of full epochs through the training data.

      sample_size_posterior (int): Number of samples of "posterior" model
        parameters draw from each surrogate posterior when making
        predictions.

      sample_size_divergence (int): number of Monte Carlo samples to use in
        estimating the variational divergence. Larger values may stabilize
        the optimization, but at higher cost per step in time and memory.
        See [`tfp.vi.fit_surrogate_posterior_stateless`](
        https://www.tensorflow.org/probability/api_docs/python/tfp/vi/fit_surrogate_posterior_stateless)
        for further details.

      kl_weight (float): Weighting of the KL divergence term in VI. The
        goal is to find a surrogate posterior `q(z)` that maximizes a
        version of the ELBO with the `KL(surrogate posterior || prior)`
        term scaled by `kl_weight`

            E_z~q [log p(x|z)] - kl_weight * KL(q || p)

        Reference
        > Weight Uncertainty in Neural Network
        > Charles Blundell, Julien Cornebise, Koray Kavukcuoglu, Daan Wierstra.
        > Proceedings of the 32nd International Conference on Machine Learning.
        > PMLR 37:1613-1622, 2015.
        > <https://proceedings.mlr.press/v37/blundell15>

      batch_size (None | int): If specified, the log probability in each
        step of variational inference  is computed on a batch of this size.
        Default is `None`, meaning full-batch.

    Returns:
      Instance of self.
    """
    train_data = self.data_handler.get_train(table)
    train_target = self.data_handler.get_target(table)
    if batch_size is None:
      batch_size = train_data.shape[0]
    if self._scale_epochs_by_batch_size:
      num_epochs = num_epochs * (train_data.shape[0] // batch_size)
    model_args = self._model_args((batch_size, train_data.shape[-1]))
    _, self.losses_, self.params_ = inference.fit_vi(
        train_data,
        train_target,
        seed=seed,
        observation_model=self.observation_model,
        model_args=model_args,
        ensemble_size=ensemble_size,
        learning_rate=learning_rate,
        num_epochs=num_epochs,
        sample_size_posterior=sample_size_posterior,
        sample_size_divergence=sample_size_divergence,
        kl_weight=kl_weight,
        batch_size=batch_size,
    )
    return self

fit ¤

fit(table, seed, ensemble_size=16, learning_rate=0.01, num_epochs=1000, sample_size_posterior=30, sample_size_divergence=5, kl_weight=0.1, batch_size=None)

Run inference using stochastic variational inference ensembles.

PARAMETER DESCRIPTION
table

See documentation of table in the base class.

TYPE: DataFrame

seed

The jax random key.

TYPE: PRNGKey

ensemble_size

Number of particles (i.e., surrogate posteriors) in the ensemble, per device. The available devices can be found via jax.devices().

TYPE: int DEFAULT: 16

learning_rate

Learning rate for SGD.

TYPE: float DEFAULT: 0.01

num_epochs

Number of full epochs through the training data.

TYPE: int DEFAULT: 1000

sample_size_posterior

Number of samples of "posterior" model parameters draw from each surrogate posterior when making predictions.

TYPE: int DEFAULT: 30

sample_size_divergence

number of Monte Carlo samples to use in estimating the variational divergence. Larger values may stabilize the optimization, but at higher cost per step in time and memory. See tfp.vi.fit_surrogate_posterior_stateless for further details.

TYPE: int DEFAULT: 5

kl_weight

Weighting of the KL divergence term in VI. The goal is to find a surrogate posterior q(z) that maximizes a version of the ELBO with the KL(surrogate posterior || prior) term scaled by kl_weight

E_z~q [log p(x|z)] - kl_weight * KL(q || p)

Reference

Weight Uncertainty in Neural Network Charles Blundell, Julien Cornebise, Koray Kavukcuoglu, Daan Wierstra. Proceedings of the 32nd International Conference on Machine Learning. PMLR 37:1613-1622, 2015. https://proceedings.mlr.press/v37/blundell15

TYPE: float DEFAULT: 0.1

batch_size

If specified, the log probability in each step of variational inference is computed on a batch of this size. Default is None, meaning full-batch.

TYPE: None | int DEFAULT: None

RETURNS DESCRIPTION
BayesianNeuralFieldEstimator

Instance of self.

Source code in bayesnf/spatiotemporal.py
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
def fit(
    self,
    table,
    seed,
    ensemble_size=16,
    learning_rate=0.01,
    num_epochs=1_000,
    sample_size_posterior=30,
    sample_size_divergence=5,
    kl_weight=0.1,
    batch_size=None,
    ) -> BayesianNeuralFieldEstimator:
  """Run inference using stochastic variational inference ensembles.

  Args:
    table (pandas.DataFrame):
      See documentation of
      [`table`][bayesnf.spatiotemporal.BayesianNeuralFieldEstimator.fit]
      in the base class.

    seed (jax.random.PRNGKey): The jax random key.

    ensemble_size (int): Number of particles (i.e., surrogate posteriors)
      in the ensemble, **per device**. The available devices can be found
      via `jax.devices()`.

    learning_rate (float): Learning rate for SGD.

    num_epochs (int): Number of full epochs through the training data.

    sample_size_posterior (int): Number of samples of "posterior" model
      parameters draw from each surrogate posterior when making
      predictions.

    sample_size_divergence (int): number of Monte Carlo samples to use in
      estimating the variational divergence. Larger values may stabilize
      the optimization, but at higher cost per step in time and memory.
      See [`tfp.vi.fit_surrogate_posterior_stateless`](
      https://www.tensorflow.org/probability/api_docs/python/tfp/vi/fit_surrogate_posterior_stateless)
      for further details.

    kl_weight (float): Weighting of the KL divergence term in VI. The
      goal is to find a surrogate posterior `q(z)` that maximizes a
      version of the ELBO with the `KL(surrogate posterior || prior)`
      term scaled by `kl_weight`

          E_z~q [log p(x|z)] - kl_weight * KL(q || p)

      Reference
      > Weight Uncertainty in Neural Network
      > Charles Blundell, Julien Cornebise, Koray Kavukcuoglu, Daan Wierstra.
      > Proceedings of the 32nd International Conference on Machine Learning.
      > PMLR 37:1613-1622, 2015.
      > <https://proceedings.mlr.press/v37/blundell15>

    batch_size (None | int): If specified, the log probability in each
      step of variational inference  is computed on a batch of this size.
      Default is `None`, meaning full-batch.

  Returns:
    Instance of self.
  """
  train_data = self.data_handler.get_train(table)
  train_target = self.data_handler.get_target(table)
  if batch_size is None:
    batch_size = train_data.shape[0]
  if self._scale_epochs_by_batch_size:
    num_epochs = num_epochs * (train_data.shape[0] // batch_size)
  model_args = self._model_args((batch_size, train_data.shape[-1]))
  _, self.losses_, self.params_ = inference.fit_vi(
      train_data,
      train_target,
      seed=seed,
      observation_model=self.observation_model,
      model_args=model_args,
      ensemble_size=ensemble_size,
      learning_rate=learning_rate,
      num_epochs=num_epochs,
      sample_size_posterior=sample_size_posterior,
      sample_size_divergence=sample_size_divergence,
      kl_weight=kl_weight,
      batch_size=batch_size,
  )
  return self