Skip to main content
  • Home
  • Development
  • Documentation
  • Donate
  • Operational login
  • Browse the archive

swh logo
SoftwareHeritage
Software
Heritage
Archive
Features
  • Search

  • Downloads

  • Save code now

  • Add forge now

  • Help

  • c1da1d0
  • /
  • content_extractor.py
Raw File Download

To reference or cite the objects present in the Software Heritage archive, permalinks based on SoftWare Hash IDentifiers (SWHIDs) must be used.
Select below a type of object currently browsed in order to display its associated SWHID and permalink.

  • content
  • directory
content badge Iframe embedding
swh:1:cnt:34b54917a8d672a91f25fc6a54453e5a15d1296d
directory badge Iframe embedding
swh:1:dir:c1da1d0162aac626a09ccdc706a0209964275a78

This interface enables to generate software citations, provided that the root directory of browsed objects contains a citation.cff or codemeta.json file.
Select below a type of object currently browsed in order to generate citations for them.

  • content
  • directory
Generate software citation in BibTex format (requires biblatex-software package)
Generating citation ...
Generate software citation in BibTex format (requires biblatex-software package)
Generating citation ...
content_extractor.py
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os
import torch
import numpy as np
import yaml
import copy
from tqdm import tqdm
from torchaudio.compliance import kaldi
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from fairseq import checkpoint_utils
from transformers import AutoModel, Wav2Vec2FeatureExtractor

from utils.io_optim import (
    TorchaudioDataset,
    LibrosaDataset,
    FFmpegDataset,
    collate_batch,
)
import whisper
from modules.wenet_extractor.utils.init_model import init_model
from modules.wenet_extractor.utils.checkpoint import load_checkpoint

"""
    Extractor for content features
    1. whisper
    2. contentvec
    3. wenet
    4. mert

    Pipeline:
        in preprocess.py:
            call extract_utt_content_features() to extract content features for each utterance
            extract_utt_content_features() envelopes the following steps:
                1. load the model (whisper, contentvec, wenet)
                2. extract the content features
                3. save the content features into files
        in svc_dataset.py:
            call offline_align() to align the content features to the given target length

"""

"""
    Extractor Usage:
        1. initialize an instance of extractor
            extractor = WhisperExtractor(cfg)
        2. load the specified model
            extractor.load_model()
        3. extract the content features
            extractor.extract_content(utt) for single utterance
            extractor.extract_content_batch(utts) for batch utterances
        4. save the content features
            extractor.save_feature(utt, content_feature) for single utterance
"""


