Giter Club home page Giter Club logo

aero's Introduction

AERO

Audio Super Resolution in the Spectral Domain

Checkpoint files are available! Details below.

Requirements

Install requirements specified in requirements.txt:
pip install -r requirments.txt

We ran our code on CUDA/11.3, we therefore installed pytorch/torchvision/torchaudio with the following:

pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113

Our code uses hydra to set parameters to different experiments.

ViSQOL

If you want to run code without using ViSQOL, set visqol: False in file: conf/main_config.yaml.

In order to evaluate model output with the ViSQOL metric, one first needs to install Bazel and then ViSQOL.
In our code, we use ViSQOL via its command line API by using a Python subprocess.

Build Bazel and ViSQOL following directions from here.

Add the absolute path of the root directory of ViSQOL (where the WORKSPACE file is), to the visqol path parameter in main_config.yaml.

Preparing Data

Resample data

Data are a collection of high/low resolution pairs. Corresponding high and low resolution signals should be in different folders.

In order to create each folder, one should run resample_data twice: for low and high resolution.

E.g. for 4 and 16 kHz:
python data_prep/resample_data.py --data_dir <path for 48 kHz data> --out_dir <path for 4 kHz data> --target_sr 4
python data_prep/resample_data.py --data_dir <path for 48 kHz data> --out_dir <path for 16 kHz data> --target_sr 16

Create egs files

For each low and high resolution pair, one should create "egs files" twice: for low and high resolution.
create_meta_files.py creates a pair of train and val "egs files", each under its respective folder. Each "egs file" contains meta information about the signals: paths and signal lengths.

e.g. to create egs files for the various speech settings:

python data_prep/create_meta_files.py <path for 4 kHz data> egs/vctk/4-16 lr
python data_prep/create_meta_files.py <path for 16 kHz data> egs/vctk/4-16 hr

python data_prep/create_meta_files.py <path for 8 kHz data> egs/vctk/8-16 lr
python data_prep/create_meta_files.py <path for 16 kHz data> egs/vctk/8-16 hr

python data_prep/create_meta_files.py <path for 8 kHz data> egs/vctk/8-24 lr
python data_prep/create_meta_files.py <path for 24 kHz data> egs/vctk/8-24 hr

python data_prep/create_meta_files.py <path for 12 kHz data> egs/vctk/12-48 lr
python data_prep/create_meta_files.py <path for 46 kHz data> egs/vctk/12-48 hr

Creating dummy egs files (for debugging code)

If you want to create dummy egs files for debugging code on small number of samples. (This might be a little buggy, make sure that the same files exist in high/low resolution meta (egs) files)

python data_prep/create_meta_files.py <path for 4 kHz data> egs/vctk/4-16 lr --n_samples_limit=32
python data_prep/create_meta_files.py <path for 16 kHz data> egs/vctk/4-16 hr --n_samples_limit=32

Train

Run train.py with dset and experiment parameters.
(make sure that the parameters lr_sr, hr_sr in the experiment comply with the sample rates of the dataset).

e.g. for upsampling from 4kHz to 16kHz, with n_fft=512 and hop_length=64:

python train.py dset=4-16 experiment=aero_4-16_512_64

To train with multiple GPUs, run with parameter ddp=true. e.g.

python train.py dset=4-16 experiment=aero_4-16_512_64 ddp=true

Test (on whole dataset)

  • Make sure to create appropriate egs files for specific LR to HR setting
    • e.g. for 4-16:
      python data_prep/create_meta_files.py <path for 4 kHz data> egs/vctk/4-16 lr
      python data_prep/create_meta_files.py <path for 16 kHz data> egs/vctk/4-16 hr
  • Create a directory with experiment name in the format: aero-nfft=<NFFT>-hl=<HOP_LENGTH> (e.g. aero-nfft=512-hl=64)
  • Copy/download appropriate checkpoint.th file to directory (make sure that the corresponding nfft,hop_length parameters correspond to experiment file)
  • Run python test.py dset=<LR>-<HR> experiment=aero_<LR>-<HR>_<NFFT>_<HOP_LENGTH>

e.g. for upsampling from 4kHz to 16kHz, with n_fft=512 and hop_length=64:

python test.py \
  dset=4-16 \
  experiment=aero_4-16_512_64

Predict (on single sample)

  • Copy/download appropriate checkpoint.th file to directory (make sure that the corresponding nfft,hop_length parameters correspond to experiment file)
  • Run predict.py with appending new filename and output parameters via hydra framework, corresponding to the input file and output directory respectively.

e.g. for upsampling from 4kHz to 16kHz, with n_fft=512 and hop_length=64:

python predict.py \
  dset=4-16 \
  experiment=aero_4-16_512_64 \
  +filename=<absolute path to input file> \
  +output=<absolute path to output directory>

Checkpoints

To use pre-trained models, one can download checkpoints from here.

To link to checkpoint when testing or predicting, override/set path under checkpoint_file:<path> in conf/main_config.yaml.
e.g.

python test.py \
  dset=4-16 \
  experiment=aero_4-16_512_64 \
  +checkpoint_file=<path to appropriate checkpoint.th file>

Alternatively, make sure that the checkpoint file is in its corresponding output folder:
For each low to high resolution setting, hydra creates a folder under outputs/: lr-hr (e.g. outputs/4-16), under each such folder hydra creates a folder with the experiment name and n_fft and hop_length hyper-paremers (e.g. aero-nfft=512-hl=256). Make sure that each checkpoint exists beforehand in appropriate output folder, if you download the outputs folder and place it under the root directory (which contains train.py and /src), it should retain the appropriate structure and no renaming should be necessary (make sure that restart: false in conf/main_config.yaml)

aero's People

Contributors

m-mandel avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.