https://github.com/open-mmlab/Amphion
Tip revision: a4c23e2e1f15e4be0b0c7194e6b69a82a4bb4a07 authored by Xueyao Zhang on 18 December 2023, 14:14:33 UTC
Amphion v0.1 Release (#39)
Amphion v0.1 Release (#39)
Tip revision: a4c23e2
transformer_inference.py
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import time
import numpy as np
import torch
from tqdm import tqdm
import torch.nn as nn
from collections import OrderedDict
from models.svc.base import SVCInference
from modules.encoder.condition_encoder import ConditionEncoder
from models.svc.transformer.transformer import Transformer
from models.svc.transformer.conformer import Conformer
class TransformerInference(SVCInference):
def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
SVCInference.__init__(self, args, cfg, infer_type)
def _build_model(self):
self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min
self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max
self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder)
if self.cfg.model.transformer.type == "transformer":
self.acoustic_mapper = Transformer(self.cfg.model.transformer)
elif self.cfg.model.transformer.type == "conformer":
self.acoustic_mapper = Conformer(self.cfg.model.transformer)
else:
raise NotImplementedError
model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper])
return model
def _inference_each_batch(self, batch_data):
device = self.accelerator.device
for k, v in batch_data.items():
batch_data[k] = v.to(device)
condition = self.condition_encoder(batch_data)
y_pred = self.acoustic_mapper(condition, batch_data["mask"])
return y_pred
