Skip to content

BayesianNeuralFieldEstimator¤

Base class for BayesNF estimators.

This class should not be initialized directly, but rather one of the three subclasses that implement different model learning procedures:

All three classes share the same __init__ method described below.

Source code in bayesnf/spatiotemporal.py
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
class BayesianNeuralFieldEstimator:
  """Base class for BayesNF estimators.

  This class should not be initialized directly, but rather one of the three
  subclasses that implement different model learning procedures:

  - [BayesianNeuralFieldVI](BayesianNeuralFieldVI.md), for
    ensembles of surrogate posteriors from variational inference.

  - [BayesianNeuralFieldMAP](BayesianNeuralFieldMAP.md), for
    stochastic ensembles of maximum-a-posteriori estimates.

  - [BayesianNeuralFieldMLE](BayesianNeuralFieldMLE.md), for
    stochastic ensembles of maximum likelihood estimates.

  All three classes share the same `__init__` method described below.
  """

  _ensemble_dims: int
  _prior_weight: float = 1.0
  _scale_epochs_by_batch_size: bool = False

  def __init__(
      self,
      *,
      feature_cols: Sequence[str],
      target_col: str,
      seasonality_periods: Sequence[float | str] | None = None,
      num_seasonal_harmonics: Sequence[int] | None = None,
      fourier_degrees: Sequence[float] | None = None,
      interactions: Sequence[tuple[int, int]] | None = None,
      freq: str | None = None,
      timetype: str = 'index',
      depth: int = 2,
      width: int = 512,
      observation_model: str = 'NORMAL',
      standardize: Sequence[str] | None = None,
      ):
    """Shared initialization for subclasses of BayesianNeuralFieldEstimator.

    Args:
      feature_cols: Names of columns to use as features in the training data
        frame. The first entry denotes the name of the time variable, the
        remaining entries (if any) denote names of the spatial features.
      target_col: Name of the target column representing the spatial field.
      seasonality_periods: A list of numbers representing the seasonal
        frequencies of the data in the time domain. If timetype == 'index', then
        it is possible to specify numeric frequencies by using string short
        hands such as 'W', 'D', etc., which correspond to a valid Pandas
        frequency. See Pandas [Offset
        Aliases](https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset-aliases)
        for valid string values.
      num_seasonal_harmonics: A list of seasonal harmonics, one for each entry
        in `seasonality_periods`. The number of seasonal harmonics (h) for a
        given seasonal period `p` must satisfy `h < p//2`. It is an error fir
        `len(num_seasonal_harmonics) != len(seasonality_periods)`. Should be
        used only if `timetype == 'index'`.
      fourier_degrees: A list of integer degrees for the Fourier features of the
        inputs. If given, must have the same length as `feature_cols`.
      interactions: A list of tuples of column indexes for the first-order
        interactions. For example `[(0,1), (1,2)]` creates two interaction
        features  - `feature_cols[0] * feature_cols[1]` - `feature_cols[1] *
        feature_cols[2]`
      freq: A frequency string for the sampling rate at which the data is
        collected. See the Pandas [Offset
        Aliases](https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset-aliases)
        for valid values. Should be used if and only if `timetype == 'index'`.
      timetype: Either `index` or `float`. If `index`, then the time column must
        be a `datetime` type and `freq` must be given. Otherwise, if `float`,
        then the time column must be `float`.
      depth: The number of hidden layers in the BayesNF architecture.
      width: The number of hidden units in each layer.
      observation_model: The aleatoric noise model for the observed data. The
        options are `NORMAL` (Gaussian noise), `NB` (negative binomial noise),
        or `ZNB` (zero-inflated negative binomial noise).
      standardize: List of columns that should be standardized. It is highly
        recommended to standardize `feature_cols[1:]`. It is an error if
        `features_cols[0]` (the time variable) is in `standardize`.
    """
    self.num_seasonal_harmonics = num_seasonal_harmonics
    self.seasonality_periods = seasonality_periods
    self.observation_model = observation_model
    self.depth = depth
    self.width = width
    self.feature_cols = feature_cols
    self.target_col = target_col
    self.timetype = timetype
    self.freq = freq
    self.fourier_degrees = fourier_degrees
    self.standardize = standardize
    self.interactions = interactions

    self.losses_ = None
    self.params_ = None
    self.data_handler = SpatiotemporalDataHandler(
        self.feature_cols,
        self.target_col,
        self.timetype,
        self.freq,
        standardize=self.standardize)

  def _get_fourier_degrees(self, batch_shape: tuple[int, ...]) -> np.ndarray:
    """Set default fourier degrees, or verify shape is correct."""
    if self.fourier_degrees is None:
      fourier_degrees = np.full(batch_shape[-1], 5, dtype=int)
    else:
      fourier_degrees = np.atleast_1d(self.fourier_degrees).astype(int)
      if fourier_degrees.shape[-1] != batch_shape[-1]:
        raise ValueError(
            'The length of fourier_degrees ({}) must match the '
            'input dimension dimension ({}).'.format(
                fourier_degrees.shape[-1], batch_shape[-1]
            )
        )
    return fourier_degrees

  def _get_interactions(self) -> np.ndarray:
    """Set default fourier degrees, or verify shape is correct."""
    if self.interactions is None:
      interactions = np.zeros((0, 2), dtype=int)
    else:
      interactions = np.array(self.interactions).astype(int)
      if np.ndim(interactions) != 2 or interactions.shape[-1] != 2:
        raise ValueError(
            'The argument for `interactions` should be a 2-d array of integers '
            'of shape (N, 2), indicating the column indices to interact (the '
            f' passed shape was {interactions.shape})')
    return interactions

  def _get_seasonality_periods(self):
    """Return array of seasonal periods."""
    if (
        (self.timetype == 'index' and self.freq is None) or
        (self.timetype == 'float' and self.freq is not None)):
      raise ValueError(f'Invalid {self.freq=} with {self.timetype=}.')
    if self.seasonality_periods is None:
      return np.zeros(0)
    if self.timetype == 'index':
      return seasonalities_to_array(self.seasonality_periods, self.freq)
    if self.timetype == 'float':
      return np.asarray(self.seasonality_periods, dtype=float)
    assert False, f'Impossible {self.timetype=}.'

  def _get_num_seasonal_harmonics(self):
    """Return array of seasonal harmonics per seasonal period."""
    # Discrete time.
    if self.timetype == 'index':
      return (
          np.array(self.num_seasonal_harmonics)
          if self.num_seasonal_harmonics is not None else
          np.zeros(0))
    # Continuous time.
    if self.timetype == 'float':
      if self.num_seasonal_harmonics is not None:
        raise ValueError(
            f'Cannot use num_seasonal_harmonics with {self.timetype=}.'
        )
      # HACK: models.make_seasonal_frequencies assumes the data is discrete
      # time where each harmonic h is between 1, ..., p/2 and the harmonic
      # factors are np.arange(1, h + 1). Since our goal with continuous
      # time data is exactly 1 harmonic per seasonal factor, any h between
      # 0 and min(0.5, p/2) will work, as np.arange(1, 1+h) = [1]
      return np.fmin(.5, self._get_seasonality_periods() / 2)
    assert False, f'Impossible {self.timetype=}.'

  def _model_args(self, batch_shape):
    return {
        'depth': self.depth,
        'input_scales': self.data_handler.get_input_scales(),
        'num_seasonal_harmonics': self._get_num_seasonal_harmonics(),
        'seasonality_periods': self._get_seasonality_periods(),
        'width': self.width,
        'init_x': batch_shape,
        'fourier_degrees': self._get_fourier_degrees(batch_shape),
        'interactions': self._get_interactions(),
    }

  def predict(self, table, quantiles=(0.5,), approximate_quantiles=False):
    """Make predictions of the target column at new times.

    Args:
      table (pandas.DataFrame):
        Field locations at which to make new predictions. Same as `table` in
        [`fit`](), except that `self.target_col` need not be in `table`.

      quantiles (Sequence[float]):
        The list of quantiles to compute.

      approximate_quantiles (bool):
        If `False,` uses Chandrupatla root finding to compute quantiles.
        If `True`, uses a heuristic approximation of the quantiles.

    Returns:
      means (np.ndarray):
        The predicted means from each particle in the learned ensemble.
        The shape is `(num_devices, ensemble_size // num_devices, len(table))`
        and can be flattened to a 2D array using `np.row_stack(means)`.
        Related https://github.com/google/bayesnf/issues/17

      quantiles (List[np.ndarray]):
        A list of numpy arrays, one per requested quantile.
        The length of each array in the list is `len(table)`.

    """
    test_data = self.data_handler.get_test(table)
    return inference.predict_bnf(
        test_data,
        self.observation_model,
        params=self.params_,
        model_args=self._model_args(test_data.shape),
        quantiles=quantiles,
        ensemble_dims=self._ensemble_dims,
        approximate_quantiles=approximate_quantiles,
    )

  def fit(self, table, seed):
    """Run inference given a training data `table` and `seed`.

    Cannot be directly called on `BayesianNeuralFieldEstimator`.

    Args:
      table (pandas.DataFrame):
        A pandas DataFrame representing the
        training data. It has the following requirements:

        - The columns of `table` should contain all `self.feature_cols`
          and the `self.target_col`.

        - The type of the "time" column (i.e., `self.feature_cols[0]`)
          should be `datetime`. To ensure this requirement holds, see
          [`pandas.to_datetime`](
          https://pandas.pydata.org/docs/reference/api/pandas.to_datetime.html).
          The types of the remaining feature columns should be numeric.

      seed (jax.random.PRNGKey): The jax random key.
    """
    raise NotImplementedError('Should be implemented by subclass')

  def likelihood_model(self, table: pd.DataFrame) -> tfd.Distribution:
    """Access the predictive distribution over new field values in `table`.

    NOTE: Must be called after [`fit`]().

    Args:
      table (pandas.DataFrame):
        Field locations at which to make new predictions. Same as `table` in
        [`fit`](), except that `self.target_col` need not be in `table`.

    Returns:
      A probability distribution representing the predictive distribution
        over `self.target_col` at the new field values in `table`.
        See [tfp.distributions.Distribution](
        https://www.tensorflow.org/probability/api_docs/python/tfp/distributions/Distribution)
        for the methods associated with this object.
    """
    test_data = self.data_handler.get_test(table)
    mlp, mlp_template = inference.make_model(
        **self._model_args(test_data.shape))
    for _ in range(self._ensemble_dims - 1):
      mlp.apply = jax.vmap(mlp.apply, in_axes=(0, None))
    mlp.apply = jax.pmap(mlp.apply, in_axes=(0, None))

    # This allows the likelihood to broadcast correctly with the batch of
    # predictions.
    params = self.params_._replace(**{  # pytype: disable=attribute-error
        self.params_._fields[i]: self.params_[i][..., jnp.newaxis]  # pytype: disable=unsupported-operands,attribute-error
        for i in range(3)})

    return models.make_likelihood_model(
        params,
        jnp.array(test_data),
        mlp,
        mlp_template,
        self.observation_model)

