The different implementations of the tridiagonal solver can be found in the tridiag folder. The C++/CUDA code in the cpp subfolder.
The optimized superbee kernels are in this file.
The futhark code for this routine is in this file.
The code for interfacing CUDA with Jax is located in this folder and can be installed with the command pip install -e jax_xla
Of course this needs the Jax (https://github.com/google/jax) package to be installed.
For compiling the CUDA kernels through the python setuptools I used code from this repository.
And similar integration into Jax can also be found here.