Run trained deep neural networks in the browser or node.js.
Check out the project page and examples.
Background
Training deep neural networks on any meaningful dataset requires massive computational resources and lots and lots of time. However, the forward pass prediction phase is relatively cheap - typically there is no backpropagation, computational graphs, loss functions, or optimization algorithms to worry about.
What do you do when you have a trained deep neural network and now wish to use it to power a part of your client-facing web application? Traditionally, you would deploy your model on a server and call it from your web application through an API. But what if you can deploy it in the browser alongside the rest of your webapp? Computation would be offloaded entirely to your end-user!
Perhaps most users will not be able to run billion-parameter networks in their browsers quite yet, but smaller networks are certainly within the realm of possibility.
By focusing purely on prediction of already trained neural networks, we can focus on making forward predictive passes through the network as computationally efficient as possible, taking into full constraints of client hardware and the current state of web browsers.
Computation on GPU is perfomed where possible and advantageous to do so. Currently, this is implemented using WebGL within a browser environment, and ArrayFire within a node.js environment.
Ultimately, the goal of this project is to have a lightweight javascript library that can take a serialized Keras, Caffe, Torch [insert other deep learning framework here] model, together with pretrained weights, pack it in your webapp, and be off and running.
Andrej Karpathy's ConvNetJS is of course a source of inspiration, as well as the excellent python deep learning framework Keras.
Examples
-
CIFAR-10 VGGNet-like convolutional neural network / src / demo
-
LSTM recurrent neural network for classifying astronomical object names / src / demo
You can also run the examples on your local machine at http://localhost:8000
:
$ npm run examples-server
Usage
See the source code of the examples above. In particular, the CIFAR-10 example demonstrates multi-threaded implementation with Web Workers.
In the browser:
<script src="neocortex.min.js"></script>
<script>
// use neural network here
</script>
In node.js:
$ npm install neocortex-js
import NeuralNet from 'neocortex-js';
The core steps involve:
- Instantiate neural network class (NOTE: GPU mode is a work-in-progress)
let nn = new NeuralNet({
// relative URL in browser/webworker, absolute path in node.js
modelFilePath: 'model.json',
arrayType: 'float64', // float64 or float32
useGPU: false // if true, will try to use GPU for computations
});
- Load the model JSON file
nn.init().then(() => {
// do stuff with nn
});
- Feed input data into neural network
nn.predict(input).then(predictions => {
// make use of predictions
});
Build
To run build yourself, for both the browser (outputs to build/neocortex.min.js
) and node.js (outputs to dist/
):
$ npm run build
To build just for the browser:
$ npm run build-browser
Frameworks
Keras
Script to serialize a trained Keras model together with its hdf5
formatted weights is located in the utils/
folder here. Currently only supports sequential models with layers in the API section below. Implementation of graph models is planned.
API
Functions and layers currently implemented are listed below. More forthcoming.
Activation functions
-
linear
-
relu
-
sigmoid
-
hard_sigmoid
-
tanh
-
softmax
Advanced activation layers
-
leakyReLULayer
-
parametricReLULayer
-
parametricSoftplusLayer
-
thresholdedLinearLayer
-
thresholdedReLuLayer
Basic layers
-
denseLayer
-
flattenLayer
Recurrent layers
-
rGRULayer
(gated-recurrent unit or GRU) -
rLSTMLayer
(long short-term memory or LSTM) -
rJZS1Layer
,rJZS2Layer
,rJZS3Layer
(mutated GRUs - JZS1, JZS2, JZS3 - from Jozefowicz et al. 2015)
Convolutional layers
-
convolution2DLayer
-
maxPooling2DLayer
-
convolution1DLayer
-
maxPooling1DLayer
Embedding layers
embeddingLayer
- maps indices to corresponding embedding vectors
Normalization layers
batchNormalizationLayer
- see Ioffe and Szegedy 2015
Tests
$ npm test
Browser testing is planned.
Credits
Thanks to @halmos for the logo! It's an allusion to the fact that GPU computation in the browser currently needs to be interfaced through WebGL textures and shaders and all that good stuff.