__init__ ¤

__init__(*, feature_cols, target_col, seasonality_periods=None, num_seasonal_harmonics=None, fourier_degrees=None, interactions=None, freq=None, timetype='index', depth=2, width=512, observation_model='NORMAL', standardize=None)

Shared initialization for subclasses of BayesianNeuralFieldEstimator.

PARAMETER DESCRIPTION
feature_cols

Names of columns to use as features in the training data frame. The first entry denotes the name of the time variable, the remaining entries (if any) denote names of the spatial features.

TYPE: Sequence[str]

target_col

Name of the target column representing the spatial field.

TYPE: str

seasonality_periods

A list of numbers representing the seasonal frequencies of the data in the time domain. If timetype == 'index', then it is possible to specify numeric frequencies by using string short hands such as 'W', 'D', etc., which correspond to a valid Pandas frequency. See Pandas Offset Aliases for valid string values.

TYPE: Sequence[float | str] | None DEFAULT: None

num_seasonal_harmonics

A list of seasonal harmonics, one for each entry in seasonality_periods. The number of seasonal harmonics (h) for a given seasonal period p must satisfy h < p//2. It is an error fir len(num_seasonal_harmonics) != len(seasonality_periods). Should be used only if timetype == 'index'.

TYPE: Sequence[int] | None DEFAULT: None

