Comments (3)
I didn't encounter such problem. Could you tell me these:
- except for the data reading pipeline, did you modify the code and which part;
- Is the number of training samples for each class balanced?
from capslayer.
Dataset is balanced
but I wrote my own test code
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.python.framework.errors_impl import OutOfRangeError
from sklearn.metrics import accuracy_score, confusion_matrix
from seaborn import heatmap
import capslayer as cl
import os
import cv2
chars_map = {
1: '1',
2: '2',
3: '3',
4: '4',
5: '5',
6: '6',
7: '7',
8: '8',
9: '9',
10: '0',
11: "A",
12: "B",
13: "C",
14: "E",
15: "H",
16: "K",
17: "M",
18: "P",
19: "T",
20: "X",
21: "Y"
}
WIDTH = 20
HEIGHT = 30
def elem_conv(elem):
image = elem['images']
image = np.reshape(image, newshape=(HEIGHT,WIDTH))
label = elem['labels']
return image, label
def parse_fn(serialized_example):
features = tf.parse_single_example(serialized_example,
features={'image': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64),
'height': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'depth': tf.FixedLenFeature([], tf.int64)})
height = tf.cast(features['height'], tf.int32)
width = tf.cast(features['width'], tf.int32)
depth = tf.cast(features['depth'], tf.int32)
image = tf.decode_raw(features['image'], tf.float32)
image = tf.reshape(image, shape=[height * width * depth])
image.set_shape([HEIGHT * WIDTH * 1])
image = tf.cast(image, tf.float32) * (1. / 255)
label = tf.cast(features['label'], tf.int32)
features = {'images': image, 'labels': label}
return (features)
def get_model():
# Vector CapsNet
num_label = 21
in_images = tf.placeholder(tf.float32, [None, HEIGHT*WIDTH])
with tf.variable_scope('Conv1_layer'):
# Conv1, return with shape [batch_size, 20, 20, 256]
inputs = tf.reshape(in_images, shape=[-1, HEIGHT, WIDTH, 1])
conv1 = tf.layers.conv2d(inputs,
filters=256,
kernel_size=9,
strides=1,
padding='VALID',
activation=tf.nn.relu)
with tf.variable_scope('PrimaryCaps_layer'):
primaryCaps, activation = cl.layers.primaryCaps(conv1,
filters=32,
kernel_size=9,
strides=2,
out_caps_dims=[8, 1],
method="norm")
with tf.variable_scope('DigitCaps_layer'):
routing_method = "EMRouting"
num_inputs = np.prod(cl.shape(primaryCaps)[1:4])
primaryCaps = tf.reshape(primaryCaps, shape=[-1, num_inputs, 8, 1])
activation = tf.reshape(activation, shape=[-1, num_inputs])
poses, probs = cl.layers.dense(primaryCaps,
activation,
num_outputs=num_label,
out_caps_dims=[16, 1],
routing_method=routing_method)
# Decoder structure
# Reconstructe the inputs with 3 FC layers
with tf.variable_scope('Decoder'):
logits_idx = tf.to_int32(tf.argmax(cl.softmax(probs, axis=1), axis=1))
labels = tf.one_hot(logits_idx, depth=num_label, axis=-1, dtype=tf.float32)
labels_one_hoted = tf.reshape(labels, (-1, num_label, 1, 1))
masked_caps = tf.multiply(poses, labels_one_hoted)
num_inputs = np.prod(masked_caps.get_shape().as_list()[1:])
active_caps = tf.reshape(masked_caps, shape=(-1, num_inputs))
fc1 = tf.layers.dense(active_caps, units=512, activation=tf.nn.relu)
fc2 = tf.layers.dense(fc1, units=1024, activation=tf.nn.relu)
num_outputs = HEIGHT * WIDTH * 1
recon_imgs = tf.layers.dense(fc2,
units=num_outputs,
activation=tf.sigmoid)
recon_imgs = tf.reshape(recon_imgs, shape=[-1, HEIGHT, WIDTH, 1])
return in_images, recon_imgs, probs
def show_reconstruct(original, reconstruct, true_lbl, pred_lbl, lbl=None):
if lbl is not None:
ind = np.where(true_lbl==lbl)[0][0]
else:
ind = 0
original_image = original[ind]
reconstruct_image = reconstruct[0]
original_image = np.reshape(original_image, newshape=(HEIGHT,WIDTH))
true_lbl = chars_map[true_lbl[ind]]
pred_lbl = chars_map[pred_lbl[ind]]
title = true_lbl + 20*' ' + pred_lbl
res = np.hstack((original_image, reconstruct_image))
plt.imshow(res, cmap='gray')
plt.title(title)
plt.show()
def test(records):
"""
param records: list of .record files
"""
batch_size = 128
dataset = tf.data.TFRecordDataset(records)
dataset = dataset.map(parse_fn).batch(batch_size).repeat(1).shuffle(buffer_size=5000, seed=3)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
inputs, recon_imgs, labels_one_hoted = get_model()
saver = tf.train.Saver()
true_labels, predicted = [], []
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state(os.path.dirname('../models/models/results/logdir/model.ckpt-6600'))
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
while True:
try:
elem = sess.run(next_element)
raw_images = elem['images']
true_lbls = elem['labels'] + 1
reconstructed, pred_lbls = sess.run([recon_imgs,labels_one_hoted], feed_dict={inputs : raw_images})
reconstructed = np.squeeze(reconstructed)
pred_lbls = np.argmax(pred_lbls, axis=1) + 1
#show_reconstruct(raw_images, reconstructed, true_lbls, pred_lbls, lbl=9)
predicted.extend(pred_lbls)
true_labels.extend(true_lbls)
except OutOfRangeError as ex:
break
labels_id = np.arange(1, 11, 1).astype(np.int16)
labels = [chars_map[lbl] for lbl in labels_id]
conf_matr = confusion_matrix(true_labels, predicted)
ax = heatmap(conf_matr, annot=True, fmt='d')
ax.set_xticklabels(labels)
ax.set_yticklabels(labels)
plt.show()
if __name__ == '__main__':
records = ['data/symbols/eval_symbols.tfrecord']
test(records=records)
Checkpoint 6600 is not matter, I trained 50000 steps, and problem was the same
from capslayer.
The code looks fine. I'm not sure what the problem is. What I can guess is that capsnet might have bias problem. But before making this conclusion, I suggest:
- visualize the input image and print its corresponding label (both for training and validation set) to make sure the dataset is right;
- then remove one or more class from dataset (not the 7th class), and train the model from scratch and test it again. To see if the 7th or any others class were never predicted.
Of course it might be an implementation problem, I will check my code again.
from capslayer.
Related Issues (20)
- ValueError when using cl.layers.conv2d HOT 1
- Reproducing results from "Matrix Capsules with EM Routing" HOT 2
- EM Capsule Dense Layer Routing HOT 4
- Support for MSCOCO dataset HOT 1
- Spelling Error
- CapsLayer Neural Network Example
- Question for E_step in EM Routing
- ModuleNotFoundError: No module named 'capslayer.data.datasets.stanford_drone' HOT 2
- cifar10 dataset "Maximum allowed size exceeded" Error HOT 2
- tensorflow2.0 how to write config.py? HOT 1
- i have issue with cifar10
- OutOfRangeError (see above for traceback): End of sequence
- Hi, You can show an example application with CONV3D
- Routing by agreement with Transformer-based for NMT
- In the provided MNIST example not all capsules are seeing activation probability
- fashion mnist support
- Multi-label classificatoin
- Performance issues in capslayer/data/datasets/cifar10/reader.py
- Performance issues in /capslayer/data/datasets (by P3) HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from capslayer.