Skip to main content
  • Home
  • Development
  • Documentation
  • Donate
  • Operational login
  • Browse the archive

swh logo
SoftwareHeritage
Software
Heritage
Archive
Features
  • Search

  • Downloads

  • Save code now

  • Add forge now

  • Help

https://github.com/open-mmlab/Amphion
09 September 2024, 06:46:44 UTC
  • Code
  • Branches (2)
  • Releases (3)
  • Visits
    • Branches
    • Releases
    • HEAD
    • refs/heads/main
    • refs/heads/revert-154-FACodec-readme
    • v0.1.1-alpha
    • v0.1.0-alpha
    • v0.1.0
  • 50adafb
  • /
  • models
  • /
  • svc
  • /
  • base
  • /
  • svc_trainer.py
Raw File Download Save again
Take a new snapshot of a software origin

If the archived software origin currently browsed is not synchronized with its upstream version (for instance when new commits have been issued), you can explicitly request Software Heritage to take a new snapshot of it.

Use the form below to proceed. Once a request has been submitted and accepted, it will be processed as soon as possible. You can then check its processing state by visiting this dedicated page.
swh spinner

Processing "take a new snapshot" request ...

To reference or cite the objects present in the Software Heritage archive, permalinks based on SoftWare Hash IDentifiers (SWHIDs) must be used.
Select below a type of object currently browsed in order to display its associated SWHID and permalink.

  • content
  • directory
  • revision
  • snapshot
origin badgecontent badge
swh:1:cnt:f2d8f69c115dbb09ee4cde1968ece1c7084d7394
origin badgedirectory badge
swh:1:dir:ac83f270966cb0d559657a76345a4ad728d84e85
origin badgerevision badge
swh:1:rev:251c6690ae3de6d04454876fbb864e8664951bc8
origin badgesnapshot badge
swh:1:snp:bef780d851faeac80aef6db569e51e66f505bf34

This interface enables to generate software citations, provided that the root directory of browsed objects contains a citation.cff or codemeta.json file.
Select below a type of object currently browsed in order to generate citations for them.

  • content
  • directory
  • revision
  • snapshot
