Giter Club home page Giter Club logo

Comments (3)

naturomics avatar naturomics commented on May 28, 2024

I didn't encounter such problem. Could you tell me these:

  1. except for the data reading pipeline, did you modify the code and which part;
  2. Is the number of training samples for each class balanced?

from capslayer.

Alek-dr avatar Alek-dr commented on May 28, 2024

Dataset is balanced
but I wrote my own test code

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.python.framework.errors_impl import OutOfRangeError
from sklearn.metrics import accuracy_score, confusion_matrix
from seaborn import heatmap
import capslayer as cl
import os
import cv2

chars_map = {
    1: '1',
    2: '2',
    3: '3',
    4: '4',
    5: '5',
    6: '6',
    7: '7',
    8: '8',
    9: '9',
    10: '0',
    11: "A",
    12: "B",
    13: "C",
    14: "E",
    15: "H",
    16: "K",
    17: "M",
    18: "P",
    19: "T",
    20: "X",
    21: "Y"
}

WIDTH = 20
HEIGHT = 30

def elem_conv(elem):
    image = elem['images']
    image = np.reshape(image, newshape=(HEIGHT,WIDTH))
    label = elem['labels']
    return image, label

def parse_fn(serialized_example):
    features = tf.parse_single_example(serialized_example,
                                       features={'image': tf.FixedLenFeature([], tf.string),
                                                 'label': tf.FixedLenFeature([], tf.int64),
                                                 'height': tf.FixedLenFeature([], tf.int64),
                                                 'width': tf.FixedLenFeature([], tf.int64),
                                                 'depth': tf.FixedLenFeature([], tf.int64)})
    height = tf.cast(features['height'], tf.int32)
    width = tf.cast(features['width'], tf.int32)
    depth = tf.cast(features['depth'], tf.int32)
    image = tf.decode_raw(features['image'], tf.float32)
    image = tf.reshape(image, shape=[height * width * depth])
    image.set_shape([HEIGHT * WIDTH * 1])
    image = tf.cast(image, tf.float32) * (1. / 255)
    label = tf.cast(features['label'], tf.int32)
    features = {'images': image, 'labels': label}
    return (features)

def get_model():
    # Vector CapsNet
    num_label = 21
    in_images = tf.placeholder(tf.float32, [None, HEIGHT*WIDTH])

    with tf.variable_scope('Conv1_layer'):
        # Conv1, return with shape [batch_size, 20, 20, 256]
        inputs = tf.reshape(in_images, shape=[-1, HEIGHT, WIDTH, 1])
        conv1 = tf.layers.conv2d(inputs,
                                 filters=256,
                                 kernel_size=9,
                                 strides=1,
                                 padding='VALID',
                                 activation=tf.nn.relu)

    with tf.variable_scope('PrimaryCaps_layer'):
        primaryCaps, activation = cl.layers.primaryCaps(conv1,
                                                        filters=32,
                                                        kernel_size=9,
                                                        strides=2,
                                                        out_caps_dims=[8, 1],
                                                        method="norm")

    with tf.variable_scope('DigitCaps_layer'):
        routing_method = "EMRouting"
        num_inputs = np.prod(cl.shape(primaryCaps)[1:4])
        primaryCaps = tf.reshape(primaryCaps, shape=[-1, num_inputs, 8, 1])
        activation = tf.reshape(activation, shape=[-1, num_inputs])
        poses, probs = cl.layers.dense(primaryCaps,
                                                 activation,
                                                 num_outputs=num_label,
                                                 out_caps_dims=[16, 1],
                                                 routing_method=routing_method)

        # Decoder structure
        # Reconstructe the inputs with 3 FC layers
    with tf.variable_scope('Decoder'):
        logits_idx = tf.to_int32(tf.argmax(cl.softmax(probs, axis=1), axis=1))
        labels = tf.one_hot(logits_idx, depth=num_label, axis=-1, dtype=tf.float32)
        labels_one_hoted = tf.reshape(labels, (-1, num_label, 1, 1))
        masked_caps = tf.multiply(poses, labels_one_hoted)
        num_inputs = np.prod(masked_caps.get_shape().as_list()[1:])
        active_caps = tf.reshape(masked_caps, shape=(-1, num_inputs))
        fc1 = tf.layers.dense(active_caps, units=512, activation=tf.nn.relu)
        fc2 = tf.layers.dense(fc1, units=1024, activation=tf.nn.relu)
        num_outputs = HEIGHT * WIDTH * 1
        recon_imgs = tf.layers.dense(fc2,
                                          units=num_outputs,
                                          activation=tf.sigmoid)
        recon_imgs = tf.reshape(recon_imgs, shape=[-1, HEIGHT, WIDTH, 1])

    return in_images, recon_imgs, probs

