Giter Club home page Giter Club logo

mu-llama's Introduction

MU-LLaMA:
Music Understanding Large Language Model

PWC PWC PWC

This is the official repository for Music Understanding LLaMA: Advancing Text-to-Music Generation with Question Answering and Captioning

The demo page with more information regarding the MU-LLaMA model is avilable here.

Introduction

The MU-LLaMA model is Music Understanding Language Model designed with the purpose of answering questions based on music. Our model is also designed with the purpose of captioning music files to generate Text-to-Music Generation datasets. The model uses MERT + LLaMA as the backbone and employs an adapter to encoperate music context information to guide LLaMA's output. MERT was chosen as the music encoder for our model after comparison of different music representation models, which can be viewed here. We also provide the code for generating our MusicQA dataset from MusicCaps and the MagnaTagATune datasets.

MU-LLaMA Demo

For the working of our model, Facebook's LLaMA-2 model weights are required, details on obtaining these weights are given on HuggingFace. Our pretrained weights for the MU-LLaMA model, finetuned from LLaMA 7B-2 can be downloaded here. Once downloaded, store the files in the ckpts folder within the MU-LLaMA directory.

Once downloaded the directory structure will be as shown below.

.
├── ...
├── MU-LLaMA                
│   ├── ckpts
│   │   │── LLaMA
│   │   │   │── 7B
│   │   │   │   │── checklist.chk
│   │   │   │   │── consolidated.00.pth
│   │   │   │   │── params.json
│   │   │   │── llama.sh
│   │   │   │── tokenizer.model
│   │   │   │── tokenizer_checklist.chk
│   │   │── 7B.pth
│   │   ├── checkpoint.pth
└── ...

We use Python 3.9.17 for this project and the library requirements are given in requirements.txt. The demo can be run using gradio_app.py.

python gradio_app.py --model ./ckpts/checkpoint.pth --llama_dir ./ckpts/LLaMA

Training MU-LLaMA

To train the MU-LLaMA model, follow the steps as below.

MusicQA Dataset

We use the MusicCaps and the MagnaTagATune dataset to generate our training MusicQA dataset and the MTG-Jamendo for evaluation. You can download the generated MusicQA dataset here.

To generate the dataset yourself, first download the MusicCaps, MTT and MTG datasets. Once downloaded, the directory structure would be as shown.

.
├── ...
├── MusicQA                
│   ├── MTT
│   │   ├── audios
│   │   │   │── ...
│   │   ├── annotations_final.csv
│   ├── MusicCaps
│   │   ├── audios
│   │   │   │── ...
│   │   ├── musiccaps-public.csv
│   ├── MTG
│   │   ├── audios
│   │   │   │── 00
│   │   │   │── 01
│   │   │   │── ...
│   │   ├── raw_30s_cleantags_50artists.tsv
│   ├── MTT_process.py
│   ├── musiccaps_process.py
│   ├── MTG_process.py
│   ├── generate_dataset.py
└── ...

The MusicQA dataset generation is a very computationally intensive process which takes around 8 days per dataset on a Tesla V100-SXM2-32GB GPU, so it is recommended to download our generated dataset.

📝 Notes: Run the following command to flatten the MTT audio file structure once downloaaded and extracted,

find ./MTT/audios -mindepth 2 -type f -exec mv -t ./MTT/audios -i '{}' +

We only use the folders 00 to 09 from the MTG dataset

By running musiccaps_process.py, MTT_process.py and MTG_process.py, you can generate the question answer pairs from each of the datasets and by running generate_dataset.py the final datasets for pretraining, finetuning and evaluation will be generated.

usage: generate_dataset.py [-h] --mtt MTT --mtg MTG --musiccaps MUSICCAPS --musicqa MUSICQA

optional arguments:
  -h, --help            show this help message and exit
  --mtt MTT             Directory of the MTT dataset
  --mtg MTG             Directory of the MTG dataset
  --musiccaps MUSICCAPS
                        Directory of the MusicCaps dataset
  --musicqa MUSICQA     Directory of the MusicQA dataset to be generated

MU-LLaMA Pretraining

To pretrain the MU-LLaMA model, we use the MusicCaps part of the MusicQA dataset and the Alpaca Instruction dataset with the pretrain.sh script.

