Skip to content

Bayesian Neural Fields for Spatiotemporal Prediction¤

Unittests PyPI version

This is not an officially supported Google product.

Spatially referenced time series (i.e., spatiotemporal) datasets are ubiquitous in scientific, engineering, and business-intelligence applications. This repository contains an implementation of the Bayesian Neural Field (BayesNF) a novel spatiotemporal modeling method that integrates hierarchical probabilistic modeling for accurate uncertainty estimation with deep neural networks for high-capacity function approximation.

Bayesian Neural Fields infer joint probability distributions over field values at arbitrary points in time and space, which makes the model suitable for many data-analysis tasks including spatial interpolation, temporal forecasting, and variography. Posterior inference is conducted using variationally learned surrogates trained via mini-batch stochastic gradient descent for handling large-scale data.

The probabilistic model and inference algorithm are described in the following paper:

title   = {Scalable Spatiotemporal Prediction with {Bayesian} Neural Fields},
authors = {Saad, Feras and Burnim, Jacob and Carroll, Colin and Patton, Brian and Köster, Urs  and Saurous, Rif A. and Hoffman, Matthew}
journal = {arXiv},
volume  = {2403.07657},
year    = {2024},
doi     = {10.48550/arXiv.2403.07657},


bayesnf can be installed from the Python Package Index (PyPI) using:

python -m pip install bayesnf

The typical install time is 1 minute. This software is tested on Python 3.11 with a standard Debian GNU/Linux setup. The large-scale experiments in scripts/ were run using TPU v3-8 accelerators. To run BayesNF locally on medium to large-scale data, a GPU is required at minimum.

Documentation and Tutorials¤

Please visit

Quick start¤

# Load a dataframe with "long" format spatiotemporal data.
df_train = pd.read_csv('chickenpox.5.train.csv',
  index_col=0, parse_dates=['datetime'])

# Build a BayesianNeuralFieldEstimator
model = BayesianNeuralFieldMAP(
  seasonality_periods=['M', 'Y'],
  num_seasonal_harmonics=[2, 10],
  feature_cols=['datetime', 'latitude', 'longitude'],
  standardize=['latitude', 'longitude'],
  interactions=[(0, 1), (0, 2), (1, 2)])

# Fit the model.
model =

# Make predictions of means and quantiles on test data.
df_test = pd.read_csv('chickenpox.5.test.csv',
  index_col=0, parse_dates=['datetime'])

yhat, yhat_quantiles = model.predict(df_test, quantiles=(0.025, 0.5, 0.975))