https://github.com/freewym/espresso
Tip revision: 660facf088ded9f084cc1a24a1f00f64ce5f6918 authored by freewym on 20 July 2023, 23:05:26 UTC
allows dictionary files w/o the counts column; rename task's
allows dictionary files w/o the counts column; rename task's
Tip revision: 660facf
dump_posteriors.py
#!/usr/bin/env python3
# Copyright (c) Yiming Wang
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Dump frame-level posteriors (intepreted as log probabilities) with a trained model
for decoding with Kaldi.
"""
import ast
import logging
import os
import sys
from argparse import Namespace
import numpy as np
import torch
from omegaconf import DictConfig
from fairseq import checkpoint_utils, options, tasks, utils
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.logging import progress_bar
from fairseq.logging.meters import StopwatchMeter
try:
import kaldi_io
except ImportError:
raise ImportError("Please install kaldi_io with: pip install kaldi_io")
def main(cfg: DictConfig):
if isinstance(cfg, Namespace):
cfg = convert_namespace_to_omegaconf(cfg)
assert cfg.common_eval.path is not None, "--path required for decoding!"
return _main(cfg, sys.stderr)
def _main(cfg, output_file):
logging.basicConfig(
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=os.environ.get("LOGLEVEL", "INFO").upper(),
stream=output_file,
)
logger = logging.getLogger("espresso.dump_posteriors")
print_options_meaning_changes(cfg, logger)
utils.import_user_module(cfg.common)
if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None:
cfg.dataset.max_tokens = 12000
logger.info(cfg)
# Fix seed for stochastic decoding
if cfg.common.seed is not None and not cfg.generation.no_seed_provided:
np.random.seed(cfg.common.seed)
utils.set_torch_seed(cfg.common.seed)
use_cuda = torch.cuda.is_available() and not cfg.common.cpu
task = tasks.setup_task(cfg.task)
overrides = ast.literal_eval(cfg.common_eval.model_overrides)
# Load ensemble
logger.info("loading model(s) from {}".format(cfg.common_eval.path))
models, saved_cfg = checkpoint_utils.load_model_ensemble(
utils.split_paths(cfg.common_eval.path),
arg_overrides=overrides,
task=task,
suffix=cfg.checkpoint.checkpoint_suffix,
strict=(cfg.checkpoint.checkpoint_shard_count == 1),
num_shards=cfg.checkpoint.checkpoint_shard_count,
)
# loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config
task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task)
# Load state prior for cross-entropy trained systems decoding
if cfg.generation.state_prior_file is not None:
prior = torch.from_numpy(kaldi_io.read_vec_flt(cfg.generation.state_prior_file))
else:
prior = []
# Optimize ensemble for generation
for model in models:
if model is None:
continue
if cfg.common.fp16:
model.half()
if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
model.cuda()
model.prepare_for_inference_(cfg)
if isinstance(prior, list) and getattr(model, "state_prior", None) is not None:
prior.append(model.state_prior.unsqueeze(0))
if isinstance(prior, list) and len(prior) > 0:
prior = torch.cat(prior, 0).mean(0) # average priors across models
prior = prior / prior.sum() # re-normalize
elif isinstance(prior, list):
prior = None
if prior is not None:
if cfg.common.fp16:
prior = prior.half()
if use_cuda:
prior = prior.cuda()
log_prior = prior.log()
else:
log_prior = None
# Load dataset (possibly sharded)
itr = task.get_batch_iterator(
dataset=task.dataset(cfg.dataset.gen_subset),
max_tokens=cfg.dataset.max_tokens,
max_sentences=cfg.dataset.batch_size,
max_positions=utils.resolve_max_positions(
task.max_positions(), *[m.max_positions() for m in models]
),
ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=cfg.dataset.required_batch_size_multiple,
seed=cfg.common.seed,
num_shards=cfg.distributed_training.distributed_world_size,
shard_id=cfg.distributed_training.distributed_rank,
num_workers=cfg.dataset.num_workers,
data_buffer_size=cfg.dataset.data_buffer_size,
).next_epoch_itr(shuffle=False)
progress = progress_bar.progress_bar(
itr,
log_format=cfg.common.log_format,
log_interval=cfg.common.log_interval,
default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
)
# Initialize generator
gen_timer = StopwatchMeter()
generator = task.build_generator(models, cfg.generation)
# Generate and dump
num_sentences = 0
chunk_width = getattr(task, "chunk_width", None)
lprobs_wspecifier = "ark:| copy-matrix ark:- ark:-"
with kaldi_io.open_or_fd(lprobs_wspecifier, "wb") as f:
if chunk_width is None: # normal dumping (i.e., no chunking)
for sample in progress:
sample = utils.move_to_cuda(sample) if use_cuda else sample
if "net_input" not in sample:
continue
gen_timer.start()
lprobs, padding_mask = task.inference_step(generator, models, sample)
if log_prior is not None:
assert lprobs.size(-1) == log_prior.size(0)
lprobs = lprobs - log_prior
out_lengths = (
(~padding_mask).long().sum(dim=1).cpu()
if padding_mask is not None
else None
)
num_processed_frames = sample["ntokens"]
gen_timer.stop(num_processed_frames)
num_sentences += (
sample["nsentences"]
if "nsentences" in sample
else sample["id"].numel()
)
if out_lengths is not None:
for i in range(sample["nsentences"]):
length = out_lengths[i]
kaldi_io.write_mat(
f,
lprobs[i, :length, :].cpu().numpy(),
key=sample["utt_id"][i],
)
else:
for i in range(sample["nsentences"]):
kaldi_io.write_mat(
f, lprobs[i, :, :].cpu().numpy(), key=sample["utt_id"][i]
)
else: # dumping chunks within the same utterance from left to right
for sample in progress: # sample is actually a list of batches
sample = utils.move_to_cuda(sample) if use_cuda else sample
utt_id = sample[0]["utt_id"]
id = sample[0]["id"]
whole_lprobs = None
for i, chunk_sample in enumerate(sample):
if "net_input" not in chunk_sample:
continue
assert (
chunk_sample["utt_id"] == utt_id
and (chunk_sample["id"] == id).all()
)
gen_timer.start()
lprobs, _ = task.inference_step(generator, models, chunk_sample)
if log_prior is not None:
assert lprobs.size(-1) == log_prior.size(0)
lprobs = lprobs - log_prior
if whole_lprobs is None:
whole_lprobs = lprobs.cpu()
else:
whole_lprobs = torch.cat((whole_lprobs, lprobs.cpu()), 1)
num_processed_frames = chunk_sample["ntokens"]
gen_timer.stop(num_processed_frames)
if i == len(sample) - 1:
num_sentences += len(utt_id)
for j in range(len(utt_id)):
truncated_length = models[0].output_lengths(
task.dataset(cfg.dataset.gen_subset).src_sizes[id[j]]
) # length is after possible subsampling by the model
mat = whole_lprobs[j, :truncated_length, :]
kaldi_io.write_mat(f, mat.numpy(), key=utt_id[j])
logger.info(
"Dumped {:,} utterances ({} frames) in {:.1f}s ({:.2f} sentences/s, {:.2f} frames/s)".format(
num_sentences,
gen_timer.n,
gen_timer.sum,
num_sentences / gen_timer.sum,
1.0 / gen_timer.avg,
)
)
return
def print_options_meaning_changes(cfg, logger):
"""Options that have different meanings than those in the translation task
are explained here.
"""
logger.info("--max-tokens is the maximum number of input frames in a batch")
def cli_main():
parser = options.get_generation_parser(default_task="speech_recognition_hybrid")
args = options.parse_args_and_arch(parser)
main(args)
if __name__ == "__main__":
cli_main()