Giter Club home page Giter Club logo

configurable-tensorflow-din's Introduction

Configurable Tensorflow DIN

DIN

This project is a tensorflow implementation of DIN (Deep Interest Network for Click-Through Rate Prediction)

We abstract DIN into the following structure:

DIN Model Structure

  • vector: directly fed into the MLP, e.g., user profile
  • sequence: sequential values that will be attentioned, e.g., history click sequence. There can be multiple sequences, e.g., purchasing history and browsing history.
  • target: target value that will be directly fed into the MLP of DIN, and applies attention to the sequence inputs, e.g., target page

Each input value can come in two ways:

  • value: continuous input values, e.g., time spent on a page
  • embedding: discrete input values which will be passed into an embedding layer, e.g., page name

Therefore, we classify the model inputs into 6 categories: emb_vec, emb_seq, emb_tgt, val_vec, val_seq, val_tgt. We implement DIN based on such abstract, independent with specific model inputs. Pass a config dict to the model (example), and the model can construct the computation graph automatically.

Usage

Fast Run

  1. Run data/generate_samples.py to generate demo samples:
    python data/generate_samples.py
  2. Create config.py under scrips/ to specify your project absolute path, e.g.,
    project_fd = '/root/din/'
  3. Run scripts/train_din.py to start the training process:
    python scripts/train_din.py

Details

  • If you have a customized dataset, modify common.py, specify:
    train_data_fp = "path/to/trainset/file"
    test_data_fp = "path/to/testset/file"
  • If need a customized dataset with customized model structure, modify fea_config.py, and write a FEA_CONFIG dict and SHARED_EMB_CONFIG dict.

Code Structure

scripts
├── common.py: global configuration
├── config.py: global configuration
├── data: data reading module
│   ├── dataset.py: object to read, shuffle, and preprocess data
│   └── fea_config.py: configuration for preprocess and model construction
├── model
│   ├── din.py: DIN object
│   ├── layers
│   │   ├── activation.py: activation layers (PReLU and Dice)
│   │   ├── attention.py: attention layers
│   │   ├── forward_net.py: MLP layers
│   │   └── interface.py: abstract class for layers
│   └── utils
│       ├── constant.py: constants
│       ├── input_fn.py: wrapper for estimator training
│       ├── proc_input_config.py: utils function
│       └── widgets.py: utils function
├── train
│   ├── evaluate.py: evaluation code
│   ├── metrics.py: evaluation metrics
│   └── vanilla_train.py: training code
└── train_din.py: entry file

Key Methods

The Din class defined in din.py has the following functions:

  • build_graph_: build computation graph, and store the input and output nodes as attributes
    • attributes:
      • features_ph, labels_ph: input nodes
      • outputs: output nodes
      • session, graph: computation graph
      • saver: checkpoint saver
    • relevant methods:
      • switch_graph: switch to your specified graph if multiple graph exists
      • load_from: load the variable values from a target graph
  • model_fn: for estimator invocation
  • freeze: set trainable=False for variables (invoke before build_graph please)

configurable-tensorflow-din's People

Contributors

mikudehuane avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.