fourier_degrees

A list of integer degrees for the Fourier features of the inputs. If given, must have the same length as feature_cols.

TYPE: Sequence[float] | None DEFAULT: None

interactions

A list of tuples of column indexes for the first-order interactions. For example [(0,1), (1,2)] creates two interaction features - feature_cols[0] * feature_cols[1] - feature_cols[1] * feature_cols[2]

TYPE: Sequence[tuple[int, int]] | None DEFAULT: None

freq

A frequency string for the sampling rate at which the data is collected. See the Pandas Offset Aliases for valid values. Should be used if and only if timetype == 'index'.

TYPE: str | None DEFAULT: None

timetype

Either index or float. If index, then the time column must be a datetime type and freq must be given. Otherwise, if float, then the time column must be float.

TYPE: str DEFAULT: 'index'

depth

The number of hidden layers in the BayesNF architecture.

TYPE: int DEFAULT: 2

width

The number of hidden units in each layer.

TYPE: int DEFAULT: 512

observation_model

The aleatoric noise model for the observed data. The options are NORMAL (Gaussian noise), NB (negative binomial noise), or ZNB (zero-inflated negative binomial noise).

TYPE: str DEFAULT: 'NORMAL'

standardize

List of columns that should be standardized. It is highly recommended to standardize feature_cols[1:]. It is an error if features_cols[0] (the time variable) is in standardize.

