diff --git a/vits/models.py b/vits/models.py index cdf11342..0b5a7790 100644 --- a/vits/models.py +++ b/vits/models.py @@ -194,6 +194,7 @@ def forward(self, ppg, pit, spec, spk, ppg_l, spec_l): return audio, ids_slice, spec_mask, (z_f, z_r, z_p, m_p, logs_p, z_q, m_q, logs_q, logdet_f, logdet_r), spk_preds def infer(self, ppg, pit, spk, ppg_l): + ppg = ppg + torch.randn_like(ppg) * 0.0001 # Perturbation z_p, m_p, logs_p, ppg_mask, x = self.enc_p( ppg, ppg_l, f0=f0_to_coarse(pit)) z, _ = self.flow(z_p, ppg_mask, g=spk, reverse=True) @@ -241,10 +242,8 @@ def source2wav(self, source): return self.dec.source2wav(source) def inference(self, ppg, pit, spk, ppg_l, source): - ppg = ppg + torch.randn_like(ppg) * 0.0001 # Perturbation z_p, m_p, logs_p, ppg_mask, x = self.enc_p( ppg, ppg_l, f0=f0_to_coarse(pit)) - z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * 0.7 z, _ = self.flow(z_p, ppg_mask, g=spk, reverse=True) o = self.dec.inference(spk, z * ppg_mask, source) return o diff --git a/vits_decoder/__init__.py b/vits_decoder/__init__.py index e69de29b..986a0cfe 100644 --- a/vits_decoder/__init__.py +++ b/vits_decoder/__init__.py @@ -0,0 +1 @@ +from .alias.act import SnakeAlias \ No newline at end of file diff --git a/vits_decoder/alias/__init__.py b/vits_decoder/alias/__init__.py new file mode 100644 index 00000000..a2318b63 --- /dev/null +++ b/vits_decoder/alias/__init__.py @@ -0,0 +1,6 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +from .filter import * +from .resample import * +from .act import * \ No newline at end of file diff --git a/vits_decoder/alias/act.py b/vits_decoder/alias/act.py new file mode 100644 index 00000000..308344fb --- /dev/null +++ b/vits_decoder/alias/act.py @@ -0,0 +1,129 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torch import sin, pow +from torch.nn import Parameter +from .resample import UpSample1d, DownSample1d + + +class Activation1d(nn.Module): + def __init__(self, + activation, + up_ratio: int = 2, + down_ratio: int = 2, + up_kernel_size: int = 12, + down_kernel_size: int = 12): + super().__init__() + self.up_ratio = up_ratio + self.down_ratio = down_ratio + self.act = activation + self.upsample = UpSample1d(up_ratio, up_kernel_size) + self.downsample = DownSample1d(down_ratio, down_kernel_size) + + # x: [B,C,T] + def forward(self, x): + x = self.upsample(x) + x = self.act(x) + x = self.downsample(x) + + return x + + +class SnakeBeta(nn.Module): + ''' + A modified Snake function which uses separate parameters for the magnitude of the periodic components + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + References: + - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snakebeta(256) + >>> x = torch.randn(256) + >>> x = a1(x) + ''' + + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): + ''' + Initialization. + INPUT: + - in_features: shape of the input + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + alpha is initialized to 1 by default, higher values = higher-frequency. + beta is initialized to 1 by default, higher values = higher-magnitude. + alpha will be trained along with the rest of your model. + ''' + super(SnakeBeta, self).__init__() + self.in_features = in_features + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = Parameter(torch.zeros(in_features) * alpha) + self.beta = Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = Parameter(torch.ones(in_features) * alpha) + self.beta = Parameter(torch.ones(in_features) * alpha) + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + ''' + Forward pass of the function. + Applies the function to the input elementwise. + SnakeBeta = x + 1/b * sin^2 (xa) + ''' + alpha = self.alpha.unsqueeze( + 0).unsqueeze(-1) # line up with x to [B, C, T] + beta = self.beta.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + alpha = torch.exp(alpha) + beta = torch.exp(beta) + x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + return x + + +class Mish(nn.Module): + """ + Mish activation function is proposed in "Mish: A Self + Regularized Non-Monotonic Neural Activation Function" + paper, https://arxiv.org/abs/1908.08681. + """ + + def __init__(self): + super().__init__() + + def forward(self, x): + return x * torch.tanh(F.softplus(x)) + + +class SnakeAlias(nn.Module): + def __init__(self, + channels, + up_ratio: int = 2, + down_ratio: int = 2, + up_kernel_size: int = 12, + down_kernel_size: int = 12): + super().__init__() + self.up_ratio = up_ratio + self.down_ratio = down_ratio + self.act = SnakeBeta(channels, alpha_logscale=True) + self.upsample = UpSample1d(up_ratio, up_kernel_size) + self.downsample = DownSample1d(down_ratio, down_kernel_size) + + # x: [B,C,T] + def forward(self, x): + x = self.upsample(x) + x = self.act(x) + x = self.downsample(x) + + return x \ No newline at end of file diff --git a/vits_decoder/alias/filter.py b/vits_decoder/alias/filter.py new file mode 100644 index 00000000..7ad6ea87 --- /dev/null +++ b/vits_decoder/alias/filter.py @@ -0,0 +1,95 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +if 'sinc' in dir(torch): + sinc = torch.sinc +else: + # This code is adopted from adefossez's julius.core.sinc under the MIT License + # https://adefossez.github.io/julius/julius/core.html + # LICENSE is in incl_licenses directory. + def sinc(x: torch.Tensor): + """ + Implementation of sinc, i.e. sin(pi * x) / (pi * x) + __Warning__: Different to julius.sinc, the input is multiplied by `pi`! + """ + return torch.where(x == 0, + torch.tensor(1., device=x.device, dtype=x.dtype), + torch.sin(math.pi * x) / math.pi / x) + + +# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License +# https://adefossez.github.io/julius/julius/lowpass.html +# LICENSE is in incl_licenses directory. +def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size] + even = (kernel_size % 2 == 0) + half_size = kernel_size // 2 + + #For kaiser window + delta_f = 4 * half_width + A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 + if A > 50.: + beta = 0.1102 * (A - 8.7) + elif A >= 21.: + beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.) + else: + beta = 0. + window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) + + # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio + if even: + time = (torch.arange(-half_size, half_size) + 0.5) + else: + time = torch.arange(kernel_size) - half_size + if cutoff == 0: + filter_ = torch.zeros_like(time) + else: + filter_ = 2 * cutoff * window * sinc(2 * cutoff * time) + # Normalize filter to have sum = 1, otherwise we will have a small leakage + # of the constant component in the input signal. + filter_ /= filter_.sum() + filter = filter_.view(1, 1, kernel_size) + + return filter + + +class LowPassFilter1d(nn.Module): + def __init__(self, + cutoff=0.5, + half_width=0.6, + stride: int = 1, + padding: bool = True, + padding_mode: str = 'replicate', + kernel_size: int = 12): + # kernel_size should be even number for stylegan3 setup, + # in this implementation, odd number is also possible. + super().__init__() + if cutoff < -0.: + raise ValueError("Minimum cutoff must be larger than zero.") + if cutoff > 0.5: + raise ValueError("A cutoff above 0.5 does not make sense.") + self.kernel_size = kernel_size + self.even = (kernel_size % 2 == 0) + self.pad_left = kernel_size // 2 - int(self.even) + self.pad_right = kernel_size // 2 + self.stride = stride + self.padding = padding + self.padding_mode = padding_mode + filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) + self.register_buffer("filter", filter) + + #input [B, C, T] + def forward(self, x): + _, C, _ = x.shape + + if self.padding: + x = F.pad(x, (self.pad_left, self.pad_right), + mode=self.padding_mode) + out = F.conv1d(x, self.filter.expand(C, -1, -1), + stride=self.stride, groups=C) + + return out \ No newline at end of file diff --git a/vits_decoder/alias/resample.py b/vits_decoder/alias/resample.py new file mode 100644 index 00000000..750e6c34 --- /dev/null +++ b/vits_decoder/alias/resample.py @@ -0,0 +1,49 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +import torch.nn as nn +from torch.nn import functional as F +from .filter import LowPassFilter1d +from .filter import kaiser_sinc_filter1d + + +class UpSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.stride = ratio + self.pad = self.kernel_size // ratio - 1 + self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 + self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 + filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, + half_width=0.6 / ratio, + kernel_size=self.kernel_size) + self.register_buffer("filter", filter) + + # x: [B, C, T] + def forward(self, x): + _, C, _ = x.shape + + x = F.pad(x, (self.pad, self.pad), mode='replicate') + x = self.ratio * F.conv_transpose1d( + x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) + x = x[..., self.pad_left:-self.pad_right] + + return x + + +class DownSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio, + half_width=0.6 / ratio, + stride=ratio, + kernel_size=self.kernel_size) + + def forward(self, x): + xx = self.lowpass(x) + + return xx \ No newline at end of file diff --git a/vits_decoder/bigv.py b/vits_decoder/bigv.py index 60dd3bad..029362c3 100644 --- a/vits_decoder/bigv.py +++ b/vits_decoder/bigv.py @@ -1,10 +1,9 @@ import torch -import torch.nn.functional as F import torch.nn as nn -from torch import nn from torch.nn import Conv1d from torch.nn.utils import weight_norm, remove_weight_norm +from .alias.act import SnakeAlias def init_weights(m, mean=0.0, std=0.01): @@ -40,11 +39,20 @@ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): ]) self.convs2.apply(init_weights) + # total number of conv layers + self.num_layers = len(self.convs1) + len(self.convs2) + + # periodic nonlinearity with snakebeta function and anti-aliasing + self.activations = nn.ModuleList([ + SnakeAlias(channels) for _ in range(self.num_layers) + ]) + def forward(self, x): - for c1, c2 in zip(self.convs1, self.convs2): - xt = F.leaky_relu(x, 0.1) + acts1, acts2 = self.activations[::2], self.activations[1::2] + for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2): + xt = a1(x) xt = c1(xt) - xt = F.leaky_relu(xt, 0.1) + xt = a2(xt) xt = c2(xt) x = xt + x return x diff --git a/vits_decoder/discriminator.py b/vits_decoder/discriminator.py index fefd6a14..764c0ca8 100644 --- a/vits_decoder/discriminator.py +++ b/vits_decoder/discriminator.py @@ -2,7 +2,7 @@ import torch.nn as nn from omegaconf import OmegaConf - +from .msd import ScaleDiscriminator from .mpd import MultiPeriodDiscriminator from .mrd import MultiResolutionDiscriminator @@ -12,13 +12,13 @@ def __init__(self, hp): super(Discriminator, self).__init__() self.MRD = MultiResolutionDiscriminator(hp) self.MPD = MultiPeriodDiscriminator(hp) - + self.MSD = ScaleDiscriminator() def forward(self, x): r = self.MRD(x) p = self.MPD(x) - - return r + p + s = self.MSD(x) + return r + p + s if __name__ == '__main__': diff --git a/vits_decoder/generator.py b/vits_decoder/generator.py index 867baacd..2fcf4645 100644 --- a/vits_decoder/generator.py +++ b/vits_decoder/generator.py @@ -1,5 +1,6 @@ import torch import torch.nn as nn +import torch.nn.functional as F import numpy as np from torch.nn import Conv1d @@ -8,7 +9,7 @@ from torch.nn.utils import remove_weight_norm from .nsf import SourceModuleHnNSF -from .bigv import init_weights, AMPBlock +from .bigv import init_weights, AMPBlock, SnakeAlias class SpeakerAdapter(nn.Module): @@ -105,6 +106,7 @@ def __init__(self, hp): self.resblocks.append(AMPBlock(ch, k, d)) # post conv + self.activation_post = SnakeAlias(ch) self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) # weight initialization self.ups.apply(init_weights) @@ -112,15 +114,15 @@ def __init__(self, hp): def forward(self, spk, x, f0): # adapter x = self.adapter(x, spk) + x = self.conv_pre(x) + x = x * torch.tanh(F.softplus(x)) # nsf f0 = f0[:, None] f0 = self.f0_upsamp(f0).transpose(1, 2) har_source = self.m_source(f0) har_source = har_source.transpose(1, 2) - x = self.conv_pre(x) for i in range(self.num_upsamples): - x = nn.functional.leaky_relu(x, 0.1) # upsampling x = self.ups[i](x) # nsf @@ -136,7 +138,7 @@ def forward(self, spk, x, f0): x = xs / self.num_kernels # post conv - x = nn.functional.leaky_relu(x) + x = self.activation_post(x) x = self.conv_post(x) x = torch.tanh(x) return x @@ -172,9 +174,9 @@ def inference(self, spk, x, har_source): # adapter x = self.adapter(x, spk) x = self.conv_pre(x) + x = x * torch.tanh(F.softplus(x)) for i in range(self.num_upsamples): - x = nn.functional.leaky_relu(x, 0.1) # upsampling x = self.ups[i](x) # nsf @@ -190,7 +192,7 @@ def inference(self, spk, x, har_source): x = xs / self.num_kernels # post conv - x = nn.functional.leaky_relu(x) + x = self.activation_post(x) x = self.conv_post(x) x = torch.tanh(x) return x diff --git a/vits_decoder/med.py b/vits_decoder/med.py new file mode 100644 index 00000000..77554d3c --- /dev/null +++ b/vits_decoder/med.py @@ -0,0 +1,65 @@ +import torch +import torchaudio +import typing as T + + +class MelspecDiscriminator(torch.nn.Module): + """mel spectrogram (frequency domain) discriminator""" + + def __init__(self) -> None: + super().__init__() + self.SAMPLE_RATE = 48000 + # mel filterbank transform + self._melspec = torchaudio.transforms.MelSpectrogram( + sample_rate=self.SAMPLE_RATE, + n_fft=2048, + win_length=int(0.025 * self.SAMPLE_RATE), + hop_length=int(0.010 * self.SAMPLE_RATE), + n_mels=128, + power=1, + ) + + # time-frequency 2D convolutions + kernel_sizes = [(7, 7), (4, 4), (4, 4), (4, 4)] + strides = [(1, 2), (1, 2), (1, 2), (1, 2)] + self._convs = torch.nn.ModuleList( + [ + torch.nn.Sequential( + torch.nn.Conv2d( + in_channels=1 if i == 0 else 32, + out_channels=64, + kernel_size=k, + stride=s, + padding=(1, 2), + bias=False, + ), + torch.nn.BatchNorm2d(num_features=64), + torch.nn.GLU(dim=1), + ) + for i, (k, s) in enumerate(zip(kernel_sizes, strides)) + ] + ) + + # output adversarial projection + self._postnet = torch.nn.Conv2d( + in_channels=32, + out_channels=1, + kernel_size=(15, 3), + stride=(1, 2), + ) + + def forward(self, x: torch.Tensor) -> T.Tuple[torch.Tensor, T.List[torch.Tensor]]: + # apply the log-scale mel spectrogram transform + x = torch.log(self._melspec(x) + 1e-5) + + # compute hidden layers and feature maps + f = [] + for c in self._convs: + x = c(x) + f.append(x) + + # apply the output projection and global average pooling + x = self._postnet(x) + x = x.mean(dim=[-2, -1]) + + return [(f, x)] diff --git a/vits_decoder/nsf.py b/vits_decoder/nsf.py index 473df69f..1e9e6c7e 100644 --- a/vits_decoder/nsf.py +++ b/vits_decoder/nsf.py @@ -364,7 +364,7 @@ def __init__( voiced_threshod=0, ): super(SourceModuleHnNSF, self).__init__() - harmonic_num = 8 + harmonic_num = 10 self.sine_amp = sine_amp self.noise_std = add_noise_std @@ -376,9 +376,9 @@ def __init__( # to merge source harmonics into a single excitation self.l_tanh = torch.nn.Tanh() self.register_buffer('merge_w', torch.FloatTensor([[ - -0.1044, -0.4892, -0.4733, 0.4337, -0.2321, - -0.1889, 0.1315, -0.1002, 0.0590,]])) - self.register_buffer('merge_b', torch.FloatTensor([-0.2908])) + 0.2942, -0.2243, 0.0033, -0.0056, -0.0020, -0.0046, + 0.0221, -0.0083, -0.0241, -0.0036, -0.0581]])) + self.register_buffer('merge_b', torch.FloatTensor([0.0008])) def forward(self, x): """ diff --git a/vits_extend/train.py b/vits_extend/train.py index a5053dc4..3abd8e31 100644 --- a/vits_extend/train.py +++ b/vits_extend/train.py @@ -149,7 +149,7 @@ def train(rank, args, chkpt_path, hp, hp_str): scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hp.train.lr_decay, last_epoch=init_epoch-2) stft_criterion = MultiResolutionSTFTLoss(device, eval(hp.mrd.resolutions)) - vpr_loss = nn.CosineEmbeddingLoss() + spkc_criterion = nn.CosineEmbeddingLoss() trainloader = create_dataloader_train(hp, args.num_gpus, rank) @@ -191,7 +191,7 @@ def train(rank, args, chkpt_path, hp, hp_str): audio = commons.slice_segments( audio, ids_slice * hp.data.hop_length, hp.data.segment_size) # slice # Spk Loss - spk_loss = vpr_loss(spk, spk_preds, torch.Tensor(spk_preds.size(0)) + spk_loss = spkc_criterion(spk, spk_preds, torch.Tensor(spk_preds.size(0)) .to(device).fill_(1.0)) # Mel Loss mel_fake = stft.mel_spectrogram(fake_audio.squeeze(1)) @@ -223,7 +223,7 @@ def train(rank, args, chkpt_path, hp, hp_str): loss_kl_r = kl_loss(z_r, logs_p, m_q, logs_q, logdet_r, z_mask) * hp.train.c_kl # Loss - loss_g = score_loss + feat_loss + mel_loss + stft_loss + loss_kl_f + loss_kl_r * 0.5 + spk_loss * 0.5 + loss_g = score_loss + feat_loss + mel_loss + stft_loss + loss_kl_f + loss_kl_r * 0.5 + spk_loss * 2 loss_g.backward() clip_grad_value_(model_g.parameters(), None) optim_g.step()