https://github.com/open-mmlab/Amphion
Tip revision: 9682d0c8ec07ee75b4edd0a174dff3c79a5fb4d8 authored by Xueyao Zhang on 28 November 2023, 09:53:39 UTC
Amphion Alpha Release (#2)
Amphion Alpha Release (#2)
Tip revision: 9682d0c
transformer_inference.py
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import time
import numpy as np
import torch
from tqdm import tqdm
import torch.nn as nn
from collections import OrderedDict
from models.svc.base import SVCInference
from modules.encoder.condition_encoder import ConditionEncoder
from models.svc.transformer.transformer import Transformer
from models.svc.transformer.conformer import Conformer
class TransformerInference(SVCInference):
def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
SVCInference.__init__(self, args, cfg, infer_type)
def _build_model(self):
self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min
self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max
self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder)
if self.cfg.model.transformer.type == "transformer":
self.acoustic_mapper = Transformer(self.cfg.model.transformer)
elif self.cfg.model.transformer.type == "conformer":
self.acoustic_mapper = Conformer(self.cfg.model.transformer)
else:
raise NotImplementedError
model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper])
return model
def _inference_each_batch(self, batch_data):
device = self.accelerator.device
for k, v in batch_data.items():
batch_data[k] = v.to(device)
condition = self.condition_encoder(batch_data)
y_pred = self.acoustic_mapper(condition, batch_data["mask"])
return y_pred