TYPE: Sequence[str] | None DEFAULT: None

Source code in bayesnf/spatiotemporal.py
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
def __init__(
    self,
    *,
    feature_cols: Sequence[str],
    target_col: str,
    seasonality_periods: Sequence[float | str] | None = None,
    num_seasonal_harmonics: Sequence[int] | None = None,
    fourier_degrees: Sequence[float] | None = None,
    interactions: Sequence[tuple[int, int]] | None = None,
    freq: str | None = None,
    timetype: str = 'index',
    depth: int = 2,
    width: int = 512,
    observation_model: str = 'NORMAL',
    standardize: Sequence[str] | None = None,
    ):
  """Shared initialization for subclasses of BayesianNeuralFieldEstimator.

  Args:
    feature_cols: Names of columns to use as features in the training data
      frame. The first entry denotes the name of the time variable, the
      remaining entries (if any) denote names of the spatial features.
    target_col: Name of the target column representing the spatial field.
    seasonality_periods: A list of numbers representing the seasonal
      frequencies of the data in the time domain. If timetype == 'index', then
      it is possible to specify numeric frequencies by using string short
      hands such as 'W', 'D', etc., which correspond to a valid Pandas
      frequency. See Pandas [Offset
      Aliases](https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset-aliases)
      for valid string values.
    num_seasonal_harmonics: A list of seasonal harmonics, one for each entry
      in `seasonality_periods`. The number of seasonal harmonics (h) for a
      given seasonal period `p` must satisfy `h < p//2`. It is an error fir
      `len(num_seasonal_harmonics) != len(seasonality_periods)`. Should be
      used only if `timetype == 'index'`.
    fourier_degrees: A list of integer degrees for the Fourier features of the
      inputs. If given, must have the same length as `feature_cols`.
    interactions: A list of tuples of column indexes for the first-order
      interactions. For example `[(0,1), (1,2)]` creates two interaction
      features  - `feature_cols[0] * feature_cols[1]` - `feature_cols[1] *
      feature_cols[2]`
    freq: A frequency string for the sampling rate at which the data is
      collected. See the Pandas [Offset
      Aliases](https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset-aliases)
      for valid values. Should be used if and only if `timetype == 'index'`.
    timetype: Either `index` or `float`. If `index`, then the time column must
      be a `datetime` type and `freq` must be given. Otherwise, if `float`,
      then the time column must be `float`.
    depth: The number of hidden layers in the BayesNF architecture.
    width: The number of hidden units in each layer.
    observation_model: The aleatoric noise model for the observed data. The
      options are `NORMAL` (Gaussian noise), `NB` (negative binomial noise),
      or `ZNB` (zero-inflated negative binomial noise).
    standardize: List of columns that should be standardized. It is highly
      recommended to standardize `feature_cols[1:]`. It is an error if
      `features_cols[0]` (the time variable) is in `standardize`.
  """
  self.num_seasonal_harmonics = num_seasonal_harmonics
  self.seasonality_periods = seasonality_periods
  self.observation_model = observation_model
  self.depth = depth
  self.width = width
  self.feature_cols = feature_cols
  self.target_col = target_col
  self.timetype = timetype
  self.freq = freq
  self.fourier_degrees = fourier_degrees
  self.standardize = standardize
  self.interactions = interactions

  self.losses_ = None
  self.params_ = None
  self.data_handler = SpatiotemporalDataHandler(
      self.feature_cols,
      self.target_col,
      self.timetype,
      self.freq,
      standardize=self.standardize)

predict ¤

predict(table, quantiles=(0.5), approximate_quantiles=False)

Make predictions of the target column at new times.

PARAMETER DESCRIPTION
table

Field locations at which to make new predictions. Same as table in fit, except that self.target_col need not be in table.

TYPE: DataFrame

quantiles

The list of quantiles to compute.

TYPE: Sequence[float] DEFAULT: (0.5)

approximate_quantiles

If False, uses Chandrupatla root finding to compute quantiles. If True, uses a heuristic approximation of the quantiles.

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
means