class AudioPretrainedModelFeaturesExtractor:
    def __init__(self, cfg, extractor_type):
        self.cfg = cfg
        self.extractor_type = extractor_type
        self.model = None
        self.init_for_retrans()

    def init_for_retrans(self):
        target_hop = self.cfg.preprocess.hop_size

        assert self.extractor_type in ["whisper", "contentvec", "wenet"]
        if self.extractor_type == "whisper":
            source_hop = (
                self.cfg.preprocess.whisper_frameshift
                * self.cfg.preprocess.whisper_downsample_rate
                * self.cfg.preprocess.sample_rate
            )
        elif self.extractor_type == "contentvec":
            source_hop = (
                self.cfg.preprocess.contentvec_frameshift
                * self.cfg.preprocess.sample_rate
            )
        elif self.extractor_type == "wenet":
            source_hop = (
                self.cfg.preprocess.wenet_frameshift
                * self.cfg.preprocess.wenet_downsample_rate
                * self.cfg.preprocess.sample_rate
            )
        source_hop = int(source_hop)
        factor = np.gcd(source_hop, target_hop)
        source_hop //= factor
        target_hop //= factor

        self.source_hop = source_hop
        self.target_hop = target_hop

    def offline_resolution_transformation(self, content, target_len):
        """
        args:
            content: (source_len, dim)
            target_len: target length
        return:
            mapped_feature: (target_len, dim)
        """
        source_hop = self.source_hop
        target_hop = self.target_hop

        # (source_len, 256)
        _, width = content.shape
        # slice the content from padded feature
        source_len = min(target_len * target_hop // source_hop + 1, len(content))

        # const ~= target_len * target_hop
        const = source_len * source_hop // target_hop * target_hop

        # (source_len * source_hop, dim)
        up_sampling_feats = np.repeat(content, source_hop, axis=0)
        # (const, dim) -> (const/target_hop, target_hop, dim) -> (const/target_hop, dim)
        down_sampling_feats = np.average(
            up_sampling_feats[:const].reshape(-1, target_hop, width), axis=1
        )

        err = abs(target_len - len(down_sampling_feats))
        if err > 8:
            # err_log_dir is indeterminate
            err_log_dir = os.path.join(
                self.cfg.preprocess.processed_dir, "align_max_err.log"
            )
            try:
                with open(err_log_dir, "r") as f:
                    err_num = int(f.read())
            except:
                with open(err_log_dir, "w") as f:
                    f.write("0")
                err_num = 0
            if err > err_num:
                with open(err_log_dir, "w") as f:
                    f.write(str(err))

        if len(down_sampling_feats) < target_len:
            # (1, dim) -> (err, dim)
            end = down_sampling_feats[-1][None, :].repeat(err, axis=0)
            down_sampling_feats = np.concatenate([down_sampling_feats, end], axis=0)

        # (target_len, dim)
        mapped_feature = down_sampling_feats[:target_len]

        return mapped_feature

    def log_for_ReTrans(self, err):
        err_log_dir = os.path.join(
            self.cfg.preprocess.processed_dir, "align_max_err.log"
        )
        try:
            with open(err_log_dir, "r") as f:
                err_num = int(f.read())
        except:
            with open(err_log_dir, "w") as f:
                f.write("0")
            err_num = 0
        if err > err_num:
            with open(err_log_dir, "w") as f:
                f.write(str(err))

    def ReTrans(self, source_feats, padded_target_len):
        """
        Resolution Transformation for mismatched frames alginment.

        TODO: Merge the offline resolution_transformation into one

        args:
            source_feats: Tensor, (B, padded_source_len, D)
            padded_target_len: int, the maximum target length in a batch
        return:
            mapped_feature: Tensor, (B, padded_target_len, D)
        """
        source_hop = self.source_hop
        target_hop = self.target_hop

        # (B, padded_source_len, D)
        B, padded_source_len, D = source_feats.shape

        # select the valid content from padded feature
        source_len = min(
            padded_target_len * target_hop // source_hop + 1, padded_source_len
        )

        # const ~= padded_target_len * target_hop (padded wav's duration)
        const = source_len * source_hop // target_hop * target_hop

        # (B, padded_source_len, D) -> (B, padded_source_len * source_hop, D) -> (B, const, D)
        up_sampling_feats = torch.repeat_interleave(source_feats, source_hop, dim=1)[
            :, :const
        ]
        # (B, const, D) -> (B, const/target_hop, target_hop, D) -> (B, const/target_hop, D)
        down_sampling_feats = torch.mean(
            up_sampling_feats.reshape(B, -1, target_hop, D), dim=2
        )

        err = abs(padded_target_len - down_sampling_feats.shape[1])
        if err > 8:
            self.log_for_ReTrans(err)

        if down_sampling_feats.shape[1] < padded_target_len:
            # (B, 1, D) -> (B, err, D)
            end = down_sampling_feats[:, -1, :][:, None, :].repeat_interleave(
                err, dim=1
            )
            # -> (B, padded_target_len, D)
            down_sampling_feats = torch.cat([down_sampling_feats, end], dim=1)

        # (B, padded_target_len, D)
        mapped_feature = down_sampling_feats[:, :padded_target_len]
        return mapped_feature

    def get_valid_features(self, utt, content_feature):
        # only keep effective parts
        duration = utt["Duration"]
        if self.extractor_type == "whisper":
            frameshift = (
                self.cfg.preprocess.whisper_frameshift
                * self.cfg.preprocess.whisper_downsample_rate
            )  # 20ms
        elif self.extractor_type == "contentvec":
            frameshift = self.cfg.preprocess.contentvec_frameshift  # 20ms
        elif self.extractor_type == "wenet":
            frameshift = (
                self.cfg.preprocess.wenet_frameshift
                * self.cfg.preprocess.wenet_downsample_rate
            )  # 40ms
        elif self.extractor_type == "mert":
            frameshift = self.cfg.preprocess.mert_frameshift
        else:
            raise NotImplementedError

        # calculate the number of valid frames
        num_frames = int(np.ceil((duration - frameshift) / frameshift)) + 1
        assert (
            len(content_feature.shape) == 2
        ), "content feature shape error, it should be (num_frames, dim)"
        content_feature = content_feature[:num_frames, :]
        return content_feature

    def save_feature(self, utt, content_feature):
        """Save a single utternace to path {cfg.preprocess.processed_dir}

        Args:
            utt (dict): one item in metadata, containing information for one utterance
            content_feature (tensor): content feature of one utterance
        """
        uid = utt["Uid"]
        assert self.extractor_type != None
        out_dir = os.path.join(
            self.cfg.preprocess.processed_dir, utt["Dataset"], self.extractor_type
        )
        os.makedirs(out_dir, exist_ok=True)
        save_path = os.path.join(out_dir, uid + ".npy")

        content_feature = self.get_valid_features(utt, content_feature)
        np.save(save_path, content_feature.cpu().detach().numpy())


class WhisperExtractor(AudioPretrainedModelFeaturesExtractor):
    def __init__(self, config):
        super(WhisperExtractor, self).__init__(config, extractor_type="whisper")
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def load_model(self):
        # load whisper checkpoint
        print("Loading Whisper Model...")

        if "whisper_model_path" in self.cfg.preprocess:
            if os.path.isfile(self.cfg.preprocess.whisper_model_path):
                # "pretrained/whisper/medium.pt"
                download_root = os.path.dirname(self.cfg.preprocess.whisper_model_path)
            elif os.path.isdir(self.cfg.preprocess.whisper_model_path):
                # "pretrained/whisper"
                download_root = self.cfg.preprocess.whisper_model_path
            else:
                # if the path does not exist, download the model to the path
                download_root = self.cfg.preprocess.whisper_model_path
                if download_root.endswith(".pt"):
                    download_root = os.path.dirname(download_root)
        else:
            download_root = None

        model = whisper.load_model(
            self.cfg.preprocess.whisper_model, self.device, download_root
        )
        if torch.cuda.is_available():
            print("Using GPU...\n")
            model = model.cuda()
        else:
            print("Using CPU...\n")

        self.model = model.eval()

    def extract_content_features(self, wavs):
        """extract content features from a batch of dataloader
        Args:
            wavs: tensor (batch_size, T)
        """
        # wavs: (batch, max_len)
        wavs = whisper.pad_or_trim(wavs)
        # batch_mel: (batch, 80, 3000)
        batch_mel = whisper.log_mel_spectrogram(wavs, device=self.model.device)
        with torch.no_grad():
            # (batch, 1500, 1024)
            features = self.model.embed_audio(batch_mel)
        return features


class ContentvecExtractor(AudioPretrainedModelFeaturesExtractor):
    def __init__(self, cfg):
        super(ContentvecExtractor, self).__init__(cfg, extractor_type="contentvec")

    def load_model(self):
        assert self.model == None
        # Load model
        ckpt_path = self.cfg.preprocess.contentvec_file
        print("Load Contentvec Model...")

        models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
            [ckpt_path],
            suffix="",
        )
        model = models[0]
        model.eval()

        if torch.cuda.is_available():
            # print("Using GPU...\n")
            model = model.cuda()

        self.model = model

    def extract_content_features(self, wavs):
        """extract content features from a batch of dataloader
        Args:
            wavs: tensor (batch, T)
        """
        device = next(self.model.parameters()).device
        wavs = wavs.to(device)  # (batch, max_len)
        padding_mask = torch.eq(wavs, torch.zeros_like(wavs)).to(device)
        with torch.no_grad():
            logits = self.model.extract_features(
                source=wavs, padding_mask=padding_mask, output_layer=12
            )
            # feats: (batch, T, 256)
            feats = self.model.final_proj(logits[0])
        return feats


