Giter Club home page Giter Club logo

attorch's Introduction

attorch

Introduction
Installation
Layers
PyTorch Fallback
Tests

Introduction

attorch is a subset of PyTorch's nn module, written purely in Python using OpenAI's Triton. Its goal is to be an easily hackable, self-contained, and readable collection of neural network modules whilst maintaining or improving upon the efficiency of PyTorch. In other words, it intends to be a forkable project endowed with a simple, intuitive design that can serve as an accessible starting point for those who are seeking to develop custom deep learning operations but are not satisfied with the speed of a pure PyTorch implementation and do not have the technical expertise or resources to write CUDA kernels.

There already exist a number of wonderful PyTorch-like frameworks powered by Triton, including kernl, xFormers, Unsloth, and fla, but most concentrate mainly on Transformers and NLP applications, whereas attorch aims to be more inclusive by also presenting a variety of layers pertaining to areas besides NLP such as computer vision. Moreover, attorch is not an inference-only package and fully supports both forward and backward passes, meaning it can be used during training as well as inference, though its performance for the latter is generally not on par with dedicated inference engines.

Installation

The only dependencies of attorch are torch==2.2.0 and triton==2.2.0. Please install the specified versions of these two libraries and clone this repository to get started.

Layers

Currently implemented layers, with automatic mixed precision (AMP) support, are,

  • attorch.Conv2d: 2D-convolves over the input using weights, optionally adding bias.
  • attorch.MultiheadAttention: Applies multi-headed scaled dot-product attention to the inputs.
  • attorch.GELU: Applies GELU to the input, optionally fusing dropout.
  • attorch.ReLU: Applies ReLU to the input, optionally fusing dropout.
  • attorch.SiLU: Applies SiLU to the input, optionally fusing dropout.
  • attorch.Sigmoid: Applies sigmoid to the input, optionally fusing dropout.
  • attorch.Tanh: Applies tanh to the input, optionally fusing dropout.
  • attorch.GLU: Applies the gated linear unit with an arbitrary activation function to the input.
  • attorch.LogSoftmax: Normalizes the input using softmax and takes its log.
  • attorch.Softmax: Normalizes the input using softmax.
  • attorch.Softmin: Normalizes the input using softmin.
  • attorch.BatchNorm1d: Batch-normalizes the 2D or 3D input, optionally fusing an activation function and adding a residual to the pre-activation result.
  • attorch.BatchNorm2d: Batch-normalizes the 4D input, optionally fusing an activation function and adding a residual to the pre-activation result.
  • attorch.LayerNorm: Layer-normalizes the input.
  • attorch.Linear: Linearly transforms the input using weights, optionally adding bias and fusing an activation function.
  • attorch.Dropout: Randomly zeroes elements in the input during training.
  • attorch.L1Loss: Measures the mean absolute error between the input and target.
  • attorch.MSELoss: Measures the mean squared error between the input and target.
  • attorch.CrossEntropyLoss: Measures the mean cross entropy loss between the input and target, with optional reweighing of each class.
  • attorch.NLLLoss: Measures the negative log likelihood loss between the input and target, with optional reweighing of each class.

Unless otherwise noted in their docstrings, the aforementioned layers behave identically to their PyTorch equivalents.

PyTorch Fallback

To enable easier integration of attorch and PyTorch layers, attorch.nn is offered, which provides an interface to attorch's modules with PyTorch fallback should a desired layer not be available, as seen below.

from attorch import nn


lin = nn.Linear(10, 20) # Uses attorch's linear layer
gap = nn.AdaptiveAvgPool2d(1) # Uses PyTorch's global pooling since GAP is not available in attorch

Tests

Each module can be tested against its PyTorch counterpart to ensure correctness. These tests are included under tests/ and can be executed using pytest. It should be noted that some might fail owing to numerical precision issues, but in most practical use cases, that should not be a problem.

attorch's People

Contributors

bobmcdear avatar

Stargazers

hlc avatar Shida Wang avatar Nicholas Santavas avatar 이루리 avatar Ignatii  Dubyshkin avatar Janki avatar  avatar Eric Machmer avatar jinczing avatar Raevskiy Rudolf avatar A.J avatar Krishna Sirumalla avatar Vinh Tran avatar Hans Brouwer avatar Chih-Hao Liu avatar peaceorwell avatar  avatar Wiktor Jakubowski avatar Brian T avatar Pastel! avatar Uday Sankar avatar Dante Oz avatar Benjamin Warner avatar greensh avatar  avatar  avatar Leo avatar Geon Moo avatar Kim Jae-Jin (김재진) avatar Wenxiang avatar  avatar Balint avatar Wenrui Zhang avatar Dinghao Zhou avatar shadow_of_ged avatar Lucca Zenóbio avatar Yuchao Zhang avatar Xingchen Song(宋星辰) avatar Roy Hvaara avatar Shreyan Sanyal avatar Renat Zayashnikov avatar Brian Yin avatar haosdent avatar Clayton Kehoe avatar Aayushman Choudhary avatar David Gidwani avatar Valeriu Lacatusu avatar Rohan Paul avatar Richard avatar Thibaut Durand avatar Yukuan Lu avatar Sarah Johnson avatar Rafael Celente avatar Hossein Askari avatar Benjamin Anderson avatar Qian Yangyang (Yangyang) avatar Nick Konovalchuk avatar Hiroto Kurita avatar Mario García Mayo avatar Ersi Ni avatar Vanya Rubachev avatar Dr Limitless (rahim) avatar Zhou Fang avatar Mufan Qiu avatar elucida avatar Afroz Mohiuddin avatar Hongwei Fan avatar Haizhong avatar  avatar Chen Yingfa avatar Andrew Mendez avatar Martin Kim avatar  avatar Gavia Gray avatar Humair Raj Khan avatar Artem Chumachenko avatar Pingzhi Li avatar Ther avatar Kenny Falkær Olsen avatar Ryuichiro Hataya avatar  avatar neos avatar Ligeng Zhu avatar  avatar Doyup Lee avatar Sangshin Oh avatar Mao Yunfei avatar Tongzhou Wang avatar Daniel Bashir avatar Shamima Hossain avatar Jingcheng Hu avatar Radi Radev avatar Amantur Amatov avatar Eric Alcaide avatar Fangkai Jiao avatar  avatar Jovan Sardinha avatar Shreyas Jaiswal avatar Francesco Martinuzzi avatar Markus Hennerbichler avatar

Watchers

Vishal Goklani avatar Mike avatar Hao Zhuang avatar Ruixiang Zhang avatar Kostas Georgiou avatar 이루리 avatar  avatar hiyyg avatar

attorch's Issues

Difference between BLOCK and GROUP?

Hello, I am always really interested in Triton and GPU programming. I just starting to learn about them actually. I stumbled upon your project and first of all, Attorch is really really awesome and thank you for creating and opening this awesome project!

Currently, I am trying to understand the code, and I saw this GROUP concept which I having a hard time imagining this and why we need it to parallelize the kernel. For me, they pretty similar to BLOCK (based on the naming too :D). Perhaps can you explain of what's GROUP in Triton or in GPU programming in general?

Thank you in advance!

Fusing?

This project is really neat.

Curious if you considered seperating out the load / store logic from the math to make it more easy to fuse operations? For instance would be nice to be able to use dropout (forward / backward) from within a NN+dropout fused function.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.