Giter Club home page Giter Club logo

cudarelativeattention's Introduction

Efficient Relative Attention

How to use

from relative_attention import RelativeAttention2d
# you can also use RelativeAttention1d and RelativeAttention3d
net = RelativeAttention2d(num_heads, model_depth, max_relative_positions_past=[width, height],
                          max_relative_positions_future=None, # same as past
                          heads_share_relative_embeddings=[True, False], # share in width but not height
                          # extend embedding by using sin and cosine for the width dim, and zero padding for h
                          embedding_padding_modes=[EmbeddingPaddingMode.Extend, EmbeddingPaddingMode.Zero],
                          position_embedding_types=PositionEmbeddingType.Fixed,
                          key_start_positions=KeyStartPosition.BeforeQuery, 
                          add_bias_to_query_for_relative_logits=True, # the D term in transformer-xl
                          add_bias_to_query_for_key_logit=True, # the C term in transformer-xl
                          # use my custom kernel or the vanilla pytorch implementation
                          use_custom_cuda_kernel=True).cuda() 
q = torch.randn(batch_size, num_heads, q_height, q_width, model_depth // num_heads).cuda()
k = torch.randn(batch_size, num_heads, k_height, k_width, model_depth // num_heads).cuda()
if use_mask:
  mask = (torch.randn(batch_size * num_heads, q_height * q_width, k_height * k_width) > 0).cuda()
  # or just (q_height * q_width, k_height * k_width)
else:
  mask = None
logits = net(q, k, mask)
print(logits.size()) # batch_size * num_heads, q_height * q_width, k_height * k_width

Algorithm

Efficient Relative Position Encoding

Reasoning

I was trying to use a relative position encoding in my 2d attention network and there wasn't a good implementation for pytorch, so I decided to adopted the tensor2tensor implementation into pytorch. Furthermore our architecture, uses this operation at each layer, so I decided to make it a bit more efficient by writing a custom cuda kernel. It's not a general purpose kernel and it might be slower than vanilla pytorch code, it depends on your GPU, your batch_size, query_size and query_dim, so profile it on your settings before using it.

How to profile

You can see how to profile it by checking the speed_check() and run_profiler() function in check.py. for example in my settings(large batch size with 32x32 patch images) I get 1.7x speedup in my forward and backward calls.

Further Improvements

I also tried to fuse the logit calculation in my kernel, but it was way too slow, compared to cublas. You can check my experiment in experiment.py, even tough it's really slow for training, it has good performance in small inference mode. For example with batch_size=1, path_size=8x8, num_heads=1, and model_depth=16 you can get 6.5x performance gain!

Embedding Class

DistanceEmbedding class is an awesome wrapper for most of the common usages, check it out! :))

References

cudarelativeattention's People

Contributors

separius avatar

Watchers

James Cloos avatar paper2code - bot avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.