-
See Also other CUDA implementations of KAN:
-
I am interested in the performance aspect of KAN, and willing to discuss / recieve more information on this topic :)
-
This is for personal practice purposes, use at your own risk. Tested on my RTX3050 as well as a remote RTX3090 on CUDA 12.x .
CUDA implementation of the paper introducing Wavelet KAN at https://arxiv.org/abs/2405.12832.
This is significantly faster than the original implementation, with ~50x performance forward and 5x performance backward, results given by benchmark scripts in https://github.com/Jerry-Master/KAN-benchmarking.
| forward | backward | forward | backward | num params | num trainable params
--------------------------------------------------------------------------------------------------------------------
cuda-gpu | 29.10 ms | 65.27 ms | 0.28 GB | 1.03 GB | 3151362 | 3151362
orig-gpu | 522.00 ms | 1461.29 ms | 5.53 GB | 5.53 GB | 3151362 | 3151362
-
There are no optimizations in this implementation. I am a cuda beginner and willing to receive optimization suggestions : )
-
Currently Mexican hat and Morlet are implemented.
- Install
pip install -e .
Make sure the version of nvcc in PATH is compatible with your current PyTorch version (it seems minor version difference is OK).
-
Run
- Run test on MNIST:
python test.py
-
Benchmark
python benchmark.py --method all --reps 100 --just-cuda
- Morlet wavelet performs badly in MNIST, but if you use a shallow net, you can observe it learn.