Giter Club home page Giter Club logo

Comments (2)

OlaWod avatar OlaWod commented on July 23, 2024 1

"Word error rate (WER) and character error rate (CER) between source and converted speech", I used the transcriptions of source speech obtained by the ASR model as the ground truth.

get_gt.py

from transformers import Wav2Vec2Processor, HubertForCTC
import os
import argparse
import torch
import librosa
from tqdm import tqdm
from glob import glob

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--txtpath", type=str, default="gt.txt", help="path to tgt txt file")
    parser.add_argument("--wavdir", type=str, default="SOURCE")
    args = parser.parse_args()

    # load model and processor
    model_text = HubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft").cuda()
    processor_text = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
       
    # get transcriptions
    wavs = glob(f'{args.wavdir}/*.wav')
    wavs.sort()
    with open(f"{args.txtpath}", "w") as f:
        for path in tqdm(wavs):
            wav = [librosa.load(path, sr=16000)[0]]
            input_values = processor_text(wav, return_tensors="pt").input_values.cuda() # text
            logits = model_text(input_values).logits
            predicted_ids = torch.argmax(logits, dim=-1)
            text = processor_text.batch_decode(predicted_ids)[0]
            f.write(f"{path}|{text}\n")

wer.py

from transformers import Wav2Vec2Processor, HubertForCTC
import os
import argparse
import torch
import librosa
from tqdm import tqdm
from glob import glob
from jiwer import wer, cer

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--wavdir", type=str, default="PROPOSED")
    parser.add_argument("--outdir", type=str, default="result", help="path to output dir")
    parser.add_argument("--use_cuda", default=False, action="store_true")
    args = parser.parse_args()
    
    os.makedirs(args.outdir, exist_ok=True)

    # load model and processor
    model = HubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft")
    if args.use_cuda:
        model = model.cuda()
    processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
    
    # gt
    gt_dict = {}
    with open("gt.txt", "r") as f:
        for line in f.readlines():
            path, text = line.strip().split("|")
            title = os.path.basename(path)[:-4]
            gt_dict[title] = text
    
    # get transcriptions
    wavs = glob(f'{args.wavdir}/*.wav')
    wavs.sort()
    trans_dict = {}
    
    with open(f"{args.outdir}/text.txt", "w") as f:
        for path in tqdm(wavs):
            wav = [librosa.load(path, sr=16000)[0]]
            input_values = processor(wav, return_tensors="pt").input_values
            if args.use_cuda:
                input_values = input_values.cuda()
            logits = model(input_values).logits
            predicted_ids = torch.argmax(logits, dim=-1)
            text = processor.batch_decode(predicted_ids)[0]
            f.write(f"{path}|{text}\n")
            title = os.path.basename(path)[:-4]
            trans_dict[title] = text
    
    # calc
    gts, trans = [], []
    for key in trans_dict.keys():
        text = trans_dict[key]
        trans.append(text)
        gttext = gt_dict[key.split("-")[0]]
        gts.append(gttext)
    
    wer = wer(gts, trans)
    cer = cer(gts, trans)
    with open(f"{args.outdir}/wer.txt", "w") as f:
        f.write(f"wer: {wer}\n")
        f.write(f"cer: {cer}\n")
    print("WER:", wer)
    print("CER:", cer)

from freevc.

SeongYeonPark avatar SeongYeonPark commented on July 23, 2024

Thank you for answering!

from freevc.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.