Giter Club home page Giter Club logo

lambda-networks's Introduction

Lambda Networks - Pytorch

Implementation of λ Networks, a new approach to image recognition that reaches SOTA on ImageNet. The new method utilizes λ layer, which captures interactions by transforming contexts into linear functions, termed lambdas, and applying these linear functions to each input separately.

Yannic Kilcher's paper review

Install

$ pip install lambda-networks

Usage

Global context

import torch
from lambda_networks import LambdaLayer

layer = LambdaLayer(
    dim = 32,       # channels going in
    dim_out = 32,   # channels out
    n = 64,         # size of the receptive window - max(height, width)
    dim_k = 16,     # key dimension
    heads = 4,      # number of heads, for multi-query
    dim_u = 1       # 'intra-depth' dimension
)

x = torch.randn(1, 32, 64, 64)
layer(x) # (1, 32, 64, 64)

Localized context

import torch
from lambda_networks import LambdaLayer

layer = LambdaLayer(
    dim = 32,
    dim_out = 32,
    r = 23,         # the receptive field for relative positional encoding (23 x 23)
    dim_k = 16,
    heads = 4,
    dim_u = 4
)

x = torch.randn(1, 32, 64, 64)
layer(x) # (1, 32, 64, 64)

For fun, you can also import this as follows

from lambda_networks import λLayer

Tensorflow / Keras version

Shinel94 has added a Keras implementation! It won't be officially supported in this repository, so either copy / paste the code under ./lambda_networks/tfkeras.py or make sure to install tensorflow and keras before running the following.

import tensorflow as tf
from lambda_networks.tfkeras import LambdaLayer

layer = LambdaLayer(
    dim_out = 32,
    r = 23,
    dim_k = 16,
    heads = 4,
    dim_u = 1
)

x = tf.random.normal((1, 64, 64, 16)) # channel last format
layer(x) # (1, 64, 64, 32)

Citations

