Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions configs/ft_hifigan.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ data_input_path: []
data_out_path: []
val_num: 10

add_watermark: false
lambda_wm: 1
watermark_loss_tau: 0.5
watermark_loss_alpha: 0.1

pe: 'parselmouth' # 'parselmouth' or 'harvest'
f0_min: 65
f0_max: 1100
Expand Down
57 changes: 57 additions & 0 deletions models/wm_d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import torch
import torch.nn as nn
import torchaudio
from perth.perth_net.perth_net_implicit.perth_watermarker import PerthImplicitWatermarker


class PerthDetectorDiscriminator(nn.Module):
def __init__(self, audio_sr: int, device: str = "cpu", loss_tau: float = 0.75, loss_alpha: float = 0.1, **perth_kwargs):
super().__init__()
self.watermarker = PerthImplicitWatermarker(device=device, **perth_kwargs)

self.perth_net = self.watermarker.perth_net
self.audio_sr = audio_sr
self.perth_sr = self.perth_net.hp.sample_rate

self.perth_net.eval()
for param in self.perth_net.parameters():
param.requires_grad = False
if self.audio_sr != self.perth_sr:
self.resampler = torchaudio.transforms.Resample(
orig_freq=self.audio_sr,
new_freq=self.perth_sr,
resampling_method="sinc_interp_hann"
)
else:
self.resampler = nn.Identity()

self.tau = loss_tau
self.alpha = loss_alpha

def forward(self, audio_tensor: torch.Tensor) -> torch.Tensor:
"""
Args:
audio_tensor (torch.Tensor): shape (B, 1, T) or (B, T).
Returns:
torch.Tensor: shape (B,).
"""
if audio_tensor.dim() == 3:
audio_tensor = audio_tensor.squeeze(1)
resampled_audio = self.resampler(audio_tensor)
magspec, _phase = self.perth_net.ap.signal_to_magphase(resampled_audio)
wmark_pred_vector = self.perth_net.decoder(magspec)
wmark_pred_vector = wmark_pred_vector.clamp(0.0, 1.0)
if wmark_pred_vector.dim() > 1:
confidence = torch.mean(wmark_pred_vector, dim=list(range(1, wmark_pred_vector.dim())))
else:
confidence = wmark_pred_vector

return confidence

def loss(self, confidence_scores: torch.Tensor) -> torch.Tensor:
is_below_tau = (confidence_scores < self.tau)
strong_term = (self.tau - confidence_scores).pow(2)
weak_term = self.alpha * (1.0 - confidence_scores).pow(2)
batch_loss = torch.where(is_below_tau, strong_term, weak_term)

return batch_loss.mean()
36 changes: 35 additions & 1 deletion training/nsf_HiFigan_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@
from torch.utils.data import Dataset

from models.nsf_HiFigan.models import Generator, AttrDict, MultiScaleDiscriminator, MultiPeriodDiscriminator
from models.wm_d import PerthDetectorDiscriminator
from modules.loss.HiFiloss import HiFiloss
from training.base_task_gan import GanBaseTask
from utils.wav2F0 import PITCH_EXTRACTORS_ID_TO_NAME, get_pitch
from utils.wav2mel import PitchAdjustableMelSpectrogram
from perth.utils import calculate_audio_metrics


def spec_to_figure(spec, vmin=None, vmax=None):
Expand Down Expand Up @@ -254,6 +256,9 @@ def __init__(self, config):
self.logged_gt_wav = set()
self.stft = stftlog()
self.max_f0 = get_max_f0_from_config(config)
self.add_watermark = self.config.get('add_watermark', False)
if self.add_watermark:
self.lambda_wm = self.config.get('lambda_wm', 1.0)

def build_dataset(self):

Expand All @@ -280,6 +285,15 @@ def build_model(self):
'msd': MultiScaleDiscriminator(),
'mpd': MultiPeriodDiscriminator(periods=cfg['discriminator_periods'])
})
if self.add_watermark:
loss_tau = self.config.get('watermark_loss_tau', 0.7)
loss_alpha = self.config.get('watermark_loss_alpha', 0.1)
self.watermark_discriminator = PerthDetectorDiscriminator(
audio_sr=self.config['audio_sample_rate'],
device=self.device,
loss_tau=loss_tau,
loss_alpha=loss_alpha
)

def build_losses_and_metrics(self):
self.mix_loss = HiFiloss(self.config)
Expand Down Expand Up @@ -379,7 +393,15 @@ def _training_step(self, sample, batch_idx):
spec_loss, Auxlog = self.mix_loss.Auxloss(Goutput=Goutput, sample=sample)
Auxloss = spec_loss + pc_wav_loss
log_dict.update(Auxlog)

Gloss = GDloss + Auxloss

if self.add_watermark:
confidence_scores = self.watermark_discriminator(Goutput['audio'])
watermark_loss = self.watermark_discriminator.loss(confidence_scores)
log_dict['watermark_loss'] = watermark_loss.item()
log_dict['confidence'] = confidence_scores.mean().item()
Gloss = Gloss + self.lambda_wm * watermark_loss

opt_g.zero_grad() # clean generator grad
self.manual_backward(Gloss)
Expand All @@ -400,6 +422,15 @@ def _validation_step(self, sample, batch_idx):
stfts_log10 = torch.log10(torch.clamp(stfts, min=1e-7))
Gstfts_log10 = torch.log10(torch.clamp(Gstfts, min=1e-7))

original_np = sample['audio'].squeeze().cpu().numpy()
generated_np = wav.squeeze().cpu().numpy()

min_len = min(len(original_np), len(generated_np))
original_np = original_np[:min_len]
generated_np = generated_np[:min_len]

quality_metrics = calculate_audio_metrics(original_np, generated_np)

if self.global_rank == 0:
self.plot_mel(batch_idx, Gstfts_log10.transpose(1, 2), stfts_log10.transpose(1, 2),
name=f'log10stft_{batch_idx}')
Expand All @@ -412,7 +443,10 @@ def _validation_step(self, sample, batch_idx):
global_step=self.global_step)
self.logged_gt_wav.add(batch_idx)

return {'stft_loss': nn.L1Loss()(Gstfts_log10, stfts_log10)}, 1
val_outputs = {'stft_loss': nn.L1Loss()(Gstfts_log10, stfts_log10)}
val_outputs.update(quality_metrics)

return val_outputs, 1

def plot_mel(self, batch_idx, spec, spec_out, name=None):
name = f'mel_{batch_idx}' if name is None else name
Expand Down