Giter Club home page Giter Club logo

training-tricks-for-binarized-neural-networks's Introduction

Training-Tricks-for-Binarized-Neural-Networks

The collection of training tricks of binarized neural networks from previously published/pre-print work on binary networks.

1. Modified ResNet block structure

class BinActiveF(torch.autograd.Function):
    def forward(self, input):
        self.save_for_backward(input)
        input = input.sign()
        return input

    def backward(self, grad_output):
        input, = self.saved_tensors
        grad_output[input.ge(1.0)] = 0.
        grad_output[input.le(-1.0)] = 0.
        return grad_output

class BinActive(nn.Module):
    def __init__(self, bin=True):
        super(BinActive, self).__init__()
        self.bin = bin
    def forward(self, x):
        if self.bin:
            x = BinActiveF()(x)
        else:
            x = F.relu(x, inplace=True)
        return x
        
class BasicBlock(nn.Module):
    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.bn = nn.BatchNorm2d(inplanes)
        self.ba = BinActive()
        self.conv = nn.conv2d(inplanes, planes, 3, stride)
        self.prelu = nn.PReLU(planes)
        self.downsample = downsample

    def forward(self, x):
        residual = x

        out = self.bn(x)
        out = self.ba(out)
        out = self.conv(out)
        
        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.prelu(out)

        return out

class Bottleneck(nn.Module):
    def __init__(self, inplanes, planes, stride=1, has_branch=True):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.conv2d(inplanes, planes, kernel_size=1, stride=stride, padding=0)
        self.conv2 = nn.conv2d(planes, planes, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.conv2d(planes, planes*4, kernel_size=1, stride=1, padding=0)

        self.bn1 = nn.BatchNorm2d(inplanes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.bn3 = nn.BatchNorm2d(planes)
        self.bn4 = nn.BatchNorm2d(planes*4)
        
        self.ba1 = BinActive()

        self.has_branch = has_branch
        self.stride = stride

        if self.has_branch:
            if self.stride == 1:
                self.bn_bran1 = nn.Sequential(
                    nn.Conv2d(inplanes, planes*4, kernel_size=1, stride=1, padding=0, bias=False),
                    nn.BatchNorm2d(planes*4, eps=1e-4, momentum=0.1, affine=True),
                    nn.AvgPool2d(kernel_size=3, stride=1, padding=1))
                self.prelu = nn.PReLU(planes*4)
            else:
                self.branch1 = nn.Sequential(
                    nn.Conv2d(inplanes, inplanes*2, kernel_size=1, stride=1, padding=0, bias=False),
                    nn.BatchNorm2d(inplanes*2, eps=1e-4, momentum=0.1, affine=True),
                    nn.AvgPool2d(kernel_size=2, stride=2))
                self.prelu = nn.PReLU(inplanes*2)

    def forward(self, x):
        if self.stride == 2:
            short_cut = self.branch1(x)
        else:
            if self.has_branch:
                short_cut = self.bn_bran1(x)
            else:
                short_cut = x

        out = self.bn1(x)
        out = self.ba1(out)
        out = self.conv1(out)
        add = out
        out = self.bn2(out)

        out = self.ba1(out)
        out = self.conv2(out)
        out += add
        out = self.bn3(out)

        out = self.ba1(out)
        out = self.conv3(out)
        out = self.bn4(out)
        out += short_cut
        
        if self.has_branch:
            out = self.prelu(out)
        return out

2. PReLU Activation

Please refer to the above structures.

3. Double skip connection

Replace the original basic block in ResNet18 with two BasicBlock mentioned above.

4. Full precision downsampling layers

# v1 (recommended)
downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, \
                                 stride=1, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
                nn.AvgPool2d(kernel_size=2, stride=2)
            )
# v2
downsample = nn.Sequential(
                nn.AvgPool2d(kernel_size=2, stride=2)
                nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, \
                                 stride=1, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )
# v3
downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, \
                                 stride=1, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
                nn.AvgPool2d(kernel_size=3, stride=2, padd=1)
            )

5. 2-stage training strategy

  • Full-precision weights with binarized activations./ Full-precision activations with binarized weights.
  • Using the first stage model as initialization, then train 1-bit networks.

6. Weight decay setting

  • 1e-5 for stage 1.
  • 0.0 for stage 2.

7. Optimizer

  • Adam with stepwise scheduler.
optimizer = torch.optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay)
  • Adam with cosine decrease (init lr=1e-3).

8. Learning rate

  • 1e-3 for stage 1.
  • 2e-4 for stage 2.
  • CIFAR-100 : *0.2 at 150th, 250th, 320th epochs. End at 350 epoch.
  • ImageNet : *0.1 at 40th, 60th, 70th epochs. End at 75 epoch.
