Giter Club home page Giter Club logo

alphax-nasbench101's Introduction

AlphaX-NASBench101

AlphaX is a new Neural Architecture Search (NAS) agent that uses MCTS for efficient model architecture search with Meta-DNN as a predictive model to estimate the accuracy of a sampled architecture. Compared with Random Search, AlphaX builds an online model which guides the future search, compared to greedy methods, e.g. Q-learning, Regularized Evolution or Top-K methods, AlphaX dynamically trades off exploration and exploitation and can escape from locally optimal solutions with fewer number of search trials. For details of AlphaX, please refer to AlphaX: eXploring Neural Architectures with Deep Neural Networks and Monte Carlo Tree Search.

This repository hosts the implementation of AlphaX for searching on a design domain defined by NASBench-101. NASBench-101 is a NAS dataset that contains 420k+ networks with their actual training, validation accuracies. For details of NASBench-101, please check here.

In 200 search trails, AlphaX is on average 2.8x and 3.0x faster than Regularized Evolution and Random Search on NASBench dataset.

This is how AlphaX progressively probes the search domain. Each node represents an MCTS state; the node color reflects its value, i.e. accuracy, indicating how promising a search branch.

Please cite our work, if it helps your research ;)

@article{wang2018alphax,
  title={AlphaX: eXploring Neural Architectures with Deep Neural Networks and Monte Carlo Tree Search},
  author={Wang, Linnan and Zhao, Yiyang and Jinnai, Yuu},
  journal={arXiv preprint arXiv:1805.07440},
  year={2018}
}

Requirements

Python >= 3.5.5, numpy >= 1.9.1, keras >= 2.1.6, jsonpickle

Setup

  1. Clone this repo.
git clone [email protected]:linnanwang/AlphaX-NASBench101.git
cd AlphaX-NASBench101
  1. (optional) Create a virtualenv for this library.
virtualenv --system-site-packages -p python3 ./venv
source venv/bin/activate
  1. Install the project along with dependencies.
pip install numpy
pip install keras
pip install jsonpickle

Download the dataset

The full NASBench dataset in our format is at here. Please place the dataset into the same directory of AlphaX-NASBench101.

Usage

After preparing all dependencies, execute the following commands to start the search:

MCTS without meta_DNN (Fast on CPU)

python MCTS.py

MCTS with meta_DNN assisted (Slow, and please run on GPU.)

python MCTS_metaDNN.py

Note: meta_DNN requires training to predict the accuracy of an unseen architecture. Running it toward GPU is highly recommended.

Changing the size of search domain

By default, we constrain the nodes <= 6, that consists of 60000+ valid networks. The following steps illustrate how to expand or shrink the search domain.

  • In arch_generator.py, changing the MAX_NODES to any in [3, 4, 5, 6, 7] (line 20). NASBench-101 provides all the networks up to 7 nodes.
class arch_generator:
MAX_NODES     = 6 #inclusive
MAX_EDGES     = 9 #inclusive
  • Also lines from 74-80 in net_training.py defines the search target. The search stops once it hits the target. The target consists of two parts, the adjacent matrix, and the node list. Please change it to a different target after you changing the maximal nodes.
# 6 nodes
t_adj_mat  = 
[[0, 1, 1, 1, 1, 1],
 [0, 0, 0, 0, 1, 0],
 [0, 0, 0, 1, 0, 0],
 [0, 0, 0, 0, 1, 0],
 [0, 0, 0, 0, 0, 1],
 [0, 0, 0, 0, 0, 0]]
t_node_list =  ['input', 'conv3x3-bn-relu', 'conv3x3-bn-relu', 'conv3x3-bn-relu', 'conv3x3-bn-relu', 'output']

Contributors

Linnan Wang, Brown University, Yiyang Zhao, Unaffiliated,

We're also sincerely grateful for the valuable suggestions from Yuu Jinnai (Brown), Yuandong Tian(Facebook AI Research), and Rodrigo Fonseca (my awesome advisor at Brown).

alphax-nasbench101's People

Contributors

aoiang avatar linnanwang avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.