@inproceedings{
    anonymous2021lambdanetworks,
    title={LambdaNetworks: Modeling long-range Interactions without Attention},
    author={Anonymous},
    booktitle={Submitted to International Conference on Learning Representations},
    year={2021},
    url={https://openreview.net/forum?id=xTJEN-ggl1b},
    note={under review}
}

lambda-networks's People

Contributors

khanrc avatar lucidrains avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

lambda-networks's Issues

Image Size

Are non-square image blocks allowed for context? Using global context and a non-square dimensions (96, 128), I get an error on this line about dimension size.

λp = einsum('n m k u, b u v m -> b n k v', rel_pos_emb, v)

Use of Keras Lambda

Hey! Thank for the awesome implementations :D

I was wondering why the use of tf.keras.layers.Lambda? Seems unnecessary, regular calls to TF operations works and is more readable.

Lc = Lambda(lambda x: einsum('b u k m, b u v m -> b k v', x[0], x[1]))([k, v])

You can also call the functional version of the softmax instead.

Implementation of Lambda convolution

Thanks for the great work in the implementations!

I would like to ask whether there is a difference in using 'Conv2d' as suggested in Eq. 3 in the paper and your implementation of 'Conv3d'. These two convs treat the (h x w)-dimension as a 1-d sequence and a 2-d image, respectively. I believe they are quite different in concept.

Point out if I misunderstood.

Thx a lot.

question about hybrid lambdaResnet

Hi,

In the paper, there is this paragraph:

When working with hybrid LambdaNetworks, we use a single lambda layer in c4 for LambdaResNet50,
3 lambda layers for LambdaResNet101, 6 lambda layers for LambdaResNet-152/200/270/350 and
8 lambda layers for LambdaResNet-420.

I have several questions about constructing the hybrid lambdaResnet:

  1. Do we only need to replace the 3x3conv with lambda layer in the C4 stage rather than C4 and C5(as in the ablation study)?
  2. When there is more than 1 lambda layers, such as the case of LambdaResNet101, are we replacing the 3x3conv with 3 lambda layers? And in the resnet50 case, we replace the 3x3conv with 1 lambda layers ?

Please add clarity to code

so Phil - I love your work - I wish you could go extra few steps to help out users.
I found this class by François-Guillaume @frgfm - which adds in clear math coments.
I want to merge it but there's a bit of code drift don't want to introduce any bugs.
I beseech you to go extra step to help users bridge from papers to code.

https://github.com/frgfm/Holocron/blob/bcc3ea19a477e4b28dc5973cdbe92a9b05c690bb/holocron/nn/modules/lambda_layer.py

eg.
please articulate return types
def forward(self, x: torch.Tensor) -> torch.Tensor:

Please give any clarity in arguments.
# Project input and context to get queries, keys & values

Throw in some maths as a comment / this is great as it bridges the paper to the code.

B x (num_heads * dim_k) * H * W -> B x num_heads x dim_k x (H * W)

import torch
from torch import nn, einsum
import torch.nn.functional as F
from typing import Optional

__all__ = ['LambdaLayer']


class LambdaLayer(nn.Module):
    """Lambda layer from `"LambdaNetworks: Modeling long-range interactions without attention"
    <https://openreview.net/pdf?id=xTJEN-ggl1b>`_. The implementation was adapted from `lucidrains'
    <https://github.com/lucidrains/lambda-networks/blob/main/lambda_networks/lambda_networks.py>`.
    Args:
        in_channels (int): input channels
        out_channels (int, optional): output channels
        dim_k (int): key dimension
        n (int, optional): number of input pixels
        r (int, optional): receptive field for relative positional encoding
        num_heads (int, optional): number of attention heads
        dim_u (int, optional): intra-depth dimension
    """
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        dim_k: int,
        n: Optional[int] = None,
        r: Optional[int] = None,
        num_heads: int = 4,
        dim_u: int = 1
    ) -> None:
        super().__init__()
        self.u = dim_u
        self.num_heads = num_heads

        if out_channels % num_heads != 0:
            raise AssertionError('values dimension must be divisible by number of heads for multi-head query')
        dim_v = out_channels // num_heads

        # Project input and context to get queries, keys & values
        self.to_q = nn.Conv2d(in_channels, dim_k * num_heads, 1, bias=False)
        self.to_k = nn.Conv2d(in_channels, dim_k * dim_u, 1, bias=False)
        self.to_v = nn.Conv2d(in_channels, dim_v * dim_u, 1, bias=False)

        self.norm_q = nn.BatchNorm2d(dim_k * num_heads)
        self.norm_v = nn.BatchNorm2d(dim_v * dim_u)

        self.local_contexts = r is not None
        if r is not None:
            if r % 2 != 1:
                raise AssertionError('Receptive kernel size should be odd')
            self.padding = r // 2
            self.R = nn.Parameter(torch.randn(dim_k, dim_u, 1, r, r))
        else:
            if n is None:
                raise AssertionError('You must specify the total sequence length (h x w)')
            self.pos_emb = nn.Parameter(torch.randn(n, n, dim_k, dim_u))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        b, c, h, w = x.shape

        # Project inputs & context to retrieve queries, keys and values
        q = self.to_q(x)
        k = self.to_k(x)
        v = self.to_v(x)

        # Normalize queries & values
        q = self.norm_q(q)
        v = self.norm_v(v)

        # B x (num_heads * dim_k) * H * W -> B x num_heads x dim_k x (H * W)
        q = q.reshape(b, self.num_heads, -1, h * w)
        # B x (dim_k * dim_u) * H * W -> B x dim_u x dim_k x (H * W)
        k = k.reshape(b, -1, self.u, h * w).permute(0, 2, 1, 3)
        # B x (dim_v * dim_u) * H * W -> B x dim_u x dim_v x (H * W)
        v = v.reshape(b, -1, self.u, h * w).permute(0, 2, 1, 3)

        # Normalized keys
        k = k.softmax(dim=-1)

        # Content function
        λc = einsum('b u k m, b u v m -> b k v', k, v)
        Yc = einsum('b h k n, b k v -> b n h v', q, λc)

        # Position function
        if self.local_contexts:
            # B x dim_u x dim_v x (H * W) -> B x dim_u x dim_v x H x W
            v = v.reshape(b, self.u, v.shape[2], h, w)
            λp = F.conv3d(v, self.R, padding=(0, self.padding, self.padding))
            Yp = einsum('b h k n, b k v n -> b n h v', q, λp.flatten(3))
        else:
            λp = einsum('n m k u, b u v m -> b n k v', self.pos_emb, v)
            Yp = einsum('b h k n, b n k v -> b n h v', q, λp)

        Y = Yc + Yp
        # B x (H * W) x num_heads x dim_v -> B x (num_heads * dim_v) x H x W
        out = Y.permute(0, 2, 3, 1).reshape(b, self.num_heads * v.shape[2], h, w)
        return out