The predicted means from each particle in the learned ensemble. The shape is (num_devices, ensemble_size // num_devices, len(table)) and can be flattened to a 2D array using np.row_stack(means). Related https://github.com/google/bayesnf/issues/17

TYPE: ndarray

quantiles

A list of numpy arrays, one per requested quantile. The length of each array in the list is len(table).

TYPE: List[ndarray]

Source code in bayesnf/spatiotemporal.py
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
def predict(self, table, quantiles=(0.5,), approximate_quantiles=False):
  """Make predictions of the target column at new times.

  Args:
    table (pandas.DataFrame):
      Field locations at which to make new predictions. Same as `table` in
      [`fit`](), except that `self.target_col` need not be in `table`.

    quantiles (Sequence[float]):
      The list of quantiles to compute.

    approximate_quantiles (bool):
      If `False,` uses Chandrupatla root finding to compute quantiles.
      If `True`, uses a heuristic approximation of the quantiles.

  Returns:
    means (np.ndarray):
      The predicted means from each particle in the learned ensemble.
      The shape is `(num_devices, ensemble_size // num_devices, len(table))`
      and can be flattened to a 2D array using `np.row_stack(means)`.
      Related https://github.com/google/bayesnf/issues/17

    quantiles (List[np.ndarray]):
      A list of numpy arrays, one per requested quantile.
      The length of each array in the list is `len(table)`.

  """
  test_data = self.data_handler.get_test(table)
  return inference.predict_bnf(
      test_data,
      self.observation_model,
      params=self.params_,
      model_args=self._model_args(test_data.shape),
      quantiles=quantiles,
      ensemble_dims=self._ensemble_dims,
      approximate_quantiles=approximate_quantiles,
  )

fit ¤

fit(table, seed)

Run inference given a training data table and seed.

Cannot be directly called on BayesianNeuralFieldEstimator.

PARAMETER DESCRIPTION
table

A pandas DataFrame representing the training data. It has the following requirements:

  • The columns of table should contain all self.feature_cols and the self.target_col.

  • The type of the "time" column (i.e., self.feature_cols[0]) should be datetime. To ensure this requirement holds, see pandas.to_datetime. The types of the remaining feature columns should be numeric.

TYPE: DataFrame

seed

The jax random key.

TYPE: PRNGKey

Source code in bayesnf/spatiotemporal.py
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
def fit(self, table, seed):
  """Run inference given a training data `table` and `seed`.

  Cannot be directly called on `BayesianNeuralFieldEstimator`.

  Args:
    table (pandas.DataFrame):
      A pandas DataFrame representing the
      training data. It has the following requirements:

      - The columns of `table` should contain all `self.feature_cols`
        and the `self.target_col`.

      - The type of the "time" column (i.e., `self.feature_cols[0]`)
        should be `datetime`. To ensure this requirement holds, see
        [`pandas.to_datetime`](
        https://pandas.pydata.org/docs/reference/api/pandas.to_datetime.html).
        The types of the remaining feature columns should be numeric.

    seed (jax.random.PRNGKey): The jax random key.
  """
  raise NotImplementedError('Should be implemented by subclass')

likelihood_model ¤

likelihood_model(table)

Access the predictive distribution over new field values in table.

NOTE: Must be called after fit.

PARAMETER DESCRIPTION
table

Field locations at which to make new predictions. Same as table in fit, except that self.target_col need not be in table.

TYPE: DataFrame

RETURNS DESCRIPTION
Distribution

A probability distribution representing the predictive distribution over self.target_col at the new field values in table. See tfp.distributions.Distribution for the methods associated with this object.

Source code in bayesnf/spatiotemporal.py
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
def likelihood_model(self, table: pd.DataFrame) -> tfd.Distribution:
  """Access the predictive distribution over new field values in `table`.

  NOTE: Must be called after [`fit`]().

  Args:
    table (pandas.DataFrame):
      Field locations at which to make new predictions. Same as `table` in
      [`fit`](), except that `self.target_col` need not be in `table`.

  Returns:
    A probability distribution representing the predictive distribution
      over `self.target_col` at the new field values in `table`.
      See [tfp.distributions.Distribution](
      https://www.tensorflow.org/probability/api_docs/python/tfp/distributions/Distribution)
      for the methods associated with this object.
  """
  test_data = self.data_handler.get_test(table)
  mlp, mlp_template = inference.make_model(
      **self._model_args(test_data.shape))
  for _ in range(self._ensemble_dims - 1):
    mlp.apply = jax.vmap(mlp.apply, in_axes=(0, None))
  mlp.apply = jax.pmap(mlp.apply, in_axes=(0, None))

  # This allows the likelihood to broadcast correctly with the batch of
  # predictions.
  params = self.params_._replace(**{  # pytype: disable=attribute-error
      self.params_._fields[i]: self.params_[i][..., jnp.newaxis]  # pytype: disable=unsupported-operands,attribute-error
      for i in range(3)})

  return models.make_likelihood_model(
      params,
      jnp.array(test_data),
      mlp,
      mlp_template,
      self.observation_model)