./pretrain.sh ./ckpts/LLaMA-2 ./configs/pretrain.yaml ./ckpts/MU-LLaMA_Pretrain

This will pretrain the MU-LLaMA model for 150 epochs. The hyperparameters can be modified in the pretrain.sh file.

MU-LLaMA Finetuning

To finetune the MU-LLaMA model, we use the MTT part of the MusicQA dataset with the finetune.sh script.

./finetune.sh ./ckpts/LLaMA-2 ./ckpts/MU-LLaMA_Pretrain/checkpoint.pth ./configs/finetune.yaml ./ckpts/MU-LLaMA_Finetune

This will finetune the MU-LLaMA model for 20 epochs. The hyperparameters can be modified in the finetune.sh file. The MU-LLaMA model with 7B parameters takes approximately 2 days to train on a Tesla V100-SXM2-32GB GPU. Once trained, the model can be tested using the Gradio demo.

MU-LLaMA Inference

To test the model without Gradio, the inference.py script can be used.

usage: inference.py [-h] [--model MODEL] [--llama_type LLAMA_TYPE] [--llama_dir LLAMA_DIR] [--mert_path MERT_PATH] --audio_path AUDIO_PATH [--question QUESTION]

optional arguments:
  -h, --help            show this help message and exit
  --model MODEL         Name of or path to the trained checkpoint
  --llama_type LLAMA_TYPE
                        Type of llama original weight
  --llama_dir LLAMA_DIR
                        Path to LLaMA pretrained checkpoint
  --mert_path MERT_PATH
                        Path to MERT pretrained checkpoint
  --audio_path AUDIO_PATH
                        Path to the input music file
  --question QUESTION   Question to ask the model

MU-LLaMA Evaluation

Our model was compared against audio enabled models such as the Listen, Think and Understand (LTU) model and the LLaMA Adapter model trained on our MusicQA dataset. We evaluate the models using BLEU (B-U), METEOR (M-R), ROUGEL (R-L) and BERT-Score (BERT-S) which are common evaluation metrics for text generation. For the BLEU score, a weighted average of BLEU1, BLEU2, BLEU3 and BLEU4 (weight = 0.25 for each) is used.

The evaluation scripts are given in the ModelEvaluations folder. The generate scripts are used to generate the answers for all the questions in the dataset.

usage: generate_mullama.py [-h] [--model MODEL] [--knn KNN] [--llama_type LLAMA_TYPE] [--llama_dir LLAMA_DIR] [--mert_path MERT_PATH]

optional arguments:
  -h, --help            show this help message and exit
  --model MODEL         Name of or path to the trained checkpoint
  --knn KNN             Name of or path to the directory with knn checkpoint
  --llama_type LLAMA_TYPE
                        Type of llama original weight
  --llama_dir LLAMA_DIR
                        Path to LLaMA pretrained checkpoint
  --mert_path MERT_PATH
                        Path to MERT pretrained checkpoint
usage: generate_ltu.py [-h] [--demo DEMO]

optional arguments:
  -h, --help   show this help message and exit
  --demo DEMO  Link to the LTU Demo Page
usage: generate_llama-adapter.py [-h] [--model MODEL] [--llama_dir LLAMA_DIR]

optional arguments:
  -h, --help            show this help message and exit
  --model MODEL         Name of or path to the trained checkpoint
  --llama_dir LLAMA_DIR
                        Path to LLaMA pretrained checkpoint

Once generated, evaluate.py can be used to evaluated the generated answers for the three models. The results are shown below.

Model B-U ↑ M-R ↑ R-L ↑ BERT-S ↑
LTU 0.242 0.274 0.326 0.887
LLaMA Adapter 0.273 0.334 0.413 0.895
MU-LLaMA 0.306 0.385 0.466 0.901

Acknowledgements

This code contains elements from the following repos:

Cite our work

If you find this repo useful, please consider citing:

@article{liu2023music,
  title={{Music Understanding LLaMA: Advancing Text-to-Music Generation with Question Answering and Captioning}},
  author={Liu, Shansong and Hussain, Atin Sakkeer and Sun, Chenshuo and Shan, Ying},
  journal={arXiv preprint arXiv:2308.11276},
  year={2023}
}

mu-llama's People

Contributors

crypto-code avatar shansongliu avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.