Giter Club home page Giter Club logo

cuda-wavelet-kan's Introduction

CUDA implementation of Wavelet KAN

  • See Also other CUDA implementations of KAN:

  • I am interested in the performance aspect of KAN, and willing to discuss / recieve more information on this topic :)

  • This is for personal practice purposes, use at your own risk. Tested on my RTX3050 as well as a remote RTX3090 on CUDA 12.x .

Introduction

CUDA implementation of the paper introducing Wavelet KAN at https://arxiv.org/abs/2405.12832.

This is significantly faster than the original implementation, with ~50x performance forward and 5x performance backward, results given by benchmark scripts in https://github.com/Jerry-Master/KAN-benchmarking.

          |      forward  |     backward  |      forward  |     backward  |   num params  |  num trainable params
--------------------------------------------------------------------------------------------------------------------
cuda-gpu  |     29.10 ms  |     65.27 ms  |      0.28 GB  |      1.03 GB  |      3151362  |               3151362
orig-gpu  |    522.00 ms  |   1461.29 ms  |      5.53 GB  |      5.53 GB  |      3151362  |               3151362

Note

  • There are no optimizations in this implementation. I am a cuda beginner and willing to receive optimization suggestions : )

  • Currently Mexican hat and Morlet are implemented.

Start

  1. Install
pip install -e .

Make sure the version of nvcc in PATH is compatible with your current PyTorch version (it seems minor version difference is OK).

  1. Run

    • Run test on MNIST:
    python test.py
  2. Benchmark

python benchmark.py --method all --reps 100 --just-cuda

Please remind:

  1. Morlet wavelet performs badly in MNIST, but if you use a shallow net, you can observe it learn.

cuda-wavelet-kan's People

Contributors

da1sypetals avatar juvi21 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.