Skip to main content
  • Home
  • Development
  • Documentation
  • Donate
  • Operational login
  • Browse the archive

swh logo
SoftwareHeritage
Software
Heritage
Archive
Features
  • Search

  • Downloads

  • Save code now

  • Add forge now

  • Help

https://github.com/open-mmlab/Amphion
01 April 2026, 10:44:45 UTC
  • Code
  • Branches (2)
  • Releases (3)
  • Visits
    • Branches
    • Releases
    • HEAD
    • refs/heads/main
    • refs/heads/revert-154-FACodec-readme
    • v0.1.1-alpha
    • v0.1.0-alpha
    • v0.1.0
  • c135c1f
  • /
  • processors
  • /
  • content_extractor.py
Raw File Download Save again
Take a new snapshot of a software origin

If the archived software origin currently browsed is not synchronized with its upstream version (for instance when new commits have been issued), you can explicitly request Software Heritage to take a new snapshot of it.

Use the form below to proceed. Once a request has been submitted and accepted, it will be processed as soon as possible. You can then check its processing state by visiting this dedicated page.
swh spinner

Processing "take a new snapshot" request ...

To reference or cite the objects present in the Software Heritage archive, permalinks based on SoftWare Hash IDentifiers (SWHIDs) must be used.
Select below a type of object currently browsed in order to display its associated SWHID and permalink.

  • content
  • directory
  • revision
  • snapshot
origin badgecontent badge
swh:1:cnt:fe2f323b90fa91d4eb64e8f753a15491cdbf8ed8
origin badgedirectory badge
swh:1:dir:7299d913553a074234efa759cb4256e2bcb81da1
origin badgerevision badge
swh:1:rev:26f6883110181f1dbfe95c70a7c7dbaf4de5f42a
origin badgesnapshot badge
swh:1:snp:fc8cecae7657ff85d660f8ea19d9339e78426c63

This interface enables to generate software citations, provided that the root directory of browsed objects contains a citation.cff or codemeta.json file.
Select below a type of object currently browsed in order to generate citations for them.

  • content
  • directory
  • revision
  • snapshot
