Giter Club home page Giter Club logo

am-vrp's Introduction

Attention Model for Vehicle Routing Problems

Tensorflow 2.0 implementation of Attention, Learn to Solve Routing Problems! article.

This work was done as part of a final project for DeepPavlov course: Advanced Topics in Deep Reinforcement learning.

Code of the full project (dynamic version) is located at https://github.com/d-eremeev/ADM-VRP

Enviroment:

Current enviroment implementation is located in Enviroment.py file - AgentVRP class.

The class contains information about current state and actions that were done by agent.

Main methods:

  • step(action): transit to a new state according to the action.
  • get_costs(dataset, pi): returns costs for each graph in batch according to the paths in action-state space.
  • get_mask(): returns a mask with available actions (allowed nodes).
  • all_finished(): checks if all games in batch are finished (all graphes are solved).

Let's connect current terms with RL language (small dictionary):

  • State: $X$ - graph instance (coordinates, demands, etc.) together with information in which node agent is located.
  • Action: $\pi_t$ - decision in which node agent should go.
  • Reward: The (negative) tour length.

Model Training:

AM is trained by policy gradient using REINFORCE algorithm with baseline.

Baseline

  • Baseline is a copy of model with fixed weights from one of the preceding epochs.
  • Use warm-up for early epochs: mix exponential moving average of model cost over past epochs with baseline model.
  • Update baseline at the end of epoch if the difference in costs for candidate model and baseline is statistically-significant (t-test).
  • Baseline uses separate dataset for this validation. This dataset is updated after each baseline renewal.

Files Description:

  1. Enviroment.py - enviroment for VRP RL Agent
  2. layers.py - MHA layers for encoder
  3. attention_graph_encoder.py - Graph Attention Encoder
  4. attention_graph_decoder.py - Graph Attention Decoder
  5. attention_model.py - Attention Model
  6. reinforce_baseline.py - class for REINFORCE baseline
  7. train.py - defines training loop, that we use in train_with_checkpoint.ipynb
  8. train_with_checkpoint.ipynb - from this file one can start training or continue training from chechpoint
  9. generate_data.py - various auxiliary functions for data creation, saving and visualisation
  10. results folder: folder name is ADM_VRP_{graph_size}_{batch_size}. There are training logs, learning curves and saved models in each folder

Training procedure:

  1. Open train_with_checkpoint.ipynb and choose training parameters.
  2. All outputs would be saved in current directory.

am-vrp's People

Contributors

alexeypustynnikov avatar d-eremeev avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

am-vrp's Issues

tf.where

In the function decoder_mha, which is in file attention_graph_decoder.py, the code does the following operation:
mask = mask[:, tf.newaxis, :, :]
The purpose of this operation is to expand mask's dimension so that following codes can work:
compatibility = tf.where(mask,
tf.ones_like(compatibility) * (-np.inf),
compatibility
)
However, I encountered the problem that , the dimension of "mask" is different from compatibility and it cannot wrok because the shape of the mask is [bacth_szie, 1, seq_len_q, seq_len_k] while the shape of the compatibility is [batch_size, num_heads, seq_len_q, seq_len_k].
How can I solve this problem?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.