quadax

License DOI GitHub issues Pypi

Documentation UnitTests Coverage

quadax is a library for numerical quadrature and integration using JAX.

  • vmap-able, jit-able, differentiable.

  • Scalar or vector valued integrands.

  • Finite or infinite domains with discontinuities or singularities within the domain of integration.

  • Globally adaptive Gauss-Konrod and Clenshaw-Curtis quadrature for smooth integrands (similar to scipy.integrate.quad)

  • Adaptive tanh-sinh quadrature for singular or near singular integrands.

  • Quadrature from sampled values using trapezoidal and Simpsons methods.

Coming soon:

  • Custom JVP/VJP rules (currently AD works by differentiating the loop which isn’t the most efficient.)

  • N-D quadrature (cubature)

  • QMC methods

  • Integration with weight functions

  • Sparse grids (maybe, need to play with data structures and JAX)

Installation

quadax is installable with pip:

pip install quadax

Usage

import jax.numpy as jnp
import numpy as np
from quadax import quadgk

fun = lambda t: t * jnp.log(1 + t)

epsabs = epsrel = 1e-5 # by default jax uses 32 bit, higher accuracy requires going to 64 bit
a, b = 0, 1
y, info = quadgk(fun, [a, b], epsabs=epsabs, epsrel=epsrel)
assert info.err < max(epsabs, epsrel*abs(y))
np.testing.assert_allclose(y, 1/4, rtol=epsrel, atol=epsabs)

For full details of various options see the API documentation

Which method should I choose?

Can you evaluate the integrand at an arbitary point?

To start, quadgk or quadcc are probably your best options, and are similar to methods in QUADPACK (or scipy.integrate.quad). quadgk is usually the most efficient for very smooth integrands (well approximated by a high degree polynomial), quadcc tends to be slightly more efficient for less smooth integrands. If both of those don’t perform well, you should think about your integrand a bit more:

  • Does your integrand have badly behaved singularites at the endpoints? Use quadts or rombergts

  • Is your integrand only piecewise smooth or piecewise continuous? Use romberg or rombergts

Do you only know your integrand at discrete points?

  • Use trapezoid or simspson

Notes on parallel efficiency

Adaptive algorithms are inherently somewhat sequential, so perfect parallelism is generally not achievable. romberg and rombergts are fully sequential, due to limitiations on dynamically sized arrays in JAX. All of the quad* methods are parallelized on a local level (ie, for each sub-interval, the function evaluations are vectorized). This means that quad* methods will evaluate the integrand in batch sizes of order, and hence higher order methods will tend to be more efficient on GPU/TPU. However, if the integrand is not sufficiently smooth, using a higher order method can slow down convergence, particularly for quadgk, quadts and quadcc are somewhat less sensitive to the smoothness of the integrand.

Indices and tables