class WenetExtractor(AudioPretrainedModelFeaturesExtractor):
    def __init__(self, config):
        super(WenetExtractor, self).__init__(config, extractor_type="wenet")

    def load_model(self):
        wenet_cfg = self.cfg.preprocess.wenet_config
        wenet_model_path = self.cfg.preprocess.wenet_model_path
        # load Wenet config
        with open(wenet_cfg, "r") as w:
            wenet_configs = yaml.load(w, Loader=yaml.FullLoader)
        self.extract_conf = copy.deepcopy(wenet_configs["dataset_conf"])
        print("Loading Wenet Model...")
        self.model = init_model(wenet_configs)
        load_checkpoint(self.model, wenet_model_path)

        if torch.cuda.is_available():
            print("Using GPU...\n")
            self.model = self.model.cuda()
        else:
            print("Using CPU...\n")

        self.model = self.model.eval()

    def extract_content_features(self, wavs, lens):
        """extract content features from a batch of dataloader
        Args:
            wavs: tensor, whose shape is (B, T)
            lens: list
        """
        feats_list = []
        lengths_list = []

        device = next(self.model.parameters()).device
        # Extract fbank/mfcc features by kaldi
        assert self.extract_conf is not None, "load model first!"
        feats_type = self.extract_conf.get("feats_type", "fbank")
        assert feats_type in ["fbank", "mfcc"]

        for idx, wav in enumerate(wavs):
            # wav: (T)
            wav = wav[: lens[idx]].to(device)

            # pad one frame to compensate for the frame cut off after feature extraction
            pad_tensor = torch.zeros(160, device=wav.device)
            wav = torch.cat((wav, pad_tensor), dim=-1)
            wav *= 1 << 15

            wav = wav.unsqueeze(0)  # (T) -> (1, T)
            if feats_type == "fbank":
                fbank_conf = self.extract_conf.get("fbank_conf", {})
                feat = kaldi.fbank(
                    wav,
                    sample_frequency=16000,
                    num_mel_bins=fbank_conf["num_mel_bins"],
                    frame_length=fbank_conf["frame_length"],
                    frame_shift=fbank_conf["frame_shift"],
                    dither=fbank_conf["dither"],
                )
            elif feats_type == "mfcc":
                mfcc_conf = self.extract_conf.get("mfcc", {})
                feat = kaldi.mfcc(
                    wav,
                    sample_frequency=16000,
                    num_mel_bins=mfcc_conf["num_mel_bins"],
                    frame_length=mfcc_conf["frame_length"],
                    frame_shift=mfcc_conf["frame_shift"],
                    dither=mfcc_conf["dither"],
                    num_ceps=mfcc_conf.get("num_ceps", 40),
                    high_freq=mfcc_conf.get("high_freq", 0.0),
                    low_freq=mfcc_conf.get("low_freq", 20.0),
                )
            feats_list.append(feat)
            lengths_list.append(feat.shape[0])

        feats_lengths = torch.tensor(lengths_list, dtype=torch.int32).to(device)
        feats_tensor = pad_sequence(feats_list, batch_first=True).to(
            device
        )  # (batch, len, 80)

        features = self.model.encoder_extractor(
            feats_tensor,
            feats_lengths,
            decoding_chunk_size=-1,
            num_decoding_left_chunks=-1,
            simulate_streaming=False,
        )
        return features


