Giter Club home page Giter Club logo

Comments (1)

zhouzhongmi avatar zhouzhongmi commented on August 11, 2024 1

I have used another way to construct the model network. And the problem fixed. I will close this issue.

import numpy as np
from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Sequential

from cond_rnn import ConditionalRNN

i1 = Input(shape = (X_train.shape[1], X_train.shape[2]))
ic_1 = Input(shape=(categorical_appid.shape[1],))
ic_2 = Input(shape=(categorical_advertiser.shape[1],))
cond_rnn_layer = ConditionalRNN(units=64, cell='LSTM', return_sequences=True, mask=-1)([i1, ic_1, ic_2])
out_layer = TimeDistributed(Dense(Y_train_non_lt.shape[2], activation='relu'),name = 'm_output')(cond_rnn_layer)
model = Model([i1, ic_1, ic_2], out_layer)

optim = Adam(lr=0.003,)
model.compile(optimizer=optim, loss={'m_output': 'mse'}, metrics={'m_output': 'mse'})

callbacks = [
                EarlyStopping(patience=30, monitor='val_mse'),
                ModelCheckpoint(filepath='model_tuning/day_20210511/seq_model.{epoch:02d}-{val_loss:.2f}.h5', monitor="val_loss", save_best_only=True),
                TensorBoard(log_dir='./training_logs_0511/seq'),
            ]

out = model.fit(x=[X_train, categorical_appid, categorical_advertiser],  y=Y_train, epochs=100, batch_size = 1024,  verbose=2, callbacks=callbacks, workers = 100, validation_data=([X_eval, categorical_appid_eval, categorical_advertiser_eval], Y_eval),
          sample_weight=None)

from cond_rnn.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.