def show_reconstruct(original, reconstruct, true_lbl, pred_lbl, lbl=None):
    if lbl is not None:
        ind = np.where(true_lbl==lbl)[0][0]
    else:
        ind = 0
    original_image = original[ind]
    reconstruct_image = reconstruct[0]
    original_image = np.reshape(original_image, newshape=(HEIGHT,WIDTH))
    true_lbl = chars_map[true_lbl[ind]]
    pred_lbl = chars_map[pred_lbl[ind]]
    title = true_lbl + 20*' ' + pred_lbl
    res = np.hstack((original_image, reconstruct_image))
    plt.imshow(res, cmap='gray')
    plt.title(title)
    plt.show()

def test(records):
    """
    param records: list of .record files
    """
    batch_size = 128
    dataset = tf.data.TFRecordDataset(records)
    dataset = dataset.map(parse_fn).batch(batch_size).repeat(1).shuffle(buffer_size=5000, seed=3)
    iterator = dataset.make_one_shot_iterator()
    next_element = iterator.get_next()
    inputs, recon_imgs, labels_one_hoted = get_model()

    saver = tf.train.Saver()

    true_labels, predicted = [], []
    with tf.Session() as sess:
        ckpt = tf.train.get_checkpoint_state(os.path.dirname('../models/models/results/logdir/model.ckpt-6600'))
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)

        while True:
            try:
                elem = sess.run(next_element)

                raw_images = elem['images']
                true_lbls = elem['labels'] + 1

                reconstructed, pred_lbls = sess.run([recon_imgs,labels_one_hoted], feed_dict={inputs : raw_images})

                reconstructed = np.squeeze(reconstructed)
                pred_lbls = np.argmax(pred_lbls, axis=1) + 1

                #show_reconstruct(raw_images, reconstructed, true_lbls, pred_lbls, lbl=9)

                predicted.extend(pred_lbls)
                true_labels.extend(true_lbls)

            except OutOfRangeError as ex:
                break

    labels_id = np.arange(1, 11, 1).astype(np.int16)
    labels = [chars_map[lbl] for lbl in labels_id]
    conf_matr = confusion_matrix(true_labels, predicted)
    ax = heatmap(conf_matr, annot=True, fmt='d')
    ax.set_xticklabels(labels)
    ax.set_yticklabels(labels)
    plt.show()

if __name__ == '__main__':
    records = ['data/symbols/eval_symbols.tfrecord']
    test(records=records)

Checkpoint 6600 is not matter, I trained 50000 steps, and problem was the same

from capslayer.

naturomics avatar naturomics commented on May 28, 2024

The code looks fine. I'm not sure what the problem is. What I can guess is that capsnet might have bias problem. But before making this conclusion, I suggest:

  1. visualize the input image and print its corresponding label (both for training and validation set) to make sure the dataset is right;
  2. then remove one or more class from dataset (not the 7th class), and train the model from scratch and test it again. To see if the 7th or any others class were never predicted.

Of course it might be an implementation problem, I will check my code again.

from capslayer.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.