Giter Club home page Giter Club logo

multi-band-wavernn's People

Contributors

yanggeng1995 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

multi-band-wavernn's Issues

你好,询问一下Multi-band-WaveRNN对于多频带的输入输出

你好!想请教一下!Multi-band-WaveRNN对于多频带的输入输出
我使用的是fatchord_version.的wavernn,代码放在我的仓库中。但是生成不出有意义的声音,我怀疑模型输入的时候mel condition出了问题。想请教一下您:
我是将:四个频带音频x维度为(batch,T,4),mel为(batch,T,80),上采样得到的残差边为(batch,T,32)。将他们cat后传入网络中,因此网络不同之处是输入音频x由(batch,T,1)变为(batch,T,4),同时输出使用四个全连接层生成四个频带音频,并且得到四个频带的loss相加并反向传播。
我查看了论文DURIAN中图,输入特征四个频带的音频似乎是也是在高维上(batch,T,4)扩展而不是时间步上扩展成(batch,4T,1),想向您证实并交流一下,谢谢!

训练部分:

   for i, (x, y, m) in enumerate(train_set, 1):
        x, m, y = x.to(device), m.to(device), y.to(device)  # x/y: (Batch, sub_bands, T)

######################### MultiBand-WaveRNN #########################
if hp.voc_multiband:
y0 = y[:, 0, :].squeeze(0).unsqueeze(-1) # y0/y1/y2/y3: (Batch, T, 1)
y1 = y[:, 1, :].squeeze(0).unsqueeze(-1)
y2 = y[:, 2, :].squeeze(0).unsqueeze(-1)
y3 = y[:, 3, :].squeeze(0).unsqueeze(-1)

            y_hat = model(x, m)  # (Batch, T, num_classes, sub_bands)

            if model.mode == 'RAW':
                y_hat0 = y_hat[:, :, :, 0].transpose(1,2).unsqueeze(-1)  # (Batch, num_classes, T, 1)
                y_hat1 = y_hat[:, :, :, 1].transpose(1,2).unsqueeze(-1)
                y_hat2 = y_hat[:, :, :, 2].transpose(1,2).unsqueeze(-1)
                y_hat3 = y_hat[:, :, :, 3].transpose(1,2).unsqueeze(-1)

            elif model.mode == 'MOL':
                y0 = y0.float()
                y1 = y1.float()
                y2 = y2.float()
                y3 = y3.float()

            loss = loss_func(y_hat0, y0) + loss_func(y_hat1, y1) + loss_func(y_hat2, y2) + loss_func(y_hat3, y3)

模型结构:

def forward(self, x, mels):  # x: (Batch, Subband, T)
    device = next(self.parameters()).device  # use same device as parameters

    # Although we `_flatten_parameters()` on init, when using DataParallel
    # the model gets replicated, making it no longer guaranteed that the
    # weights are contiguous in GPU memory. Hence, we must call it again
    self._flatten_parameters()

    self.step += 1
    bsize = x.size(0)
    h1 = torch.zeros(1, bsize, self.rnn_dims, device=device)
    h2 = torch.zeros(1, bsize, self.rnn_dims, device=device)
    mels, aux = self.upsample(mels)

    aux_idx = [self.aux_dims * i for i in range(5)]
    a1 = aux[:, :, aux_idx[0]:aux_idx[1]]
    a2 = aux[:, :, aux_idx[1]:aux_idx[2]]
    a3 = aux[:, :, aux_idx[2]:aux_idx[3]]
    a4 = aux[:, :, aux_idx[3]:aux_idx[4]]
    # x = torch.cat([x.unsqueeze(-1), mels, a1], dim=2)

    x = torch.cat([x.transpose(1,2), mels, a1], dim=2)  # (batch,T,4)  (batch,T,80) (batch,T,32)
    x = self.I(x)  # (batch, T, 116) -> # (batch, T, 512)
    res = x
    # x, _ = self.rnn1(x, h1)
    x, _ = self.rnn1(x)   # 不加入隐藏层-Begee  # (batch, T, 512) -> (batch, T, 512)
    x = x + res    # (batch, T, 512)
    res = x
    x = torch.cat([x, a2], dim=2)  # (batch, T, 512) -> (batch, T, 512+128)
    # x, _ = self.rnn2(x, h2)
    x, _ = self.rnn2(x)  # 不加入隐藏层-Begee  (batch, T, 512+128) -> (batch, T, 512)

    x = x + res
    x = torch.cat([x, a3], dim=2) # (batch, T, 512+128)
    x = F.relu(self.fc1(x)) # (batch, T, 512+128) -> (batch, T, 512)

    x = torch.cat([x, a4], dim=2)  # (batch, T, 512+128)
    x = F.relu(self.fc2(x))  # (batch, T, 512+128) -> (batch, T, 512)

    out0 = self.fc30(x).unsqueeze(-1) # (batch, T, 512) -> (batch, T, 512)
    out1 = self.fc31(x).unsqueeze(-1)
    out2 = self.fc32(x).unsqueeze(-1)
    out3 = self.fc33(x).unsqueeze(-1)
    out = torch.cat([out0,out1,out2,out3], dim=3)  # (B, T, num_classes, sub_band)
    return out

感谢您的解答!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.