why flops so high?

I used resnet50, and change C4 layer into LambaBottleNeck; but why flops so high about 20G and input size is 224*244; is that right, or something wrong about my inplementation.

How to load_model correctly

Hi everyone,
I am struggling when loading this model saved in .h5 file.

How is the correct way to load this network?
If I use custom_objects I get init() got an unexpected keyword argument 'name'

Question on experiences with Efficiency

Hi @lucidrains,
Thanks a lot for providing this implementation so quickly.
I have a question regarding your (or other's) experience on the efficiency of Lambdalayers.
I tried to implement a LambdaUNet where I changed the 3x3conv layers with lambdalayers and avg pooling.

The Conv-UNet has 17Mio parameters while the LambdaUNet only 3Mio. Still, inference and training take much longer in the LambdaUNet than in the ConvUNet (approx 1s ConvUnet vs 10s Lamndaunet). I also used a receptive field of r=23. I am not sure where this parameter originates from or what receptive field should be set. In the paper, the authors talk about "controlled experiments". I assume they chose the lambdalayer hyperparameter (in some way) similar to the conv parameters? It is not very clear from the paper (at least from my initial reading).

I was wondering if others share my experience on slower training and inference time when blindly changing conv layer with lambda layers. Maybe someone can share his expertise on how I can control my LambdaUnet to be comparable to a regular UNet to reproduce the performance and efficiency results from the paper.
Thanks again

Code for long-range sequence?

Hi Lucidrains,
I wondered this LambdaLayer is for image-like 2-d tensor. How can I apply it for long-range sequence?

Thank you!

LambdaResNet Implementation?

I have been looking around and found one implementation of LambdaResNets, although there seem to be some metric performance problems and I've found wall-clock performance problems (runs ~7x slower than normal resnets).

Do you plan on putting out a lambdaresnet model in this repository?

Warning: Mixed memory format inputs detected while calling the operator.

I have added lambda layers to every block of Resnet, but the following warning will appear. Will it affect the result?

Warning: Mixed memory format inputs detected while calling the operator. The operator will output channels_last tensor even if some of the inputs are not in channels_last format. (function operator())

examples?

are there any examples or implementations of a full nework that we can try out? such as one of the lambda-resnets or even lambda-mask-rcnn etc?

Lambda for a sequence of images

Thanks for the quick implementation!

I have a problem where I have a sequence of images rather than 1. (a video)
So instead of having a dimension batch, channels, height, width, I also have after batch a length dimension to determine sequence length.

Given a known max_length (for the positional embedding), in forward, should conv4d be used instead of 3d to allow interaction between frames?

In the paper they do mention this could serve as a general framework for sequences of images, so I wonder if you explored that in implementation (where obviously a single image is just a case where length=1)

How lambda layer handle the downsample in LambdaResNet?

Hi,
Thanks for your clear code, i try to implement the LambdaResNet.
Does lambda layer replace all conv2d layer?
If so, how does lambda layer handle the downsample in conv2d, like stride=2?
Or just keep the conv2d if stride =2, replace only the conv2d layers in stride =1?

why flops so high

I used resnet50, and change C4 layer into LambaBottleNeck; but why flops so high about 20G and input size is 224*244; is that right, or something wrong about my inplementation.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.