class MertExtractor(AudioPretrainedModelFeaturesExtractor):
    def __init__(self, cfg):
        super(MertExtractor, self).__init__(cfg, extractor_type="mert")
        self.preprocessor = None

    def load_model(self):
        assert self.model == None
        assert self.preprocessor == None

        print("Loading MERT Model: ...", self.cfg.preprocess.mert_model)

        model_name = self.cfg.preprocess.mert_model
        model = AutoModel.from_pretrained(model_name, trust_remote_code=True)

        if torch.cuda.is_available():
            model = model.cuda()
        preprocessor = Wav2Vec2FeatureExtractor.from_pretrained(
            model_name, trust_remote_code=True
        )

        self.model = model
        self.preprocessor = preprocessor

    def extract_content_features(self, wavs):
        """extract content features from a batch of dataloader
        Args:
            wavs: tensor (batch, T)
        """
        with torch.no_grad():
            sample_rate = self.preprocessor.sampling_rate
            device = next(self.model.parameters()).device
            assert (
                sample_rate == self.cfg.preprocess.mert_sample_rate
            ), "mert sample rate mismatch, expected {}, got {}".format(
                self.cfg.preprocess.mert_sample_rate, sample_rate
            )
            mert_features = []
            # wav: (len)
            for wav in wavs:
                # {input_values: tensor, attention_mask: tensor}
                inputs = self.preprocessor(
                    wavs, sampling_rate=sample_rate, return_tensors="pt"
                ).to(device)

                outputs = self.model(**inputs, output_hidden_states=True)
                # (25 layers, time steps, 1024 feature_dim)
                all_layer_hidden_states = torch.stack(outputs.hidden_states).squeeze()
                # (1, frame_len, 1024) -> (frame_len, 1024)
                feature = outputs.hidden_states[
                    self.cfg.preprocess.mert_feature_layer
                ].squeeze(0)
                mert_features.append(feature)

        return mert_features


