Giter Club home page Giter Club logo

keras-mnist-center-loss-with-visualization's Issues

the issue about fit_generator

大大您好,謝謝你提供這麼棒的source code,讓我受益良多 :)
以下我有一些關於fit generator的問題,還望大大能夠解惑~

問題1:

hist = model.fit_generator(generator=data_generator_centerloss(X=[x_train, y_train_a_class_value], Y=[y_train_a_class,y_train_a, random_y_train_a], batch_size=batch_size),
                                   steps_per_epoch=train_num // batch_size,
                                   validation_data=([x_test,y_test_a_class_value], [y_test_a_class,y_test_a, random_y_test_a]),
                                   epochs=nb_epochs, verbose=1,
                                   callbacks=callbacks)

以上是大大提供fit_generator的範本,我好奇的是為何Y=[y_train_a_class,y_train_a, random_y_train_a]有3個輸出 ? y_train_a代表什麼意思?
會有這個好奇點是因為我看在TTY.mnist.py是用.fit實踐的 --> model_centerloss.fit([x_train,y_train_value], [y_train, random_y_train], batch_size=batch_size, epochs=epochs, verbose=1, validation_data=([x_test,y_test_value], [y_test,random_y_test]), callbacks=[histories]) ,照我的理解,其為雙輸入雙輸出的格式。故我覺得.fit_generator也要為雙輸入雙輸出~

問題2

我依照雙輸入和雙輸出的想法建構自己的fit_generator,出現一個很奇怪的問題
image
以下是我輸入和輸出的 .shape,感覺大小是正確的
image

所以有點摸不著頭緒,是不是我的generator和l2_loss的格是不相符,不過我看l2_loss的型態都是?,感覺怪怪的QQ

補充

以下是我實作generator的方式:,我的generator會return這些東西
image
然後我的fit_generator是這樣實踐的
image

不好意思打擾您,真的很謝謝你提供那麼棒的程式,讓我在實踐center loss時有一個很棒的參考對象~謝謝

,Tina

中心损失中中心值未更新

按照上诉的方法,中心值没有更新,每次还是取随机值,更新需要进行梯度更新或者每次重新计算中心值

How to use Center_loss in fit_generator not fit

fit

history_softmax = resnet_model.fit_generator(
train_generator,
steps_per_epoch=123520//BATCH_SIZE,
epochs=EPOCHS,
validation_data=val_generator,
validation_steps=7119//BATCH_SIZE,
callbacks=[lr]
)

fit_generator

history2 = model_centerloss.fit_generator(train_generator, [y_train, random_y_train],
batch_size=batch_size, epochs=epochs,
verbose=1, validation_data=(val_generator, [y_test, random_y_test]),
callbacks=[lr, modelcheckpoint])

some question for center loss

您好:
有看了您在3/2更新的中文解說,非常清楚也讓我了解了許多,但我想請問一下我記得center loss 會在訓練過程中不斷更新center feature,可是在您的程式中並沒有看到相關的代碼,因此想請問這部份是否出現了一些問題,還是是我沒注意到,謝謝

验证集的准确率不提高?

hi @shamangary
我用我自己的数据集,训练。在测试集上的结果很差,只有14%,而且发现dense_2_acc 不提高,最高到19%。如果,我不用center_loss,在测试集上的准确率有50%。
2018-05-27 6 14 57
加载数据集

训练集和验证集
train,trainlabel = xrp.loadtrain()
train = densenet.preprocess_input(train)
index = [i for i in range(len(train))]
random.shuffle(index)
train = train[index]
trainlabel = trainlabel[index]
(x_train,x_test)=(train[0:11000],train[11000:])
(y_train,y_test)=(trainlabel[0:11000],trainlabel[11000:])
y_train_value = y_train
y_test_value = y_test
y_train = np_utils.to_categorical(y_train,nb_classes)
y_test = np_utils.to_categorical(y_test,nb_classes)
测试集
X_test,Y_test = xrp.loadtest()
Z_test_value = Y_test
Y_test = np_utils.to_categorical(Y_test,nb_classes)
X_test = densenet.preprocess_input(X_test)

网络结构如下
base_model = VGG16(weights='imagenet',include_top=False)
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = BatchNormalization(axis=1)(x)
ip1 = Dense(1024,activation='relu')(x)
predictions = Dense(8,activation='softmax')(ip1)
model = Model(input=base_model.input,output=predictions)

 for layer in base_model.layers:
      layer.trainable = True

 sgd = SGD(lr=0.0001, decay=1e-6, momentum=0.9,nesterov=True)
 model.compile(optimizer='rmsprop',loss='categorical_crossentropy')

#this is the center_loss
isCenterloss = True

  if isCenterloss:
      lambda_c = 0.2
      input_target = Input(shape=(1,))
      centers = Embedding(8,1024)(input_target)
      l2_loss = Lambda(lambda x: K.sum(K.square(x[0]- 
      x[1[:,0]),1,keepdims=True),name='l2_loss')
      model_centerloss = Model(inputs=[base_model.input,input_target],outputs=[predictions,l2_loss])
     sgd_1 = SGD(lr=0.00001, decay=1e-6, momentum=0.9,nesterov=True)
      model_centerloss.compile(optimizer=sgd_1, loss=["categorical_crossentropy", lambda y_true,,y_pred: y_pred],loss_weights=[1,lambda_c],metrics=['accuracy'])

  if isCenterloss:
     random_y_train = np.random.rand(x_train.shape[0],1)
     random_y_test = np.random.rand(x_test.shape[0],1)
     model_centerloss.fit([x_train,y_train_value], [y_train, random_y_train], batch_size=batch_size,epochs=nb_epoch, verbose=1, validation_data=([x_test,y_test_value], [y_test,random_y_test]))

预测代码如下:

privateLabel_0
= model_centerloss.predict([X_test,Z_test_value],batch_size=1,verbose=1) print (len(privateLabel_0)
privateLabel = privateLabel_0[0]
list_label1 = []
list_label2 = []
x= len(privateLabel)

for j in range(x):
     list_label1.append(np.argmax(Y_test[j]))
     list_label2.append(np.argmax(privateLabel[j]))

privateAcc
= len([1 for i in range(len(Y_test)) if list_label2[i]==list_label1[i]])/float(len(Y_test))
print ('the privateAcc is ',privateAcc)

我调整了learning_rate = 0.00001,可是效果不好,dense_2_acc 不提高。
请问一下,是哪方面的问题。谢谢。

Question about learning rate of centrers

Hello,

Thanks for your code. I just have a very quick question.

The paper mentions that there should be a scalar alpha to control the learning rate of the centers. This is interesting to me. However I don't think you implement such a thing. If you do please correct me.
Plus do you know how to implement that? I am very curious about this but right now I don't have a concrete idea about doing so.
Thx

single value ground truth labels as inputs 不是很理解

@shamangary
Thanks.

input_target = Input(shape=(1,)) # single value ground truth labels as inputs

这一行代码不是很懂。
center_loss 公式中 是C,代码中你解释是为 single value ground truth labels as inputs
C是聚类的中心,你是将single value ground truth label 作为初始值么?
谢谢

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.