Giter Club home page Giter Club logo

contrastive-predictive-coding-pytorch's Introduction

Contrastive-Predictive-Coding-PyTorch

This repository contains (PyTorch) code to reproduce the core results for:

If you find the code useful, please cite

@article{lai2019contrastive,
  title={Contrastive Predictive Coding Based Feature for Automatic Speaker Verification},
  author={Lai, Cheng-I},
  journal={arXiv preprint arXiv:1904.01575},
  year={2019}
}

Getting Started

./src/model/model.py contains the CPC models implementation, ./src/main.py is the code for training the CPC models, ./src/spk_class.py trains a NN speaker classifier, ./ivector/ contains the scripts for running an i-vectors speaker verification system.

An example of CPC and speaker classifier training can be found at

./run.sh

CPC Models

CDCK2: base model from the paper 'Representation Learning with Contrastive Predictive Coding'.
CDCK5: CDCK2 with a different decoder.
CDCK6: CDCK2 with a shared encoder and double decoders.

Experimental Results

A. CPC Model Training

CPC model ID number of epoch model size dev NCE loss dev acc.
CDCK2 60 7.42M 1.6427 26.42
CDCK5 60 5.58M 1.7818 22.48
CDCK6 30 7.33M 1.6484 28.24

B. Speaker Verificaiton on LibriSpeech test-clean-100 (Average Pooling)

Note: 1st trial list and 2nd trial list

Feature Feature Dim Summarization LDA Dim 1st EER 2nd EER
MFCC 24 average pooling 24 9.211 13.48
CDCK2 256 average pooling 200 5.887 11.1
CDCK5 40 average pooling 40 7.508 12.25
CDCK6 256 average pooling 200 6.809 12.73

C. CPC applied with PCA

Feature w PCA Original Feature PCA Dim PCA Variance Ratio
CDCK2-36 CDCK2 36 76.76
CDCK2-60 CDCK2 60 87.40
CDCK5-24 CDCK5 24 93.39

D. Speaker Verificaiton on LibriSpeech test-clean-100 (i-vectors)

Feature Feature Dim Summarization 1st EER 2nd EER
MFCC 24 i-vectors 5.518 8.157
CDCK2-60 60 i-vectors 5.351 9.753
CDCK5-24 24 i-vectors 4.911 8.901
CDCK6-60 60 i-vectors 5.228 9.009
MFCC + CDCK2-36 60 i-vectors 3.62 6.898
MFCC + CDCK5-24 48 i-vectors 3.712 6.962
MFCC + CDCK6-36 60 i-vectors 3.691 6.765

E. DET Curves of CPC and MFCC Fusion for i-vectors Speaker Verification

Authors

Cheng-I Lai.

If you encouter any problem, feel free to contact me.

contrastive-predictive-coding-pytorch's People

Contributors

jefflai108 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

contrastive-predictive-coding-pytorch's Issues

Feed entire input to encoder??

I see in your implementation that you feed entire signal into the encoder,
while the paper has noted that each timestemp should be insert seperatly.
When you feed the entire signal into the encoder, you get some overlapping features with the Conv kernel (except for the case that the stride equal to the kernel size).

Why did you implement like that? do you think it does not matter ?

Thanks!

Can you provide the train & test dataset?

Thanks for your sharing of CPC code.
I read the code and found that the provided Dataset class reads .h5 files. From open ASR website and the information provided from the paper, I can only download those files with extension .frac or .txt.
Can you explicitly explain the configuration of your dataset?

What is the format of "list" file?

I saw list files such as "LibriSpeech/list/train.txt" are required parameters for main.py. It seems such files are not provided by librispeech officially. What is the format of them? Could you provide them or the script to generate them?

Threre might some wrong in validation.py

Hi , Thank you again for this share coding.
I found something might wrong in validation.py.
When you doing validation, initialing GRU hidden again, this might cause validation loss in log is more than itself. And since it intials GRU hideen every epoch, I think it might impair the performance slightly.

How combine MFCC and CPCfeatures

Thank you for sharing your code, I have meet some problem.
When we use CPC, it is [128,256] but mfcc is [frame,39],
as you result, I wonder how to combine it in [frame, 39 + 256] dims.
Thanks again

Some Trouble in Understanding

I had some trouble to understand the realization of infoNCE loss function. I don't understand the How torch.diag() could represent infoNCE loss.

Softmax uses by default dimension 1

In the calculation of the NCE loss, the softmax does not have a dimension to compute the result and by default, PyTorch uses dim=1 with 2D input.

The Loss in the paper highlights that the c_t (context) remains constant, and we 'match' this context to the actual values of z_t. By using dim=1 instead of dim=0 we actually compute the 'match' between a constant z_t and c_ts that are generated by each example in the batch.

The softmax should be performed on the columns of the 8x8 matrix to capture the true loss function defined in the CPC paper.

Use of Batch Normalisation

The paper does not mention the use of Batch Normalization in the case of the audio task.

In the case of the Vision task, it mentions that '' We did not use Batch-Norm [38]."

The implementation of loss might be wrong

https://arxiv.org/pdf/1807.03748.pdf
If you look at equation 4 from the paper, the log softmax would be over N-1 negative samples and 1 positive sample. From your implementation, the N-1 negative samples are actually self.time_step-1. Taking log_softmax over batch seems wrong. We switched it to log_softmax over time and training is more stable and accuracy has gone up for our toy dataset. However that is only a partial fix.

how to format h5 files for input?

Hi, thanks for sharing your implementation of CPC. I've been trying to run it out of the box but am having issues shaping the input data correctly. Is there another script that encodes the wav file directories into .h5?

Second last tilmestep as the c_t in the baseline model?

At Line 310, you have the following code

output, hidden = self.gru(forward_seq, hidden) # output size e.g. 8*100*256
c_t = output[:,t_samples,:].view(batch, 256) # c_t e.g. size 8*256

So you are using the second last timestep as c_t? Since the last timestep should be output[:,t_samples+1,:], or just simply hidden.

As far as I understand from the original paper, c_t should be the last timestep. Am I missing anything here?

Train and test data not available

Dear Jeff,

Thank you so much for providing this great repository! Sincerely appreciate your great implementation!

However, after reading all the closed issues and trying out for initializing the training, I am still a bit confused about the training and test dataset. I try to run run.sh and the following error reported:
Screenshot from 2020-01-30 14-42-21

May I request what might be the possible solution of this? Thank you so much for your clarification!

Sincerely,
Martin

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.