def extract_utt_content_features_dataloader(cfg, metadata, num_workers):
    dataset_name = metadata[0]["Dataset"]
    with torch.no_grad():
        if cfg.preprocess.extract_whisper_feature:
            feat_dir = os.path.join(
                cfg.preprocess.processed_dir, dataset_name, "whisper"
            )
            os.makedirs(feat_dir, exist_ok=True)
            feat_files_num = len(os.listdir(feat_dir))

            if feat_files_num != len(metadata):
                whisper_waveforms = FFmpegDataset(
                    cfg,
                    dataset_name,
                    cfg.preprocess.whisper_sample_rate,
                    metadata=metadata,
                )
                data_loader = DataLoader(
                    whisper_waveforms,
                    num_workers=num_workers,
                    shuffle=False,
                    pin_memory=cfg.preprocess.pin_memory,
                    batch_size=cfg.preprocess.content_feature_batch_size,
                    collate_fn=collate_batch,
                    drop_last=False,
                )
                extractor = WhisperExtractor(cfg)
                extractor.load_model()
                for batch_idx, items in enumerate(tqdm(data_loader)):
                    _metadata, wavs, lens = items

                    batch_content_features = extractor.extract_content_features(wavs)
                    for index, utt in enumerate(_metadata):
                        extractor.save_feature(utt, batch_content_features[index])

        if cfg.preprocess.extract_contentvec_feature:
            feat_dir = os.path.join(
                cfg.preprocess.processed_dir, dataset_name, "contentvec"
            )
            os.makedirs(feat_dir, exist_ok=True)
            feat_files_num = len(os.listdir(feat_dir))

            if feat_files_num != len(metadata):
                contentvec_waveforms = LibrosaDataset(
                    cfg,
                    dataset_name,
                    cfg.preprocess.contentvec_sample_rate,
                    metadata=metadata,
                )
                data_loader = DataLoader(
                    contentvec_waveforms,
                    num_workers=num_workers,
                    shuffle=False,
                    pin_memory=cfg.preprocess.pin_memory,
                    batch_size=cfg.preprocess.content_feature_batch_size,
                    collate_fn=collate_batch,
                    drop_last=False,
                )
                extractor = ContentvecExtractor(cfg)
                extractor.load_model()
                for batch_idx, items in enumerate(tqdm(data_loader)):
                    _metadata, wavs, lens = items

                    batch_content_features = extractor.extract_content_features(wavs)
                    for index, utt in enumerate(_metadata):
                        extractor.save_feature(utt, batch_content_features[index])

        if cfg.preprocess.extract_wenet_feature:
            feat_dir = os.path.join(cfg.preprocess.processed_dir, dataset_name, "wenet")
            os.makedirs(feat_dir, exist_ok=True)
            feat_files_num = len(os.listdir(feat_dir))

            if feat_files_num != len(metadata):
                wenet_waveforms = TorchaudioDataset(
                    cfg,
                    dataset_name,
                    cfg.preprocess.wenet_sample_rate,
                    metadata=metadata,
                )
                data_loader = DataLoader(
                    wenet_waveforms,
                    num_workers=num_workers,
                    shuffle=False,
                    pin_memory=cfg.preprocess.pin_memory,
                    batch_size=cfg.preprocess.content_feature_batch_size,
                    collate_fn=collate_batch,
                    drop_last=False,
                )
                extractor = WenetExtractor(cfg)
                extractor.load_model()
                for batch_idx, items in enumerate(tqdm(data_loader)):
                    _metadata, wavs, lens = items

                    batch_content_features = extractor.extract_content_features(
                        wavs,
                        lens,
                    )
                    for index, utt in enumerate(_metadata):
                        extractor.save_feature(utt, batch_content_features[index])

        if cfg.preprocess.extract_mert_feature:
            feat_dir = os.path.join(cfg.preprocess.processed_dir, dataset_name, "mert")
            os.makedirs(feat_dir, exist_ok=True)
            feat_files_num = len(os.listdir(feat_dir))

            if feat_files_num != len(metadata):
                mert_waveforms = TorchaudioDataset(
                    cfg,
                    dataset_name,
                    cfg.preprocess.mert_sample_rate,
                    metadata=metadata,
                )
                data_loader = DataLoader(
                    mert_waveforms,
                    num_workers=num_workers,
                    shuffle=False,
                    pin_memory=cfg.preprocess.pin_memory,
                    batch_size=cfg.preprocess.content_feature_batch_size,
                    collate_fn=collate_batch,
                    drop_last=False,
                )
                extractor = MertExtractor(cfg)
                extractor.load_model()
                for batch_idx, items in enumerate(tqdm(data_loader)):
                    _metadata, wavs, lens = items

                    batch_content_features = extractor.extract_content_features(wavs)
                    for index, utt in enumerate(_metadata):
                        extractor.save_feature(utt, batch_content_features[index])

back to top

Software Heritage — Copyright (C) 2015–2025, The Software Heritage developers. License: GNU AGPLv3+.
The source code of Software Heritage itself is available on our development forge.
The source code files archived by Software Heritage are available under their own copyright and licenses.
Terms of use: Archive access, API— Content policy— Contact— JavaScript license information— Web API