(requires biblatex-software package)
Generating citation ...
(requires biblatex-software package)
Generating citation ...
(requires biblatex-software package)
Generating citation ...
(requires biblatex-software package)
Generating citation ...
Tip revision: 26f6883110181f1dbfe95c70a7c7dbaf4de5f42a authored by Xueyao Zhang on 25 March 2026, 14:11:57 UTC
Vevo2 Release (#481)
Tip revision: 26f6883
content_extractor.py
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os
import torch
import numpy as np
import yaml
import copy
from tqdm import tqdm
from torchaudio.compliance import kaldi
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from fairseq import checkpoint_utils
from transformers import AutoModel, Wav2Vec2FeatureExtractor
import torch.nn as nn
import torch.nn.functional as F
import joblib
from einops import repeat

from utils.io_optim import (
    TorchaudioDataset,
    LibrosaDataset,
    FFmpegDataset,
    collate_batch,
)
import whisper
from modules.wenet_extractor.utils.init_model import init_model
from modules.wenet_extractor.utils.checkpoint import load_checkpoint

"""
    Extractor for content features
    1. whisper
    2. contentvec
    3. wenet
    4. mert

    Pipeline:
        in preprocess.py:
            call extract_utt_content_features() to extract content features for each utterance
            extract_utt_content_features() envelopes the following steps:
                1. load the model (whisper, contentvec, wenet)
                2. extract the content features
                3. save the content features into files
        in svc_dataset.py:
            call offline_align() to align the content features to the given target length

"""

"""
    Extractor Usage:
        1. initialize an instance of extractor
            extractor = WhisperExtractor(cfg)
        2. load the specified model
            extractor.load_model()
        3. extract the content features
            extractor.extract_content(utt) for single utterance
            extractor.extract_content_batch(utts) for batch utterances
        4. save the content features
            extractor.save_feature(utt, content_feature) for single utterance
"""


class AudioPretrainedModelFeaturesExtractor:
    def __init__(self, cfg, extractor_type):
        self.cfg = cfg
        self.extractor_type = extractor_type
        self.model = None
        self.init_for_retrans()

    def init_for_retrans(self):
        target_hop = self.cfg.preprocess.hop_size

        assert self.extractor_type in ["whisper", "contentvec", "wenet"]
        if self.extractor_type == "whisper":
            source_hop = (
                self.cfg.preprocess.whisper_frameshift
                * self.cfg.preprocess.whisper_downsample_rate
                * self.cfg.preprocess.sample_rate
            )
        elif self.extractor_type == "contentvec":
            source_hop = (
                self.cfg.preprocess.contentvec_frameshift
                * self.cfg.preprocess.sample_rate
            )
        elif self.extractor_type == "wenet":
            source_hop = (
                self.cfg.preprocess.wenet_frameshift
                * self.cfg.preprocess.wenet_downsample_rate
                * self.cfg.preprocess.sample_rate
            )
        source_hop = int(source_hop)
        factor = np.gcd(source_hop, target_hop)
        source_hop //= factor
        target_hop //= factor

        self.source_hop = source_hop
        self.target_hop = target_hop

    def offline_resolution_transformation(self, content, target_len):
        """
        args:
            content: (source_len, dim)
            target_len: target length
        return:
            mapped_feature: (target_len, dim)
        """
        source_hop = self.source_hop
        target_hop = self.target_hop

        # (source_len, 256)
        _, width = content.shape
        # slice the content from padded feature
        source_len = min(target_len * target_hop // source_hop + 1, len(content))

        # const ~= target_len * target_hop
        const = source_len * source_hop // target_hop * target_hop

        # (source_len * source_hop, dim)
        up_sampling_feats = np.repeat(content, source_hop, axis=0)
        # (const, dim) -> (const/target_hop, target_hop, dim) -> (const/target_hop, dim)
        down_sampling_feats = np.average(
            up_sampling_feats[:const].reshape(-1, target_hop, width), axis=1
        )

        err = abs(target_len - len(down_sampling_feats))
        if err > 8:
            # err_log_dir is indeterminate
            err_log_dir = os.path.join(
                self.cfg.preprocess.processed_dir, "align_max_err.log"
            )
            try:
                with open(err_log_dir, "r") as f:
                    err_num = int(f.read())
            except:
                with open(err_log_dir, "w") as f:
                    f.write("0")
                err_num = 0
            if err > err_num:
                with open(err_log_dir, "w") as f:
                    f.write(str(err))

        if len(down_sampling_feats) < target_len:
            # (1, dim) -> (err, dim)
            end = down_sampling_feats[-1][None, :].repeat(err, axis=0)
            down_sampling_feats = np.concatenate([down_sampling_feats, end], axis=0)

        # (target_len, dim)
        mapped_feature = down_sampling_feats[:target_len]

        return mapped_feature

    def log_for_ReTrans(self, err):
        err_log_dir = os.path.join(
            self.cfg.preprocess.processed_dir, "align_max_err.log"
        )
        try:
            with open(err_log_dir, "r") as f:
                err_num = int(f.read())
        except:
            with open(err_log_dir, "w") as f:
                f.write("0")
            err_num = 0
        if err > err_num:
            with open(err_log_dir, "w") as f:
                f.write(str(err))

    def ReTrans(self, source_feats, padded_target_len):
        """
        Resolution Transformation for mismatched frames alginment.

        TODO: Merge the offline resolution_transformation into one

        args:
            source_feats: Tensor, (B, padded_source_len, D)
            padded_target_len: int, the maximum target length in a batch
        return:
            mapped_feature: Tensor, (B, padded_target_len, D)
        """
        source_hop = self.source_hop
        target_hop = self.target_hop

        # (B, padded_source_len, D)
        B, padded_source_len, D = source_feats.shape

        # select the valid content from padded feature
        source_len = min(
            padded_target_len * target_hop // source_hop + 1, padded_source_len
        )

        # const ~= padded_target_len * target_hop (padded wav's duration)
        const = source_len * source_hop // target_hop * target_hop

        # (B, padded_source_len, D) -> (B, padded_source_len * source_hop, D) -> (B, const, D)
        up_sampling_feats = torch.repeat_interleave(source_feats, source_hop, dim=1)[
            :, :const
        ]
        # (B, const, D) -> (B, const/target_hop, target_hop, D) -> (B, const/target_hop, D)
        down_sampling_feats = torch.mean(
            up_sampling_feats.reshape(B, -1, target_hop, D), dim=2
        )

        err = abs(padded_target_len - down_sampling_feats.shape[1])
        if err > 8:
            self.log_for_ReTrans(err)

        if down_sampling_feats.shape[1] < padded_target_len:
            # (B, 1, D) -> (B, err, D)
            end = down_sampling_feats[:, -1, :][:, None, :].repeat_interleave(
                err, dim=1
            )
            # -> (B, padded_target_len, D)
            down_sampling_feats = torch.cat([down_sampling_feats, end], dim=1)

        # (B, padded_target_len, D)
        mapped_feature = down_sampling_feats[:, :padded_target_len]
        return mapped_feature

    def get_valid_features(self, utt, content_feature):
        # only keep effective parts
        duration = utt["Duration"]
        if self.extractor_type == "whisper":
            frameshift = (
                self.cfg.preprocess.whisper_frameshift
                * self.cfg.preprocess.whisper_downsample_rate
            )  # 20ms
        elif self.extractor_type == "contentvec":
            frameshift = self.cfg.preprocess.contentvec_frameshift  # 20ms
        elif self.extractor_type == "wenet":
            frameshift = (
                self.cfg.preprocess.wenet_frameshift
                * self.cfg.preprocess.wenet_downsample_rate
            )  # 40ms
        elif self.extractor_type == "mert":
            frameshift = self.cfg.preprocess.mert_frameshift
        else:
            raise NotImplementedError

        # calculate the number of valid frames
        num_frames = int(np.ceil((duration - frameshift) / frameshift)) + 1
        assert (
            len(content_feature.shape) == 2
        ), "content feature shape error, it should be (num_frames, dim)"
        content_feature = content_feature[:num_frames, :]
        return content_feature

    def save_feature(self, utt, content_feature):
        """Save a single utternace to path {cfg.preprocess.processed_dir}

        Args:
            utt (dict): one item in metadata, containing information for one utterance
            content_feature (tensor): content feature of one utterance
        """
        uid = utt["Uid"]
        assert self.extractor_type != None
        out_dir = os.path.join(
            self.cfg.preprocess.processed_dir, utt["Dataset"], self.extractor_type
        )
        os.makedirs(out_dir, exist_ok=True)
        save_path = os.path.join(out_dir, uid + ".npy")

        content_feature = self.get_valid_features(utt, content_feature)
        np.save(save_path, content_feature.cpu().detach().numpy())


class WhisperExtractor(AudioPretrainedModelFeaturesExtractor):
    def __init__(self, config):
        super(WhisperExtractor, self).__init__(config, extractor_type="whisper")
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def load_model(self):
        # load whisper checkpoint
        print("Loading Whisper Model...")

        if "whisper_model_path" in self.cfg.preprocess:
            if os.path.isfile(self.cfg.preprocess.whisper_model_path):
                # "pretrained/whisper/medium.pt"
                download_root = os.path.dirname(self.cfg.preprocess.whisper_model_path)
            elif os.path.isdir(self.cfg.preprocess.whisper_model_path):
                # "pretrained/whisper"
                download_root = self.cfg.preprocess.whisper_model_path
            else:
                # if the path does not exist, download the model to the path
                download_root = self.cfg.preprocess.whisper_model_path
                if download_root.endswith(".pt"):
                    download_root = os.path.dirname(download_root)
        else:
            download_root = None

        model = whisper.load_model(
            self.cfg.preprocess.whisper_model, self.device, download_root
        )
        if torch.cuda.is_available():
            print("Using GPU...\n")
            model = model.cuda()
        else:
            print("Using CPU...\n")

        self.model = model.eval()

    def extract_content_features(self, wavs):
        """extract content features from a batch of dataloader
        Args:
            wavs: tensor (batch_size, T)
        """
        # wavs: (batch, max_len)
        wavs = whisper.pad_or_trim(wavs)
        # batch_mel: (batch, 80, 3000)
        batch_mel = whisper.log_mel_spectrogram(wavs, device=self.model.device)
        with torch.no_grad():
            # (batch, 1500, 1024)
            features = self.model.embed_audio(batch_mel)
        return features


class ContentvecExtractor(AudioPretrainedModelFeaturesExtractor):
    def __init__(self, cfg):
        super(ContentvecExtractor, self).__init__(cfg, extractor_type="contentvec")

    def load_model(self):
        assert self.model == None
        # Load model
        ckpt_path = self.cfg.preprocess.contentvec_file
        print("Load Contentvec Model...")

        models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
            [ckpt_path],
            suffix="",
        )
        model = models[0]
        model.eval()

        if torch.cuda.is_available():
            # print("Using GPU...\n")
            model = model.cuda()

        self.model = model

    def extract_content_features(self, wavs):
        """extract content features from a batch of dataloader
        Args:
            wavs: tensor (batch, T)
        """
        device = next(self.model.parameters()).device
        wavs = wavs.to(device)  # (batch, max_len)
        padding_mask = torch.eq(wavs, torch.zeros_like(wavs)).to(device)
        with torch.no_grad():
            logits = self.model.extract_features(
                source=wavs, padding_mask=padding_mask, output_layer=12
            )
            # feats: (batch, T, 256)
            feats = self.model.final_proj(logits[0])
        return feats


class WenetExtractor(AudioPretrainedModelFeaturesExtractor):
    def __init__(self, config):
        super(WenetExtractor, self).__init__(config, extractor_type="wenet")

    def load_model(self):
        wenet_cfg = self.cfg.preprocess.wenet_config
        wenet_model_path = self.cfg.preprocess.wenet_model_path
        # load Wenet config
        with open(wenet_cfg, "r") as w:
            wenet_configs = yaml.load(w, Loader=yaml.FullLoader)
        self.extract_conf = copy.deepcopy(wenet_configs["dataset_conf"])
        print("Loading Wenet Model...")
        self.model = init_model(wenet_configs)
        load_checkpoint(self.model, wenet_model_path)

        if torch.cuda.is_available():
            print("Using GPU...\n")
            self.model = self.model.cuda()
        else:
            print("Using CPU...\n")

        self.model = self.model.eval()

    def extract_content_features(self, wavs, lens):
        """extract content features from a batch of dataloader
        Args:
            wavs: tensor, whose shape is (B, T)
            lens: list
        """
        feats_list = []
        lengths_list = []

        device = next(self.model.parameters()).device
        # Extract fbank/mfcc features by kaldi
        assert self.extract_conf is not None, "load model first!"
        feats_type = self.extract_conf.get("feats_type", "fbank")
        assert feats_type in ["fbank", "mfcc"]

        for idx, wav in enumerate(wavs):
            # wav: (T)
            wav = wav[: lens[idx]].to(device)

            # pad one frame to compensate for the frame cut off after feature extraction
            pad_tensor = torch.zeros(160, device=wav.device)
            wav = torch.cat((wav, pad_tensor), dim=-1)
            wav *= 1 << 15

            wav = wav.unsqueeze(0)  # (T) -> (1, T)
            if feats_type == "fbank":
                fbank_conf = self.extract_conf.get("fbank_conf", {})
                feat = kaldi.fbank(
                    wav,
                    sample_frequency=16000,
                    num_mel_bins=fbank_conf["num_mel_bins"],
                    frame_length=fbank_conf["frame_length"],
                    frame_shift=fbank_conf["frame_shift"],
                    dither=fbank_conf["dither"],
                )
            elif feats_type == "mfcc":
                mfcc_conf = self.extract_conf.get("mfcc", {})
                feat = kaldi.mfcc(
                    wav,
                    sample_frequency=16000,
                    num_mel_bins=mfcc_conf["num_mel_bins"],
                    frame_length=mfcc_conf["frame_length"],
                    frame_shift=mfcc_conf["frame_shift"],
                    dither=mfcc_conf["dither"],
                    num_ceps=mfcc_conf.get("num_ceps", 40),
                    high_freq=mfcc_conf.get("high_freq", 0.0),
                    low_freq=mfcc_conf.get("low_freq", 20.0),
                )
            feats_list.append(feat)
            lengths_list.append(feat.shape[0])

        feats_lengths = torch.tensor(lengths_list, dtype=torch.int32).to(device)
        feats_tensor = pad_sequence(feats_list, batch_first=True).to(
            device
        )  # (batch, len, 80)

        features = self.model.encoder_extractor(
            feats_tensor,
            feats_lengths,
            decoding_chunk_size=-1,
            num_decoding_left_chunks=-1,
            simulate_streaming=False,
        )
        return features


class MertExtractor(AudioPretrainedModelFeaturesExtractor):
    def __init__(self, cfg):
        super(MertExtractor, self).__init__(cfg, extractor_type="mert")
        self.preprocessor = None

    def load_model(self):
        assert self.model == None
        assert self.preprocessor == None

        print("Loading MERT Model: ...", self.cfg.preprocess.mert_model)

        model_name = self.cfg.preprocess.mert_model
        model = AutoModel.from_pretrained(model_name, trust_remote_code=True)

        if torch.cuda.is_available():
            model = model.cuda()
        preprocessor = Wav2Vec2FeatureExtractor.from_pretrained(
            model_name, trust_remote_code=True
        )

        self.model = model
        self.preprocessor = preprocessor

    def extract_content_features(self, wavs):
        """extract content features from a batch of dataloader
        Args:
            wavs: tensor (batch, T)
        """
        with torch.no_grad():
            sample_rate = self.preprocessor.sampling_rate
            device = next(self.model.parameters()).device
            assert (
                sample_rate == self.cfg.preprocess.mert_sample_rate
            ), "mert sample rate mismatch, expected {}, got {}".format(
                self.cfg.preprocess.mert_sample_rate, sample_rate
            )
            mert_features = []
            # wav: (len)
            for wav in wavs:
                # {input_values: tensor, attention_mask: tensor}
                inputs = self.preprocessor(
                    wavs, sampling_rate=sample_rate, return_tensors="pt"
                ).to(device)

                outputs = self.model(**inputs, output_hidden_states=True)
                # (25 layers, time steps, 1024 feature_dim)
                all_layer_hidden_states = torch.stack(outputs.hidden_states).squeeze()
                # (1, frame_len, 1024) -> (frame_len, 1024)
                feature = outputs.hidden_states[
                    self.cfg.preprocess.mert_feature_layer
                ].squeeze(0)
                mert_features.append(feature)

        return mert_features


class HubertExtractor(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.load_model()

    def load_model(self):
        kmeans_model_path = self.cfg.preprocess.kmeans_model_path
        hubert_model_path = self.cfg.preprocess.hubert_model_path
        print("Load Hubert Model...")
        checkpoint = torch.load(hubert_model_path)
        load_model_input = {hubert_model_path: checkpoint}
        model, *_ = checkpoint_utils.load_model_ensemble_and_task(load_model_input)
        self.model = model[0]
        self.model.eval()

        # Load KMeans cluster centers
        kmeans = joblib.load(kmeans_model_path)
        self.kmeans = kmeans

        self.register_buffer(
            "cluster_centers", torch.from_numpy(kmeans.cluster_centers_)
        )

    def extract_content_features(self, wav_input):
        """
        Extract content features and quantize using KMeans clustering.

        Args:
            audio_data: tensor (batch_size, T)

        Returns:
            quantize: tensor (batch_size, T, 768)
        """
        # Extract features using HuBERT
        wav_input = F.pad(wav_input, (40, 40), "reflect")
        embed = self.model(
            wav_input,
            features_only=True,
            mask=False,
            output_layer=9,  # adjust based on the hubert model used
        )["x"]

        # make the hop size chaged from 200 to 320
        embed = embed.permute((0, 2, 1))
        embed = F.interpolate(embed, scale_factor=1.6, mode="nearest")
        embed = embed.permute((0, 2, 1))

        batched_cluster_centers = repeat(
            self.cluster_centers, "c d -> b c d", b=embed.shape[0]
        )

        dists = -torch.cdist(embed, batched_cluster_centers, p=2)
        clusters = dists.argmax(dim=-1)  # (batch, seq_len)
        quantize = F.embedding(clusters, self.cluster_centers)

        return clusters, quantize


def extract_utt_content_features_dataloader(cfg, metadata, num_workers):
    dataset_name = metadata[0]["Dataset"]
    with torch.no_grad():
        if cfg.preprocess.extract_whisper_feature:
            feat_dir = os.path.join(
                cfg.preprocess.processed_dir, dataset_name, "whisper"
            )
            os.makedirs(feat_dir, exist_ok=True)
            feat_files_num = len(os.listdir(feat_dir))

            if feat_files_num != len(metadata):
                whisper_waveforms = FFmpegDataset(
                    cfg,
                    dataset_name,
                    cfg.preprocess.whisper_sample_rate,
                    metadata=metadata,
                )
                data_loader = DataLoader(
                    whisper_waveforms,
                    num_workers=num_workers,
                    shuffle=False,
                    pin_memory=cfg.preprocess.pin_memory,
                    batch_size=cfg.preprocess.content_feature_batch_size,
                    collate_fn=collate_batch,
                    drop_last=False,
                )
                extractor = WhisperExtractor(cfg)
                extractor.load_model()
                for batch_idx, items in enumerate(tqdm(data_loader)):
                    _metadata, wavs, lens = items

                    batch_content_features = extractor.extract_content_features(wavs)
                    for index, utt in enumerate(_metadata):
                        extractor.save_feature(utt, batch_content_features[index])

        if cfg.preprocess.extract_contentvec_feature:
            feat_dir = os.path.join(
                cfg.preprocess.processed_dir, dataset_name, "contentvec"
            )
            os.makedirs(feat_dir, exist_ok=True)
            feat_files_num = len(os.listdir(feat_dir))

            if feat_files_num != len(metadata):
                contentvec_waveforms = LibrosaDataset(
                    cfg,
                    dataset_name,
                    cfg.preprocess.contentvec_sample_rate,
                    metadata=metadata,
                )
                data_loader = DataLoader(
                    contentvec_waveforms,
                    num_workers=num_workers,
                    shuffle=False,
                    pin_memory=cfg.preprocess.pin_memory,
                    batch_size=cfg.preprocess.content_feature_batch_size,
                    collate_fn=collate_batch,
                    drop_last=False,
                )
                extractor = ContentvecExtractor(cfg)
                extractor.load_model()
                for batch_idx, items in enumerate(tqdm(data_loader)):
                    _metadata, wavs, lens = items

                    batch_content_features = extractor.extract_content_features(wavs)
                    for index, utt in enumerate(_metadata):
                        extractor.save_feature(utt, batch_content_features[index])

        if cfg.preprocess.extract_wenet_feature:
            feat_dir = os.path.join(cfg.preprocess.processed_dir, dataset_name, "wenet")
            os.makedirs(feat_dir, exist_ok=True)
            feat_files_num = len(os.listdir(feat_dir))

            if feat_files_num != len(metadata):
                wenet_waveforms = TorchaudioDataset(
                    cfg,
                    dataset_name,
                    cfg.preprocess.wenet_sample_rate,
                    metadata=metadata,
                )
                data_loader = DataLoader(
                    wenet_waveforms,
                    num_workers=num_workers,
                    shuffle=False,
                    pin_memory=cfg.preprocess.pin_memory,
                    batch_size=cfg.preprocess.content_feature_batch_size,
                    collate_fn=collate_batch,
                    drop_last=False,
                )
                extractor = WenetExtractor(cfg)
                extractor.load_model()
                for batch_idx, items in enumerate(tqdm(data_loader)):
                    _metadata, wavs, lens = items

                    batch_content_features = extractor.extract_content_features(
                        wavs,
                        lens,
                    )
                    for index, utt in enumerate(_metadata):
                        extractor.save_feature(utt, batch_content_features[index])

        if cfg.preprocess.extract_mert_feature:
            feat_dir = os.path.join(cfg.preprocess.processed_dir, dataset_name, "mert")
            os.makedirs(feat_dir, exist_ok=True)
            feat_files_num = len(os.listdir(feat_dir))

            if feat_files_num != len(metadata):
                mert_waveforms = TorchaudioDataset(
                    cfg,
                    dataset_name,
                    cfg.preprocess.mert_sample_rate,
                    metadata=metadata,
                )
                data_loader = DataLoader(
                    mert_waveforms,
                    num_workers=num_workers,
                    shuffle=False,
                    pin_memory=cfg.preprocess.pin_memory,
                    batch_size=cfg.preprocess.content_feature_batch_size,
                    collate_fn=collate_batch,
                    drop_last=False,
                )
                extractor = MertExtractor(cfg)
                extractor.load_model()
                for batch_idx, items in enumerate(tqdm(data_loader)):
                    _metadata, wavs, lens = items

                    batch_content_features = extractor.extract_content_features(wavs)
                    for index, utt in enumerate(_metadata):
                        extractor.save_feature(utt, batch_content_features[index])

back to top

Software Heritage — Copyright (C) 2015–2026, The Software Heritage developers. License: GNU AGPLv3+.
The source code of Software Heritage itself is available on our development forge.
The source code files archived by Software Heritage are available under their own copyright and licenses.
Terms of use: Archive access, API— Content policy— Contact— JavaScript license information— Web API