parser.add_argument('--schedule', type=int, nargs='+', default=[150, 225],
                    help='Decrease learning rate at these epochs.')
parser.add_argument('--gammas', type=float, nargs='+', default=[0.1, 0.1],
                    help='LR is multiplied by gamma on schedule, number of gammas should be equal to schedule')

def adjust_learning_rate(optimizer, epoch, gammas, schedule):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    lr = args.lr
    assert len(gammas) == len(schedule), "length of gammas and schedule should be equal"
    for (gamma, step) in zip(gammas, schedule):
        if (epoch >= step):
            lr = lr * gamma
        else:
            break
    print('learning rate : %.6f.' % lr)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr

9. Data augmentation

  • CIFAR-100: random crop, random horizontal flip, random rotation (+/-15 degree), mix-up/auto augmentation.
transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(15),
            transforms.ToTensor(),
            normalize,
        ]
  • ImageNet: random crop, random flip, colour jitter (only in first stage, disabled for stage 2).
transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
            transforms.ToTensor(),
            normalize,
        ]
  • ImageNet: Lighting.
#lighting data augmentation
imagenet_pca = {
    'eigval': np.asarray([0.2175, 0.0188, 0.0045]),
    'eigvec': np.asarray([
        [-0.5675, 0.7192, 0.4009],
        [-0.5808, -0.0045, -0.8140],
        [-0.5836, -0.6948, 0.4203],
    ])
}

class Lighting(object):
    def __init__(self, alphastd,
                 eigval=imagenet_pca['eigval'],
                 eigvec=imagenet_pca['eigvec']):
        self.alphastd = alphastd
        assert eigval.shape == (3,)
        assert eigvec.shape == (3, 3)
        self.eigval = eigval
        self.eigvec = eigvec

    def __call__(self, img):
        if self.alphastd == 0.:
            return img
        rnd = np.random.randn(3) * self.alphastd
        rnd = rnd.astype('float32')
        v = rnd
        old_dtype = np.asarray(img).dtype
        v = v * self.eigval
        v = v.reshape((3, 1))
        inc = np.dot(self.eigvec, v).reshape((3,))
        img = np.add(img, inc)
        if old_dtype == np.uint8:
            img = np.clip(img, 0, 255)
        img = Image.fromarray(img.astype(old_dtype), 'RGB')
        return img

    def __repr__(self):
        return self.__class__.__name__ + '()'

lighting_param = 0.1
train_transforms = transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(0.08, 1.0)),
        Lighting(lighting_param),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize])

10. Momentum in Batch Normalization layers

nn.BatchNorm2d(128, momentum=0.2, affine=True),

11. Reorder pooling block

From Conv+BN+ReLU+Pooling to Conv+Pooling+BN+ReLU.

12. Knowledge-distillation

  • KL divergence matching.
  • Feature-map matching after L2 normalization. equation
  • Label refinery

13. Channel-attention

x = BN(x)
out = x.sign()
out = conv(out)
out *= SE(x) # SE() generates [batchsize x C x 1 x 1] attention tensor
out = prelu(out)

where SE could be any channel attention module, such as SE-Net, CGD, CBAM, BAM, etc.

14. Auxiliary loss function

  • Center loss for stage 2 (marginal improvements).
  • Cross Entropy loss with labelsmooth (~+0.5% Top-1).
criterion_smooth = CrossEntropyLabelSmooth(num_classes=1000, epsilon=0.1).cuda()

class CrossEntropyLabelSmooth(nn.Module):

  def __init__(self, num_classes, epsilon):
    super(CrossEntropyLabelSmooth, self).__init__()
    self.num_classes = num_classes
    self.epsilon = epsilon
    self.logsoftmax = nn.LogSoftmax(dim=1)

  def forward(self, inputs, targets):
    log_probs = self.logsoftmax(inputs)
    targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
    targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
    loss = (-targets * log_probs).mean(0).sum()
    return loss

15. Double/treble channel number

16. Full-precision pre-training

  • step 1. replace relu with the following leaky-clip
index = x.abs()>1.
x[index] = x[index]*0.1+x[index].sign()*0.9 
  • step 2. replace leaky-clip with x.clamp_(-1,1)

Cite:

If you find this repo useful, please cite

@misc{tricks4BNN,
  author =       {Shuan},
  title =        {Training-Tricks-for-Binarized-Neural-Networks},
  howpublished = {\url{https://github.com/HolmesShuan/Training-Tricks-for-Binarized-Neural-Networks}},
  year =         {2019}
}

training-tricks-for-binarized-neural-networks's People

Contributors

holmesshuan avatar

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.