(requires biblatex-software package)
Generating citation ...
(requires biblatex-software package)
Generating citation ...
(requires biblatex-software package)
Generating citation ...
(requires biblatex-software package)
Generating citation ...
Tip revision: 251c6690ae3de6d04454876fbb864e8664951bc8 authored by Harry He on 06 September 2024, 13:52:56 UTC
update Amphion/Emilia references (#271)
Tip revision: 251c669
svc_trainer.py
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import json
import os

import torch
import torch.nn as nn
import numpy as np

from models.base.new_trainer import BaseTrainer
from models.svc.base.svc_dataset import (
    SVCOfflineCollator,
    SVCOfflineDataset,
    SVCOnlineCollator,
    SVCOnlineDataset,
)
from processors.audio_features_extractor import AudioFeaturesExtractor
from processors.acoustic_extractor import cal_normalized_mel, load_mel_extrema

EPS = 1.0e-12


class SVCTrainer(BaseTrainer):
    r"""The base trainer for all SVC models. It inherits from BaseTrainer and implements
    ``build_criterion``, ``_build_dataset`` and ``_build_singer_lut`` methods. You can inherit from this
    class, and implement ``_build_model``, ``_forward_step``.
    """

    def __init__(self, args=None, cfg=None):
        self.args = args
        self.cfg = cfg

        self._init_accelerator()

        # Only for SVC tasks
        with self.accelerator.main_process_first():
            self.singers = self._build_singer_lut()

        # Super init
        BaseTrainer.__init__(self, args, cfg)

        # Only for SVC tasks
        self.task_type = "SVC"
        self.logger.info("Task type: {}".format(self.task_type))

    ### Following are methods only for SVC tasks ###
    def _build_dataset(self):
        self.online_features_extraction = (
            self.cfg.preprocess.features_extraction_mode == "online"
        )

        if not self.online_features_extraction:
            return SVCOfflineDataset, SVCOfflineCollator
        else:
            self.audio_features_extractor = AudioFeaturesExtractor(self.cfg)
            return SVCOnlineDataset, SVCOnlineCollator

    def _extract_svc_features(self, batch):
        """
        Features extraction during training

        Batch:
            wav: (B, T)
            wav_len: (B)
            target_len: (B)
            mask: (B, n_frames, 1)
            spk_id: (B, 1)

            wav_{sr}: (B, T)
            wav_{sr}_len: (B)

        Added elements when output:
            mel: (B, n_frames, n_mels)
            frame_pitch: (B, n_frames)
            frame_uv: (B, n_frames)
            frame_energy: (B, n_frames)
            frame_{content}: (B, n_frames, D)
        """

        padded_n_frames = torch.max(batch["target_len"])
        final_n_frames = padded_n_frames

        ### Mel Spectrogram ###
        if self.cfg.preprocess.use_mel:
            # (B, n_mels, n_frames)
            raw_mel = self.audio_features_extractor.get_mel_spectrogram(batch["wav"])
            if self.cfg.preprocess.use_min_max_norm_mel:
                # TODO: Change the hard code

                # Using the empirical mel extrema to denormalize
                if not hasattr(self, "mel_extrema"):
                    # (n_mels)
                    m, M = load_mel_extrema(self.cfg.preprocess, "vctk")
                    # (1, n_mels, 1)
                    m = (
                        torch.as_tensor(m, device=raw_mel.device)
                        .unsqueeze(0)
                        .unsqueeze(-1)
                    )
                    M = (
                        torch.as_tensor(M, device=raw_mel.device)
                        .unsqueeze(0)
                        .unsqueeze(-1)
                    )
                    self.mel_extrema = m, M

                m, M = self.mel_extrema
                mel = (raw_mel - m) / (M - m + EPS) * 2 - 1

            else:
                mel = raw_mel

            final_n_frames = min(final_n_frames, mel.size(-1))

            # (B, n_frames, n_mels)
            batch["mel"] = mel.transpose(1, 2)
        else:
            raw_mel = None

        ### F0 ###
        if self.cfg.preprocess.use_frame_pitch:
            # (B, n_frames)
            raw_f0, raw_uv = self.audio_features_extractor.get_f0(
                batch["wav"],
                wav_lens=batch["wav_len"],
                use_interpolate=self.cfg.preprocess.use_interpolation_for_uv,
                return_uv=True,
            )
            final_n_frames = min(final_n_frames, raw_f0.size(-1))
            batch["frame_pitch"] = raw_f0

            if self.cfg.preprocess.use_uv:
                batch["frame_uv"] = raw_uv

        ### Energy ###
        if self.cfg.preprocess.use_frame_energy:
            # (B, n_frames)
            raw_energy = self.audio_features_extractor.get_energy(
                batch["wav"], mel_spec=raw_mel
            )
            final_n_frames = min(final_n_frames, raw_energy.size(-1))
            batch["frame_energy"] = raw_energy

        ### Semantic Features ###
        if self.cfg.model.condition_encoder.use_whisper:
            # (B, n_frames, D)
            whisper_feats = self.audio_features_extractor.get_whisper_features(
                wavs=batch["wav_{}".format(self.cfg.preprocess.whisper_sample_rate)],
                target_frame_len=padded_n_frames,
            )
            final_n_frames = min(final_n_frames, whisper_feats.size(1))
            batch["whisper_feat"] = whisper_feats

        if self.cfg.model.condition_encoder.use_contentvec:
            # (B, n_frames, D)
            contentvec_feats = self.audio_features_extractor.get_contentvec_features(
                wavs=batch["wav_{}".format(self.cfg.preprocess.contentvec_sample_rate)],
                target_frame_len=padded_n_frames,
            )
            final_n_frames = min(final_n_frames, contentvec_feats.size(1))
            batch["contentvec_feat"] = contentvec_feats

        if self.cfg.model.condition_encoder.use_wenet:
            # (B, n_frames, D)
            wenet_feats = self.audio_features_extractor.get_wenet_features(
                wavs=batch["wav_{}".format(self.cfg.preprocess.wenet_sample_rate)],
                target_frame_len=padded_n_frames,
                wav_lens=batch[
                    "wav_{}_len".format(self.cfg.preprocess.wenet_sample_rate)
                ],
            )
            final_n_frames = min(final_n_frames, wenet_feats.size(1))
            batch["wenet_feat"] = wenet_feats

        ### Align all the audio features to the same frame length ###
        frame_level_features = [
            "mask",
            "mel",
            "frame_pitch",
            "frame_uv",
            "frame_energy",
            "whisper_feat",
            "contentvec_feat",
            "wenet_feat",
        ]
        for k in frame_level_features:
            if k in batch:
                # (B, n_frames, ...)
                batch[k] = batch[k][:, :final_n_frames].contiguous()

        return batch

    @staticmethod
    def _build_criterion():
        criterion = nn.MSELoss(reduction="none")
        return criterion

    @staticmethod
    def _compute_loss(criterion, y_pred, y_gt, loss_mask):
        """
        Args:
            criterion: MSELoss(reduction='none')
            y_pred, y_gt: (B, seq_len, D)
            loss_mask: (B, seq_len, 1)
        Returns:
            loss: Tensor of shape []
        """

        # (B, seq_len, D)
        loss = criterion(y_pred, y_gt)
        # expand loss_mask to (B, seq_len, D)
        loss_mask = loss_mask.repeat(1, 1, loss.shape[-1])

        loss = torch.sum(loss * loss_mask) / torch.sum(loss_mask)
        return loss

    def _save_auxiliary_states(self):
        """
        To save the singer's look-up table in the checkpoint saving path
        """
        with open(
            os.path.join(self.tmp_checkpoint_save_path, self.cfg.preprocess.spk2id),
            "w",
            encoding="utf-8",
        ) as f:
            json.dump(self.singers, f, indent=4, ensure_ascii=False)

    def _build_singer_lut(self):
        resumed_singer_path = None
        if self.args.resume_from_ckpt_path and self.args.resume_from_ckpt_path != "":
            resumed_singer_path = os.path.join(
                self.args.resume_from_ckpt_path, self.cfg.preprocess.spk2id
            )
        if os.path.exists(os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)):
            resumed_singer_path = os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)

        if resumed_singer_path:
            with open(resumed_singer_path, "r") as f:
                singers = json.load(f)
        else:
            singers = dict()

        for dataset in self.cfg.dataset:
            singer_lut_path = os.path.join(
                self.cfg.preprocess.processed_dir, dataset, self.cfg.preprocess.spk2id
            )
            with open(singer_lut_path, "r") as singer_lut_path:
                singer_lut = json.load(singer_lut_path)
            for singer in singer_lut.keys():
                if singer not in singers:
                    singers[singer] = len(singers)

        with open(
            os.path.join(self.exp_dir, self.cfg.preprocess.spk2id), "w"
        ) as singer_file:
            json.dump(singers, singer_file, indent=4, ensure_ascii=False)
        print(
            "singers have been dumped to {}".format(
                os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)
            )
        )
        return singers

back to top

Software Heritage — Copyright (C) 2015–2026, The Software Heritage developers. License: GNU AGPLv3+.
The source code of Software Heritage itself is available on our development forge.
The source code files archived by Software Heritage are available under their own copyright and licenses.
Terms of use: Archive access, API— Content policy— Contact— JavaScript license information— Web API