https://github.com/freewym/espresso
Raw File
Tip revision: 660facf088ded9f084cc1a24a1f00f64ce5f6918 authored by freewym on 20 July 2023, 23:05:26 UTC
allows dictionary files w/o the counts column; rename task's
Tip revision: 660facf
speech_recognize.py
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# Copyright (c) Yiming Wang
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Recognize pre-processed speech with a trained model.
"""

import ast
import logging
import math
import os
import sys
from argparse import Namespace
from itertools import chain

import numpy as np
import torch
from omegaconf import DictConfig

from espresso.models.external_language_model import MultiLevelLanguageModel
from espresso.models.tensorized_lookahead_language_model import (
    TensorizedLookaheadLanguageModel,
)
from espresso.tools import wer
from espresso.tools.utils import plot_attention, sequence_mask
from fairseq import checkpoint_utils, options, tasks, utils
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.logging import progress_bar
from fairseq.logging.meters import StopwatchMeter, TimeMeter
from fairseq.models import FairseqLanguageModel


def main(cfg: DictConfig):

    if isinstance(cfg, Namespace):
        cfg = convert_namespace_to_omegaconf(cfg)

    assert cfg.common_eval.path is not None, "--path required for recognition!"
    assert (
        not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam
    ), "--sampling requires --nbest to be equal to --beam"

    if cfg.common_eval.results_path is not None:
        os.makedirs(cfg.common_eval.results_path, exist_ok=True)
        output_path = os.path.join(cfg.common_eval.results_path, "decode.log")
        with open(output_path, "w", buffering=1, encoding="utf-8") as h:
            return _main(cfg, h)
    return _main(cfg, sys.stdout)


def get_symbols_to_strip_from_output(generator):
    if hasattr(generator, "symbols_to_strip_from_output"):
        return generator.symbols_to_strip_from_output
    else:
        return {generator.eos, generator.pad}


def _main(cfg, output_file):
    logging.basicConfig(
        format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        level=os.environ.get("LOGLEVEL", "INFO").upper(),
        stream=output_file,
        force=True,
    )
    logger = logging.getLogger("espresso.speech_recognize")
    if output_file is not sys.stdout:  # also print to stdout
        logger.addHandler(logging.StreamHandler(sys.stdout))

    print_options_meaning_changes(cfg, logger)

    utils.import_user_module(cfg.common)

    if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None:
        cfg.dataset.max_tokens = 12000
    logger.info(cfg)

    # Fix seed for stochastic decoding
    if cfg.common.seed is not None and not cfg.generation.no_seed_provided:
        np.random.seed(cfg.common.seed)
        utils.set_torch_seed(cfg.common.seed)

    use_cuda = torch.cuda.is_available() and not cfg.common.cpu

    task = tasks.setup_task(cfg.task)
    task.build_tokenizer(cfg.tokenizer)
    task.build_bpe(cfg.bpe)

    # Set dictionary
    dictionary = task.target_dictionary

    overrides = ast.literal_eval(cfg.common_eval.model_overrides)

    # Load ensemble
    logger.info("loading model(s) from {}".format(cfg.common_eval.path))
    models, saved_cfg = checkpoint_utils.load_model_ensemble(
        utils.split_paths(cfg.common_eval.path),
        arg_overrides=overrides,
        task=task,
        suffix=cfg.checkpoint.checkpoint_suffix,
        strict=(cfg.checkpoint.checkpoint_shard_count == 1),
        num_shards=cfg.checkpoint.checkpoint_shard_count,
    )

    # loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config
    task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task)

    if cfg.generation.lm_path is not None:
        overrides["data"] = cfg.task.data

        try:
            logger.info(
                "loading language model(s) from {}".format(cfg.generation.lm_path)
            )
            lms, _ = checkpoint_utils.load_model_ensemble(
                utils.split_paths(cfg.generation.lm_path),
                arg_overrides=overrides,
                task=None,
            )
        except:
            logger.warning(
                f"Failed to load language model! Please make sure that the language model dict is the same "
                f"as target dict and is located in the data dir ({cfg.task.data})"
            )
            raise

        assert len(lms) == 1 or len(lms) == 2  # Multi-level LM expects two LMs
    else:
        lms = [None]

    for i, m in enumerate(lms):
        if m is None:
            continue
        if hasattr(m, "is_wordlm") and m.is_wordlm:
            # assume subword LM comes before word LM
            if i > 0 and isinstance(lms[i - 1], FairseqLanguageModel):
                lms[i - 1] = MultiLevelLanguageModel(
                    m,
                    lms[i - 1],
                    subwordlm_weight=cfg.generation.subwordlm_weight,
                    oov_penalty=cfg.generation.oov_penalty,
                    open_vocab=not cfg.generation.disable_open_vocab,
                )
                del lms[i]
                logger.info("LM shallow fusion with Multi-level LM")
            else:
                lms[i] = TensorizedLookaheadLanguageModel(
                    m,
                    dictionary,
                    oov_penalty=cfg.generation.oov_penalty,
                    open_vocab=not cfg.generation.disable_open_vocab,
                )
                logger.info("LM shallow fusion with Look-ahead Word LM")
        else:
            assert isinstance(m, FairseqLanguageModel)
            logger.info("LM fusion with Subword LM")
    if cfg.generation.lm_weight != 0.0:
        logger.info(
            "using LM shallow fusion with lm-weight={:.2f}".format(
                cfg.generation.lm_weight
            )
        )

    # Optimize ensemble for generation
    for model in chain(models, lms):
        if model is None:
            continue
        if cfg.common.fp16:
            model.half()
        if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
            model.cuda()
        model.prepare_for_inference_(cfg)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(cfg.dataset.gen_subset),
        max_tokens=cfg.dataset.max_tokens,
        max_sentences=cfg.dataset.batch_size,
        max_positions=utils.resolve_max_positions(
            task.max_positions(), *[m.max_positions() for m in models]
        ),
        ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=cfg.dataset.required_batch_size_multiple,
        seed=cfg.common.seed,
        num_shards=cfg.distributed_training.distributed_world_size,
        shard_id=cfg.distributed_training.distributed_rank,
        num_workers=cfg.dataset.num_workers,
        data_buffer_size=cfg.dataset.data_buffer_size,
    ).next_epoch_itr(shuffle=False)
    progress = progress_bar.progress_bar(
        itr,
        log_format=cfg.common.log_format,
        log_interval=cfg.common.log_interval,
        default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
    )

    # Initialize generator
    if cfg.generation.match_source_len:
        logger.warning(
            "The option match_source_len is not applicable to speech recognition. Ignoring it."
        )
    gen_timer = StopwatchMeter()

    extra_gen_cls_kwargs = {
        "lm_model": lms[0],
        "lm_weight": cfg.generation.lm_weight,
        "eos_factor": cfg.generation.eos_factor,
    }
    cfg.generation.score_reference = False  # not applicable for ASR
    is_attention_model = cfg.task.criterion_name not in ["transducer_loss", "ctc_loss"]
    save_attention_plot = (
        cfg.generation.print_alignment is not None and is_attention_model
    )
    if is_attention_model:
        cfg.generation.print_alignment = (
            None  # not applicable for attention-based encoder-decoder ASR models
        )
    generator = task.build_generator(
        models, cfg.generation, extra_gen_cls_kwargs=extra_gen_cls_kwargs
    )

    # Handle tokenization and BPE
    tokenizer = task.build_tokenizer(cfg.tokenizer)
    bpe = task.build_bpe(cfg.bpe)

    def decode_fn(x):
        if bpe is not None:
            x = bpe.decode(x)
        if tokenizer is not None:
            x = tokenizer.decode(x)
        return x

    scorer = wer.Scorer(dictionary, wer_output_filter=cfg.task.wer_output_filter)

    num_sentences = 0
    has_target = True
    wps_meter = TimeMeter()
    for sample in progress:
        sample = utils.move_to_cuda(sample) if use_cuda else sample
        if "net_input" not in sample:
            continue

        prefix_tokens = None
        if cfg.generation.prefix_size > 0:
            prefix_tokens = sample["target"][:, : cfg.generation.prefix_size]

        constraints = None
        if "constraints" in sample:
            constraints = sample["constraints"]

        gen_timer.start()
        hypos = task.inference_step(
            generator,
            models,
            sample,
            prefix_tokens=prefix_tokens,
            constraints=constraints,
        )
        num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
        gen_timer.stop(num_generated_tokens)

        # obtain nonpad mask of encoder output to plot attentions
        if save_attention_plot:
            net_input = sample["net_input"]
            src_tokens = net_input["src_tokens"]
            output_lengths = models[0].encoder.output_lengths(net_input["src_lengths"])
            nonpad_idxs = sequence_mask(
                output_lengths, models[0].encoder.output_lengths(src_tokens.size(1))
            )

        for i in range(len(sample["id"])):
            has_target = sample["target"] is not None
            utt_id = sample["utt_id"][i]

            # Retrieve the original sentences
            if has_target:
                target_str = dictionary.wordpiece_encode(sample["text"][i])
                if not cfg.common_eval.quiet:
                    print(
                        "T-{}\t{}".format(utt_id, sample["text"][i]), file=output_file
                    )

            # Process top predictions
            for j, hypo in enumerate(hypos[i][: cfg.generation.nbest]):
                hypo_str = dictionary.string(
                    hypo["tokens"].int().cpu(),
                    bpe_symbol=None,
                    extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator),
                )  # not removing bpe at this point
                detok_hypo_str = decode_fn(hypo_str)
                if not cfg.common_eval.quiet:
                    score = hypo["score"] / math.log(2)  # convert to base 2
                    print(
                        "H-{}\t{}\t{}".format(utt_id, detok_hypo_str, score),
                        file=output_file,
                    )

                # Score and obtain attention only the top hypothesis
                if j == 0:
                    # src_len x tgt_len
                    attention = (
                        hypo["attention"][nonpad_idxs[i]].float().cpu()
                        if save_attention_plot and hypo["attention"] is not None
                        else None
                    )
                    if save_attention_plot and attention is not None:
                        save_dir = os.path.join(
                            cfg.common_eval.results_path, "attn_plots"
                        )
                        os.makedirs(save_dir, exist_ok=True)
                        plot_attention(attention, detok_hypo_str, utt_id, save_dir)
                    scorer.add_prediction(utt_id, hypo_str)
                    if has_target:
                        scorer.add_evaluation(utt_id, target_str, hypo_str)

        wps_meter.update(num_generated_tokens)
        progress.log({"wps": round(wps_meter.avg)})
        num_sentences += (
            sample["nsentences"] if "nsentences" in sample else sample["id"].numel()
        )

    logger.info("NOTE: hypothesis and token scores are output in base 2")
    logger.info(
        "Recognized {:,} utterances ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)".format(
            num_sentences,
            gen_timer.n,
            gen_timer.sum,
            num_sentences / gen_timer.sum,
            1.0 / gen_timer.avg,
        )
    )
    if save_attention_plot:
        logger.info("Saved attention plots in " + save_dir)

    if has_target and hasattr(task.datasets[cfg.dataset.gen_subset], "tgt"):
        scorer.add_ordered_utt_list(task.datasets[cfg.dataset.gen_subset].tgt.utt_ids)

    fn = "decoded_char_results.txt"
    with open(
        os.path.join(cfg.common_eval.results_path, fn), "w", encoding="utf-8"
    ) as f:
        f.write(scorer.print_char_results())
        logger.info("Decoded char results saved as " + f.name)

    fn = "decoded_results.txt"
    with open(
        os.path.join(cfg.common_eval.results_path, fn), "w", encoding="utf-8"
    ) as f:
        f.write(scorer.print_results())
        logger.info("Decoded results saved as " + f.name)

    if has_target:
        header = "Recognize {} with beam={}: ".format(
            cfg.dataset.gen_subset, cfg.generation.beam
        )
        fn = "wer"
        with open(
            os.path.join(cfg.common_eval.results_path, fn), "w", encoding="utf-8"
        ) as f:
            res = "WER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%, #words={:d}".format(
                *(scorer.wer()), scorer.tot_word_count()
            )
            logger.info(header + res)
            f.write(res + "\n")
            logger.info("WER saved in " + f.name)

        fn = "cer"
        with open(
            os.path.join(cfg.common_eval.results_path, fn), "w", encoding="utf-8"
        ) as f:
            res = "CER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%, #chars={:d}".format(
                *(scorer.cer()), scorer.tot_char_count()
            )
            logger.info(" " * len(header) + res)
            f.write(res + "\n")
            logger.info("CER saved in " + f.name)

        fn = "aligned_results.txt"
        with open(
            os.path.join(cfg.common_eval.results_path, fn), "w", encoding="utf-8"
        ) as f:
            f.write(scorer.print_aligned_results())
            logger.info("Aligned results saved as " + f.name)
    return scorer


def print_options_meaning_changes(cfg, logger):
    """Options that have different meanings than those in the translation task
    are explained here.
    """
    logger.info("--max-tokens is the maximum number of input frames in a batch")
    is_attention_model = cfg.task.criterion_name not in ["transducer_loss", "ctc_loss"]
    if cfg.generation.print_alignment and is_attention_model:
        logger.info("--print-alignment is set to True to plot attentions")


def cli_main():
    parser = options.get_generation_parser(default_task="speech_recognition_espresso")
    args = options.parse_args_and_arch(parser)
    assert args.results_path is not None, "please specify --results-path"
    main(args)


if __name__ == "__main__":
    cli_main()
back to top