Comments (12)
train.py
import os
import json
import argparse
import itertools
import math
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
#import torch.multiprocessing as mp
#import torch.distributed as dist
#from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import autocast, GradScaler
import commons
import utils
from data_utils import (
TextAudioSpeakerLoader,
TextAudioSpeakerCollate,
DistributedBucketSampler
)
from models import (
SynthesizerTrn,
MultiPeriodDiscriminator,
)
from losses import (
generator_loss,
discriminator_loss,
feature_loss,
kl_loss
)
from mel_processing import mel_spectrogram_torch, spec_to_mel_torch
torch.backends.cudnn.benchmark = True
global_step = 0
#os.environ['TORCH_DISTRIBUTED_DEBUG'] = 'INFO'
def main():
"""Assume Single Node Multi GPUs Training Only"""
assert torch.cuda.is_available(), "CPU training is not allowed."
hps = utils.get_hparams()
n_gpus = torch.cuda.device_count()
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = hps.train.port
#mp.spawn(run, nprocs=n_gpus, args=(n_gpus, hps,))
run(hps)
#def run(rank, n_gpus, hps):
def run(hps, rank=0):
global global_step
if rank == 0:
logger = utils.get_logger(hps.model_dir)
logger.info(hps)
utils.check_git_hash(hps.model_dir)
writer = SummaryWriter(log_dir=hps.model_dir)
writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval"))
#dist.init_process_group(backend='nccl', init_method='env://', world_size=n_gpus, rank=rank)
torch.manual_seed(hps.train.seed)
torch.cuda.set_device(rank)
train_dataset = TextAudioSpeakerLoader(hps.data.training_files, hps)
train_sampler = DistributedBucketSampler(
train_dataset,
hps.train.batch_size,
[32,300,400,500,600,700,800,900,1000],
num_replicas=1,
rank=rank,
shuffle=True)
collate_fn = TextAudioSpeakerCollate(hps)
train_loader = DataLoader(train_dataset, num_workers=8, shuffle=False, pin_memory=True,
collate_fn=collate_fn, batch_sampler=train_sampler)
if rank == 0:
eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps)
eval_loader = DataLoader(eval_dataset, num_workers=8, shuffle=True,
batch_size=hps.train.batch_size, pin_memory=False,
drop_last=False, collate_fn=collate_fn)
net_g = SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
**hps.model).cuda(rank)
net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank)
optim_g = torch.optim.AdamW(
net_g.parameters(),
hps.train.learning_rate,
betas=hps.train.betas,
eps=hps.train.eps)
optim_d = torch.optim.AdamW(
net_d.parameters(),
hps.train.learning_rate,
betas=hps.train.betas,
eps=hps.train.eps)
#net_g = DDP(net_g, device_ids=[rank])#, find_unused_parameters=True)
#net_d = DDP(net_d, device_ids=[rank])
try:
_, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g)
_, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d, optim_d)
global_step = (epoch_str - 1) * len(train_loader)
except:
epoch_str = 1
global_step = 0
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str-2)
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str-2)
scaler = GradScaler(enabled=hps.train.fp16_run)
for epoch in range(epoch_str, hps.train.epochs + 1):
if rank==0:
train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, [train_loader, eval_loader], logger, [writer, writer_eval])
else:
train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, [train_loader, None], None, None)
scheduler_g.step()
scheduler_d.step()
def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers):
net_g, net_d = nets
optim_g, optim_d = optims
scheduler_g, scheduler_d = schedulers
train_loader, eval_loader = loaders
if writers is not None:
writer, writer_eval = writers
#train_loader.batch_sampler.set_epoch(epoch)
global global_step
net_g.train()
net_d.train()
for batch_idx, items in enumerate(train_loader):
if hps.model.use_spk:
c, spec, y, spk = items
g = spk.cuda(rank, non_blocking=True)
else:
c, spec, y = items
g = None
spec, y = spec.cuda(rank, non_blocking=True), y.cuda(rank, non_blocking=True)
c = c.cuda(rank, non_blocking=True)
mel = spec_to_mel_torch(
spec,
hps.data.filter_length,
hps.data.n_mel_channels,
hps.data.sampling_rate,
hps.data.mel_fmin,
hps.data.mel_fmax)
with autocast(enabled=hps.train.fp16_run):
y_hat, ids_slice, z_mask,\
(z, z_p, m_p, logs_p, m_q, logs_q) = net_g(c, spec, g=g, mel=mel)
y_mel = commons.slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length)
y_hat_mel = mel_spectrogram_torch(
y_hat.squeeze(1),
hps.data.filter_length,
hps.data.n_mel_channels,
hps.data.sampling_rate,
hps.data.hop_length,
hps.data.win_length,
hps.data.mel_fmin,
hps.data.mel_fmax
)
y = commons.slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice
# Discriminator
y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
with autocast(enabled=False):
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g)
loss_disc_all = loss_disc
optim_d.zero_grad()
scaler.scale(loss_disc_all).backward()
scaler.unscale_(optim_d)
grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
scaler.step(optim_d)
with autocast(enabled=hps.train.fp16_run):
# Generator
y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
with autocast(enabled=False):
loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
loss_fm = feature_loss(fmap_r, fmap_g)
loss_gen, losses_gen = generator_loss(y_d_hat_g)
loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl
optim_g.zero_grad()
scaler.scale(loss_gen_all).backward()
scaler.unscale_(optim_g)
grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
scaler.step(optim_g)
scaler.update()
if rank==0:
if global_step % hps.train.log_interval == 0:
lr = optim_g.param_groups[0]['lr']
losses = [loss_disc, loss_gen, loss_fm, loss_mel, loss_kl]
logger.info('Train Epoch: {} [{:.0f}%]'.format(
epoch,
100. * batch_idx / len(train_loader)))
logger.info([x.item() for x in losses] + [global_step, lr])
scalar_dict = {"loss/g/total": loss_gen_all, "loss/d/total": loss_disc_all, "learning_rate": lr, "grad_norm_d": grad_norm_d, "grad_norm_g": grad_norm_g}
scalar_dict.update({"loss/g/fm": loss_fm, "loss/g/mel": loss_mel, "loss/g/kl": loss_kl})
scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)})
scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)})
scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)})
image_dict = {
"slice/mel_org": utils.plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()),
"slice/mel_gen": utils.plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()),
"all/mel": utils.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()),
}
utils.summarize(
writer=writer,
global_step=global_step,
images=image_dict,
scalars=scalar_dict)
if global_step % hps.train.eval_interval == 0:
evaluate(hps, net_g, eval_loader, writer_eval)
utils.save_checkpoint(net_g, optim_g, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "G_{}.pth".format(global_step)))
utils.save_checkpoint(net_d, optim_d, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "D_{}.pth".format(global_step)))
global_step += 1
if rank == 0:
logger.info('====> Epoch: {}'.format(epoch))
def evaluate(hps, generator, eval_loader, writer_eval):
generator.eval()
with torch.no_grad():
for batch_idx, items in enumerate(eval_loader):
if hps.model.use_spk:
c, spec, y, spk = items
g = spk[:1].cuda(0)
else:
c, spec, y = items
g = None
spec, y = spec[:1].cuda(0), y[:1].cuda(0)
c = c[:1].cuda(0)
break
mel = spec_to_mel_torch(
spec,
hps.data.filter_length,
hps.data.n_mel_channels,
hps.data.sampling_rate,
hps.data.mel_fmin,
hps.data.mel_fmax)
#y_hat = generator.module.infer(c, g=g, mel=mel)
y_hat = generator.infer(c, g=g, mel=mel)
y_hat_mel = mel_spectrogram_torch(
y_hat.squeeze(1).float(),
hps.data.filter_length,
hps.data.n_mel_channels,
hps.data.sampling_rate,
hps.data.hop_length,
hps.data.win_length,
hps.data.mel_fmin,
hps.data.mel_fmax
)
image_dict = {
"gen/mel": utils.plot_spectrogram_to_numpy(y_hat_mel[0].cpu().numpy()),
"gt/mel": utils.plot_spectrogram_to_numpy(mel[0].cpu().numpy())
}
audio_dict = {
"gen/audio": y_hat[0],
"gt/audio": y[0]
}
utils.summarize(
writer=writer_eval,
global_step=global_step,
images=image_dict,
audios=audio_dict,
audio_sampling_rate=hps.data.sampling_rate
)
generator.train()
if __name__ == "__main__":
main()
data_utils.py
...
#class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
class DistributedBucketSampler(torch.utils.data.Sampler):
"""
Maintain similar input lengths in a batch.
Length groups are specified by boundaries.
Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}.
It removes samples which are not included in the boundaries.
Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded.
"""
def __init__(self, dataset, batch_size, boundaries, num_replicas=None, rank=None, shuffle=True):
#super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
super().__init__(dataset)
self.num_replicas=1
self.rank=0
self.shuffle=True
self.lengths = dataset.lengths
self.batch_size = batch_size
self.boundaries = boundaries
self.buckets, self.num_samples_per_bucket = self._create_buckets()
self.total_size = sum(self.num_samples_per_bucket)
self.num_samples = self.total_size // self.num_replicas
...
def __iter__(self):
# deterministically shuffle based on epoch
g = torch.Generator()
#g.manual_seed(self. Epoch)
...
from freevc.
CPU fairly level at 60%, random spikes to 80% for 10-20 seconds.
from freevc.
You can try lower num_workers, maybe?
from freevc.
Even lowering it to 1 has no effect on cpu load.
from freevc.
https://pytorch.org/docs/stable/distributed.html?highlight=distributed#module-torch.distributed
gloo backend burdens cpu.
from freevc.
I guess that supports my theory of the distributed compute being the problem.
Would you be able to post a training script with distributed compute removed for ppl using a single GPU or running on windows? Removing that is outside of my abilities, im not a coder.
from freevc.
Code works, speed is up from 1.5 steps/s to 3 steps/s. :)
Gpu load is lower though, only 46% utilization.
from freevc.
Got this error after roughly 200 steps
File "train2.py", line 293, in <module> main() File "train2.py", line 50, in main run(hps) File "train2.py", line 119, in run train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, [train_loader, eval_loader], logger, [writer, writer_eval]) File "train2.py", line 173, in train_and_evaluate y = commons.slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice File "Z:\FreeVCtrain\commons.py", line 53, in slice_segments ret[i] = x[i, :, idx_str:idx_end] RuntimeError: The expanded size of the tensor (8000) must match the existing size (320) at non-singleton dimension 1. Target sizes: [1, 8000]. Tensor sizes: [320]
Looked over the file and I couldn't spot any errors.
I have also been training at batch size 16, as that is all that was able to fit into the 11gb vram. I saw the Vram useage is a lot lower now with the distributed compute removed so I tried increasing the batch size in hopes of increasing the gpu load. Upped the batch all the way back to 64, but no change in vram. I dont think its actually using the setting for some reason.
from freevc.
Modified.
The batch sampler will discard samples that are too short, so not using this will cause that error.
from freevc.
Tried the changes, Got this error
File "Z:\FreeVCtrain\data_utils.py", line 246, in iter
for bucket in self.buckets:
AttributeError: 'DistributedBucketSampler' object has no attribute 'buckets'
found the problem, error is on line 214
self. Buckets, self.num_samples_per_bucket = self._create_buckets()
changed it to
self.buckets, self.num_samples_per_bucket = self._create_buckets()
Its running now
from freevc.
...my chrome typo adjustment extension automatically adjusted the code I pasted.
from freevc.
Ahh... Yeah Just ran into another one which I saw you also fixed. Just recopied the entire code to be sure.
Thank you so much for doing all this!! I really appreciate it!
from freevc.
Related Issues (20)
- Condition decoder on desired output length to have control over speech rate in inference?
- 基于您现有的模型使用aishell3训练,大概要训练多久,作者有试过吗
- Unseen Male to Male results in Female output HOT 1
- 音色转换程度不一致
- Epoch duration
- 关于算法的类型 HOT 1
- 训练了500个epoch,按照freevc.json配置进行训练,无论wav_tgt使用何种音色,测试出来的音色都是同一个?
- Changing batch size to 16 or 32
- poor performance on seen-to-unseen task while finetuning on Hindi language HOT 2
- 2023.01.10 update: code below can deteriorate model performance HOT 3
- Vocoder version
- Fine tuning with custom (multilingual) data HOT 1
- How to start inference example? HOT 1
- 关于训练问题
- target pitch issue after training (not appearing if using the pretrained checkpoint) HOT 1
- Config file for the FreeVC-24 checkpoint HOT 1
- training a model with 44.1k data
- Why is the speaker embedding g used to condition the Posterior Encoder and the Decoder?
- Poor results with: voice_conversion_models--multilingual--vctk--freevc24.zip CoquiTTS
- About checkpoint HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from freevc.