Giter Club home page Giter Club logo

music-transformer-comp6248's Introduction

music-transformer-comp6248

This re-implementation is meant for the COMP6248 reproducibility challenge.

Re-implementation of music transformer paper published for ICLR 2019.
Link to paper: https://openreview.net/forum?id=rJe4ShAcF7

Scope of re-implementation

We focus on re-implementing a subset of the experiments described in the published paper, only using the JSB Chorales dataset for training three model variations listed below:

  • Baseline Transformer (TF) (5L, 256hs, 256att, 1024ff, 8h)
  • Baseline TF + concatenating positional sinusoids
  • TF with efficient relative attention (5L, 512hs, 512att, 512ff, 256r, 8h)

Results

Unfortunately, we were unable to successfully reproduce the results shown in the published paper. There are a number of reasons that contribute to this, but mainly due to insufficient details provided for the JSB dataset and how the data was processed for the transformer model.

Nonetheless, these are the results obtained after training for 300 epochs:

Model Variations Final Training Loss Final Validation Loss
Baseline TF 1.731 3.398
Baseline TF + concat pos sinusoids 2.953 3.575
TF with relative attention 3.028 3.743

Please read the report for more detailed information regarding the re-implementation.

Usage

Note: Please create the following folders in this directory before running the scripts:

  • ./weights/ - for storing the trained weights.
  • ./outputs/ - for storing training/validation loss values and generated outputs.
  • ./plots/ - for storing the plots.

To train the models, simply run train.py and add the arguments accordingly:

python train.py -src_data datasets/JSB-Chorales-dataset/Jsb16thSeparated.npz -epochs 300 -weights_name baselineTF_300epoch -device cuda:2 -checkpoint 10

To generate music using a trained model, run generate.py and add the arguments accordingly:

python generate.py -src_data datasets/JSB-Chorales-dataset/Jsb16thSeparated.npz -load_weights weights -weights_name baselineTF_300epoch -device cuda:1 -k 3

To generate the plots for training/validation loss, please use gen_training_plots.py:

python gen_training_plots.py -t_loss_file <t_loss_npy_file> -v_loss_file <v_loss_npy_file>

<t_loss_npy_file> and <t_loss_npy_file> are generated during training and are saved in the outputs directory.

The generated sequences can be plotted and listened using this IPython Notebook.
Please note that this requires Magenta to be installed.

Datasets

  1. JSB Chorales dataset
  1. MAESTRO dataset (Not used)

Environment Setup

  1. PyTorch with CUDA-enabled GPUs.
  • Install CUDA 9.0 and CUDNN 7.4.1.5
  • Then follow these steps:
conda create -n torch python=3.6
conda activate torch
conda install pytorch torchvision cuda90 -c pytorch
  1. Magenta for plotting and playing the generated notes.
    Steps for installing on Ubuntu:
conda create -n magenta python=3.6
conda activate magenta
sudo apt-get update
sudo apt-get install build-essential libasound2-dev libjack-dev libfluidsynth1 fluid-soundfont-gm
pip install --pre python-rtmidi
pip install jupyter magenta pyfluidsynth pretty_midi

music-transformer-comp6248's People

Contributors

jinyi12 avatar zckoh avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.