quadax
quadax is a library for numerical quadrature and integration using JAX.
vmap
-able,jit
-able, differentiable.Scalar or vector valued integrands.
Finite or infinite domains with discontinuities or singularities within the domain of integration.
Globally adaptive Gauss-Konrod and Clenshaw-Curtis quadrature for smooth integrands (similar to
scipy.integrate.quad
)Adaptive tanh-sinh quadrature for singular or near singular integrands.
Quadrature from sampled values using trapezoidal and Simpsons methods.
Coming soon:
Custom JVP/VJP rules (currently AD works by differentiating the loop which isn’t the most efficient.)
N-D quadrature (cubature)
QMC methods
Integration with weight functions
Sparse grids (maybe, need to play with data structures and JAX)
Installation
quadax is installable with pip:
pip install quadax
Usage
import jax.numpy as jnp
import numpy as np
from quadax import quadgk
fun = lambda t: t * jnp.log(1 + t)
epsabs = epsrel = 1e-5 # by default jax uses 32 bit, higher accuracy requires going to 64 bit
a, b = 0, 1
y, info = quadgk(fun, [a, b], epsabs=epsabs, epsrel=epsrel)
assert info.err < max(epsabs, epsrel*abs(y))
np.testing.assert_allclose(y, 1/4, rtol=epsrel, atol=epsabs)
For full details of various options see the API documentation
Which method should I choose?
Can you evaluate the integrand at an arbitary point?
To start, quadgk
or quadcc
are probably your best options, and are similar to
methods in QUADPACK (or scipy.integrate.quad
). quadgk
is usually the most efficient
for very smooth integrands (well approximated by a high degree polynomial), quadcc
tends to be slightly more efficient for less smooth integrands. If both of those don’t
perform well, you should think about your integrand a bit more:
Does your integrand have badly behaved singularites at the endpoints? Use
quadts
orrombergts
Is your integrand only piecewise smooth or piecewise continuous? Use
romberg
orrombergts
Do you only know your integrand at discrete points?
Use
trapezoid
orsimspson
Notes on parallel efficiency
Adaptive algorithms are inherently somewhat sequential, so perfect parallelism
is generally not achievable. romberg
and rombergts
are fully sequential, due to
limitiations on dynamically sized arrays in JAX. All of the quad*
methods are parallelized
on a local level (ie, for each sub-interval, the function evaluations are vectorized).
This means that quad*
methods will evaluate the integrand in batch sizes of order
,
and hence higher order methods will tend to be more efficient on GPU/TPU. However, if the
integrand is not sufficiently smooth, using a higher order method can slow down convergence,
particularly for quadgk
, quadts
and quadcc
are somewhat less sensitive to the
smoothness of the integrand.