Giter Club home page Giter Club logo

lightcapsnet's Introduction

light-CapsNet

背景

2017年11月に提案されたカプセルネットワークの実行速度を上げる. 論文へのリンクはこちら

変更点

  • squash関数をベクトル版step関数に変更
  • routingアルゴリズムの変数bを削除

■step関数

def step(vectors, axis=-1):
    """
    カプセルネットワークでは非線形の活性化関数が使用される. この関数はベクトルの長さを0~1に圧縮する.
    :param vectors: 圧出される複数のベクトル, 4次元テンソル
    :param axis: 圧縮する軸
    :return: 複数の入力ベクトルと同じ形の一つのテンソル
    """
    s_squared_norm = K.sum(K.square(vectors), axis, keepdims=True)
    return vectors / s_squared_norm

■light-routingアルゴリズム

# 前処理として係数を1に初期化
# c.shape = [None, self.num_capsule, self.input_num_capsule].
c = tf.ones(shape=[K.shape(inputs_hat)[0], self.num_capsule, self.input_num_capsule])

assert self.routings > 0, 'The routings should be > 0.'
for i in range(self.routings):
    # inputs_hat.shape=[None, num_capsule, input_num_capsule, dim_capsule]
    # The first two dimensions as `batch` dimension,
    # then matmal: [input_num_capsule] x [input_num_capsule, dim_capsule] -> [dim_capsule].
    # outputs.shape=[None, num_capsule, dim_capsule]
    outputs = step(K.batch_dot(c, inputs_hat, [2, 2]))  # [None, 10, 16]

    if i < self.routings - 1:
        # outputs.shape =  [None, num_capsule, dim_capsule]
        # inputs_hat.shape=[None, num_capsule, input_num_capsule, dim_capsule]
        # The first two dimensions as `batch` dimension,
        # then matmal: [dim_capsule] x [input_num_capsule, dim_capsule]^T -> [input_num_capsule].
        # c.shape=[batch_size, num_capsule, input_num_capsule]
        c += K.batch_dot(outputs, inputs_hat, [2, 3])

結果

■実行環境

google colabのGPUを使用.

■実行速度

変更点 時間 精度
変更なし 3分36秒 loss: 0.1284 - capsnet_loss: 0.1013 - decoder_loss: 0.0690 - capsnet_acc: 0.8994 - val_loss: 0.0445 - val_capsnet_loss: 0.0242 - val_decoder_loss: 0.0516 - val_capsnet_acc: 0.9878
step関数 3分35秒 loss: 0.2779 - capsnet_loss: 0.2505 - decoder_loss: 0.0699 - capsnet_acc: 0.7014 - val_loss: 0.0679 - val_capsnet_loss: 0.0487 - val_decoder_loss: 0.0490 - val_capsnet_acc: 0.9776
light-routingアルゴリズム 3分33秒 loss: 0.2068 - capsnet_loss: 0.1794 - decoder_loss: 0.0699 - capsnet_acc: 0.8095 - val_loss: 0.0542 - val_capsnet_loss: 0.0364 - val_decoder_loss: 0.0456 - val_capsnet_acc: 0.9829
どちらも 3分31秒 loss: 0.2227 - capsnet_loss: 0.1949 - decoder_loss: 0.0710 - capsnet_acc: 0.7932 - val_loss: 0.1243 - val_capsnet_loss: 0.1060 - val_decoder_loss: 0.0466 - val_capsnet_acc: 0.8827

使用方法

■Step 1. インストール

TensorFlow>=1.2 Keras>=2.0.7をインストール

pip install tensorflow-gpu
pip install keras

■Step 2. リポジトリをクローン

git clone https://github.com/XifengGuo/CapsNet-Keras.git capsnet-keras
cd capsnet-keras

■Step 3. 実行

デフォルト設定

python capsulenet.py

ヘルプ機能

python capsulenet.py -h

■Step 4. モデル検証

下記のコマンドでresult/trained_model.h5にモデルを保存することができます.

$ python capsulenet.py -t -w result/trained_model.h5

テストaccuracyと再構成された画像を出力してくれます.

学習済みモデルのダウンロードはこちら

■Step 5. GPUで学習

(注)Keras 2.0.9が必要ですので満たしていない方はアップデートをしてください.

python capsulenet-multi-gpu.py --gpus 2

このコマンドで自動的にGPUを用いて処理してくれます. なお,トレーニング中はaccuracyを出力しません.

別の手法

リンク集

lightcapsnet's People

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.