https://github.com/Microsoft/CNTK
Raw File
Tip revision: 2d32149257c46a1fba1ae3fe0c424d876233e55e authored by Alexey Reznichenko on 30 June 2017, 14:29:35 UTC
Refactor and simplify MLFIndexBuilder
Tip revision: 2d32149
msra_mgram.h
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//
// msra_mgram.h -- simple ARPA LM read and access function
//

#pragma once

#include "Basics.h"
#include "fileutil.h" // for opening/reading the ARPA file
#include <vector>
#include <string>
#include <unordered_map>
#include <algorithm> // for various sort() calls
#include <math.h>

namespace msra { namespace lm {

// ===========================================================================
// core LM interface -- LM scores are accessed through this exclusively
// ===========================================================================

class ILM // generic interface -- mostly the score() function
{
public:
    virtual double score(const int *mgram, int m) const = 0;
    virtual bool oov(int w) const = 0; // needed for perplexity calculation
    // ... TODO (?): return true/false to indicate whether anything changed.
    // Intended as a signal to derived LMs that cache values.
    virtual void adapt(const int *data, size_t m) = 0; // (NULL,M) to reset, (!NULL,0) to flush

    // iterator for composing models --iterates in increasing order w.r.t. w
    class IIter
    {
    public:
        virtual operator bool() const = 0; // has iterator not yet reached end?
        // ... TODO: ensure iterators do not return OOVs w.r.t. user symbol table
        // (It needs to be checked which LM type's iterator currently does.)
        virtual void operator++() = 0; // advance by one
        // ... TODO: change this to key() or something like this
        virtual std::pair<const int *, int> operator*() const = 0; // current m-gram (mgram,m)
        virtual std::pair<double, double> value() const = 0;       // current (logP, logB)
    };
    virtual IIter *iter(int minM = 0, int maxM = INT_MAX) const = 0;
    virtual int order() const = 0;        // order, e.g. 3 for trigram
    virtual size_t size(int m) const = 0; // return #m-grams

    // diagnostics functions -- not all models implement these
    virtual int getLastLongestHistoryFound() const = 0;
    virtual int getLastLongestMGramFound() const = 0;
};

// ===========================================================================
// log-add helpers
// ===========================================================================

const double logzero = -1e30;

static inline double logadd(double x, double y)
{
    double diff = y - x;
    double sum = x; // x no longer used after this
    if (diff > 0)
    {
        sum = y;      // y no longer used after this
        diff = -diff; // that means we need to negate diff
    }
    if (diff > -24.0) // approx. from a constant from fmpe.h
        sum += log(1.0 + exp(diff));
    return sum;
}

// take the log, but clip to logzero
template <class FLOATTYPE> // float or double
static inline FLOATTYPE logclip(FLOATTYPE x)
{
    // ... TODO: use the proper constants here (slightly inconsistent)
    return x > (FLOATTYPE) 1e-30 ? log(x) : (FLOATTYPE) logzero;
}

// compute 1-P in logarithmic representation
static inline double invertlogprob(double logP)
{
    return logclip(1.0 - exp(logP));
}

// ===========================================================================
// CSymbolSet -- a simple symbol table
// ===========================================================================

// compare function to allow char* as keys (without, unordered_map will correctly
// compute a hash key from the actual strings, but then compare the pointers
// -- duh!)
struct less_strcmp : public std::binary_function<const char *, const char *, bool>
{ // this implements operator<
    bool operator()(const char *const &_Left, const char *const &_Right) const
    {
        return strcmp(_Left, _Right) < 0;
    }
};

class CSymbolSet : public std::unordered_map<const char *, int, std::hash<const char *>, less_strcmp>
{
    std::vector<const char *> symbols; // the symbols

    CSymbolSet(const CSymbolSet &);
    CSymbolSet &operator=(const CSymbolSet &);

public:
    CSymbolSet()
    {
        symbols.reserve(1000);
    }
    ~CSymbolSet()
    {
        clear();
    }

    void clear()
    {
        foreach_index (i, symbols)
            free((void *) symbols[i]);
        unordered_map::clear();
    }

    // operator[key] on a 'const' object
    // get id for an existing word, returns -1 if not existing
    int operator[](const char *key) const
    {
        unordered_map<const char *, int>::const_iterator iter = find(key);
        return (iter != end()) ? iter->second : -1;
    }

    // operator[key] on a non-'const' object
    // determine unique id for a word ('key')
    int operator[](const char *key)
    {
        unordered_map<const char *, int>::const_iterator iter = find(key);
        if (iter != end())
            return iter->second;

        // create
        const char *p = _strdup(key);
        if (!p)
            RuntimeError("CSymbolSet:id string allocation failure");

        try
        {
            int id = (int) symbols.size();
            symbols.push_back(p); // we own the memory--remember to free it
            insert(std::make_pair(p, id));
            return id;
        }
        catch (...)
        {
            free((void *) p);
            throw;
        }
    }

    // return symbol string for a given id
    // Returned pointer is owned by this object.
    inline const char *operator[](int id) const
    {
        return symbols[id];
    }

    // overloads to be compatible with C++ strings and CSymMap
    int sym2existingId(const std::string &key) const
    {
        return (*this)[key.c_str()];
    }
    int sym2id(const std::string &key)
    {
        return (*this)[key.c_str()];
    }
    inline const char *id2sym(int id)
    {
        return (*this)[id];
    }

    // some helpers for writing and reading back a symbol set
    void write(FILE *f)
    {
        fputTag(f, "SYMS");       // header
        fputint(f, (int) size()); // symbol set
        foreach_index (k, symbols)
            fputstring(f, symbols[k]);
    }

    void read(FILE *f)
    {
        clear(); // clear out what was there before (typically nothing)
        fcheckTag(f, "SYMS");
        int numWords = fgetint(f);
        char buf[1000];
        for (int k = 0; k < numWords; k++)
        {
            fgetstring(f, buf);
            int id = (*this)[buf];
            if (id != k)
                RuntimeError("plsa: sequence error while reading vocabulary");
        }
    }
};

// ===========================================================================
// mgram_map -- lookup table for mgrams
// ===========================================================================

// variable naming convention for word ids:
//  - w   a word in user space
//        Defined by userSymMap::operator[](string) passed to read().
//        Data passed to score() and adapt() functions are in 'w' space.
//  - id  an id in internal LM space
//        E.g. defined by vocabulary in LM input file.
// All external LM accesses involve an implicit mapping, including:
//  w -> id  --for calls to score() and adapt()
//  id -> w  --for iterators (IIter++ orders by and *IIter returns keys in 'w' space)

// representation of LM in memory
// LMs are stored sparsely, i.e. only used elements are stored.
// For each m-gram, a score is stored. For each history, a back-off weight is stored.
// Both are stored in flat arrays, one per order, that are concatenations of
// individual arrays per history.
// The mgram_map provides a measure of locating these entries. For each level,
// it stores a flat array of 'firsts' which point to the first child entry in
// the next level (the next 'firsts' value denotes the end).
// The mgram_map also stores word ids, which are the indexes of the sparse
// elements.
// To access an m-gram score of back-off weight, the mgram_map structure is
// traversed, involving a binary search operation at each level.

// a compact vector to hold 24-bit vaulues
class int24_vector : std::vector<unsigned char>
{
public:
    // basic (non-tricky) operations --just multiply anything by 3
    int24_vector()
    {
    }
    int24_vector(size_t n)
        : std::vector<unsigned char>(n * 3)
    {
    }
    void resize(size_t n)
    {
        std::vector<unsigned char> &base = *this;
        base.resize(n * 3);
    }
    void reserve(size_t n)
    {
        std::vector<unsigned char> &base = *this;
        base.reserve(n * 3);
    }
    void swap(int24_vector &other)
    {
        std::vector<unsigned char> &base = *this;
        base.swap(other);
    }
    size_t size() const
    {
        const std::vector<unsigned char> &base = *this;
        return base.size() / 3;
    }
    bool empty() const
    {
        const std::vector<unsigned char> &base = *this;
        return base.empty();
    }

    // a reference to a 3-byte int (not a naked pointer as we cannot just assign to it)
    template <class T>
    class uint24_ref_t
    {
    protected:
        T p;
        friend class int24_vector; // only int24_vector may instantiate this
        __forceinline uint24_ref_t(T p)
            : p(p)
        {
        }

    public:
        // access
        __forceinline operator int() const
        {
            return (((((signed char) p[2]) << 8) + p[1]) << 8) + p[0];
        }
    };
    typedef uint24_ref_t<const unsigned char *> const_uint24_ref; // const version (only read)
    class uint24_ref : public uint24_ref_t<unsigned char *>       // non-const (read and assign)
    {
        static void overflow()
        {
            RuntimeError("uint32_ref: attempting to store value > 24 bits");
        }

    protected:
        friend class int24_vector; // only int24_vector may instantiate this
        __forceinline uint24_ref(unsigned char *p)
            : uint24_ref_t(p)
        {
        }

    public:
        // assignment operator
        __forceinline int operator=(int value)
        {
            if ((unsigned int) (value + 0x800000) > 0xffffff)
                overflow();
            p[0] = (unsigned char) value;
            p[1] = (unsigned char) (value >> 8);
            p[2] = (unsigned char) (value >> 16);
            assert(value == (int) *this);
            return value;
        }
    };

    // reading and writing
    __forceinline uint24_ref operator[](size_t i)
    {
        std::vector<unsigned char> &base = *this;
        return uint24_ref(&base[i * 3]);
    }
    __forceinline const_uint24_ref operator[](size_t i) const
    {
        const std::vector<unsigned char> &base = *this;
        return const_uint24_ref(&base[i * 3]);
    }
    __forceinline int back() const
    {
        const std::vector<unsigned char> &base = *this;
        return const_uint24_ref(&base[base.size() - 3]);
    }
    void push_back(int value)
    {
        std::vector<unsigned char> &base = *this;
        size_t cursize = base.size();
        size_t newsize = cursize + 3;
        if (newsize > base.capacity())
            base.reserve(newsize * 2); // double the size to ensure constant-time
        base.resize(newsize);
        uint24_ref r = uint24_ref(&base[cursize]);
        r = value;
        assert(value == back());
    }
};

// maps from m-grams to m-gram storage locations.
class mgram_map
{
    typedef unsigned int index_t; // (-> size_t when we really need it)
    // typedef size_t index_t;                   // (tested once, seems to work)
    static const index_t nindex; // invalid index
    // entry [m][i] is first index of children in level m+1, entry[m][i+1] the end.
    int M;                                    // order, e.g. M=3 for trigram
    std::vector<std::vector<index_t>> firsts; // [M][i] ([0] = zerogram = root)
    std::vector<int24_vector> ids;            // [M+1][i] ([0] = not used)
    bool level1nonsparse;                     // true: level[1] can be directly looked up
    std::vector<index_t> level1lookup;        // id->index for unigram level
    static void fail(const char *msg)
    {
        RuntimeError("mgram_map::%s", msg);
    }

    // mapping from w -> i -- users pass 'w', internally we use our own 'ids'
    std::vector<int> w2id; // w -> id
    std::vector<int> id2w; // id -> w
    int idmax;             // max id ever encountered by create()
    inline int map(int w) const
    {
        if (w < 0 || w >= (int) w2id.size())
            return -1;
        else
            return w2id[w];
    }

    // get index for 'id' in level m+1, as a child of index i in level m.
    // Returns -1 if not found.
    // This is a relatively generic binary search.
    inline index_t find_child(int m, index_t i, int id) const
    {
        // unigram level is a special case where we can avoid searching
        if (m == 0)
        {
            if (id < 0)
                return nindex;

            if (level1nonsparse)
                i = (index_t) id;
            else // sparse: use a look-up table
            {
                if ((size_t) id >= level1lookup.size())
                    return nindex;
                i = level1lookup[id];
            }
            assert(i == nindex || ids[1][i] == id);
            return i;
        }
        index_t beg = firsts[m][i];
        index_t end = firsts[m][i + 1];
        const int24_vector &ids_m1 = ids[m + 1];
        while (beg < end)
        {
            i = (beg + end) / 2;
            int v = ids_m1[i];
            if (id == v)
                return i; // found it
            else if (id < v)
                end = i; // id is left of i
            else
                beg = i + 1; // id is right of i
        }
        return nindex; // not found
    }

public:
    // --- allocation

    mgram_map()
    {
    }
    mgram_map(int p_M)
    {
        init(p_M);
    }

    // construct
    void init(int p_M)
    {
        clear();
        M = p_M;
        firsts.assign(M, std::vector<index_t>(1, 0));
        ids.assign(M + 1, int24_vector());
        ids[0].resize(1); // fake zerogram entry for consistency
        ids[0][0] = -1;
    }
    // reserve memory for a level
    void reserve(int m, size_t size)
    {
        if (m == 0)
            return; // cannot reserve level 0
        ids[m].reserve(size);
        if (m < M)
            firsts[m].reserve(size + 1);
        if (m == 1)
            level1lookup.reserve(size);
    }
    // allow to reduce M after the fact
    void resize(int newM)
    {
        if (newM > M)
            fail("resize() can only shrink");
        M = newM;
        firsts.resize(M);
        ids.resize(M + 1);
    }
    // destruct
    void clear()
    {
        M = 0;
        firsts.clear();
        ids.clear();
        w2id.clear();
        id2w.clear();
        idmax = -1;
    }
    // size
    inline int size(int m) const
    {
        return (int) ids[m].size();
    }
    // swap --used e.g. in merging
    void swap(mgram_map &other)
    {
        std::swap(M, other.M);
        firsts.swap(other.firsts);
        ids.swap(other.ids);
        std::swap(level1nonsparse, other.level1nonsparse);
        level1lookup.swap(other.level1lookup);
        w2id.swap(other.w2id);
        id2w.swap(other.id2w);
        std::swap(idmax, other.idmax);
    }

    // --- id mapping

    // test whether a word id is known in this model
    inline bool oov(int w) const
    {
        return map(w) < 0;
    }

    // return largest used word id (=last entry in unigram ids[])
    int maxid() const
    {
        return idmax;
    }

    // return largest used w (only after created())
    int maxw() const
    {
        return -1 + (int) w2id.size();
    }

    // map is indexed with a 'key'.
    // A key represents an m-gram by storing a pointer to the original array.
    // The key allows to remove predicted word (pop_w()) or history (pop_h()).
    class key
    {
    protected:
        friend class mgram_map;
        const int *mgram; // pointer to mgram array --key does not own that memory!
        int m;            // elements in mgram array
    public:
        // constructors
        inline key()
            : mgram(NULL), m(0)
        {
        } // required for use in std::vector
        inline key(const int *mgram, int m)
            : mgram(mgram), m(m)
        {
        }
        // manipulations
        inline key pop_h() const
        {
            if (m == 0)
                fail("key::pop_h() called on empty key");
            return key(mgram + 1, m - 1);
        }
        inline key pop_w() const
        {
            if (m == 0)
                fail("key::pop_w() called on empty key");
            return key(mgram, m - 1);
        }
        // access
        inline int back() const
        {
            if (m == 0)
                fail("key::back() called on empty key");
            return mgram[m - 1];
        }
        inline const int &operator[](int n) const
        {
            if (n < 0 || n >= m)
                fail("key::operator[] out of bounds");
            return mgram[n];
        }
        inline int order() const
        {
            return m;
        }
        // key comparison (used in sorting and merging)
        inline bool operator<(const key &other) const
        {
            for (int k = 0; k < m && k < other.m; k++)
                if (mgram[k] != other.mgram[k])
                    return mgram[k] < other.mgram[k];
            return m < other.m;
        }
        inline bool operator>(const key &other) const
        {
            return other < *this;
        }
        inline bool operator<=(const key &other) const
        {
            return !(*this > other);
        }
        inline bool operator>=(const key &other) const
        {
            return !(*this < other);
        }
        inline bool operator==(const key &other) const
        {
            if (m != other.m)
                return false;
            for (int k = 0; k < m; k++)
                if (mgram[k] != other.mgram[k])
                    return false;
            return true;
        }
        inline bool operator!=(const key &other) const
        {
            return !(*this == other);
        }
    };

    // 'coord' is an abstract coordinate of an m-gram. This is returned by
    // operator[], and is used as an index in our sister structure, mgram_data.
    struct coord
    {
        index_t i;        // index in that level -- -1 means not found
        unsigned short m; // level
        inline bool valid() const
        {
            return i != nindex;
        }
        inline void validate() const
        {
            if (!valid())
                fail("coord used but invalid");
        }
        void invalidate()
        {
            i = nindex;
        }
        inline int order() const
        {
            validate();
            return m;
        }
        inline coord(int m, index_t i)
            : m((unsigned short) m), i(i)
        {
        } // valid coord
        // ^^ this is where we'd test for index_t overflow if we ever need it
        inline coord(bool valid = true)
            : m(0), i(valid ? 0 : nindex)
        {
        } // root or invalid
    };

    // 'foundcoord' is an extended 'coord' as returned by operator[], with
    // information on whether it is valid or not, and whether it refers to
    // an m-gram or to a history only.
    class foundcoord : public /*<-want to get rid of this*/ coord
    {
        const short type;
        foundcoord &operator=(const foundcoord &);

    public:
        inline bool valid_w() const
        {
            return type > 0;
        }
        inline bool valid_h() const
        {
            return type == 0;
        }
        inline bool valid() const
        {
            return type >= 0;
        }
        inline operator const coord &() const
        {
            return *this;
        }
        inline foundcoord(short type, int m, index_t i)
            : type(type), coord(m, i)
        {
        }
        inline foundcoord(short type)
            : type(type), coord(type >= 0)
        {
        }
    };

    // search for an mgram -- given a 'key', return its 'coord.'
    // If m-gram is found, type=1. If only history found then type=0, and
    // coord represents the history token instead.
    // The given key may not be longer than our storage (we do not automatically
    // truncate because that would not be detectable by caller).
    __forceinline foundcoord operator[](const key &k) const
    {
        if (k.m > M) // call truncate() first with too long keys
            fail("operator[] called with too long key");
        if (k.m == 0)
            return foundcoord(1); // zerogram -> root

        // We traverse history one by one.
        index_t i = 0;
        for (int n = 1; n < k.m; n++)
        {
            int w = k[n - 1]; // may be -1 for unknown word
            int id = map(w);  // may still be -1
            // const char * sym = idToSymbol (id); sym;   // (debugging)
            i = find_child(n - 1, i, id);
            if (i == nindex)           // unknown history: fall back
                return foundcoord(-1); // indicates failure
            // found it: advance search by one history token
        }

        // Found history. Do we also find the prediced word?
        int w = k[k.m - 1]; // may be -1 for unknown word
        int id = map(w);    // may still be -1
        index_t i_m = find_child(k.m - 1, i, id);
        if (i_m == nindex) // not found
            return foundcoord(0, k.m - 1, i);
        else // found
            return foundcoord(1, k.m, i_m);
    }

    // truncate a key to the m-gram length supported by this
    inline key truncate(const key &k) const
    {
        if (k.m <= M)
            return k;
        else
            return key(k.mgram + (k.m - M), M);
    }

    // --- iterators
    //  - iterating over children of a history
    //  - deep-iterating over the entire tree

    // for (iterator iter (mgram_map, parent_coord); iter; ++iter) { mgram_data[iter]; w=*iter; }
    class iterator : public coord
    {
        index_t end;          // end index: i is invalid when it reaches this
        const mgram_map &map; // remembered for operator*
        void operator=(const iterator &);

    public:
        // bool: true if can use or increment
        inline operator bool() const
        {
            return i < end;
        }
        // increment
        inline void operator++()
        {
            if (i < end)
                i++;
            else
                fail("iterator used beyond end");
        }
        // retrieve word -- returns -1 if not used in user's w->id map, e.g. skipped word
        inline int operator*() const
        {
            if (i >= end)
                fail("iterator used beyond end");
            return map.id2w[map.ids[m][i]];
        }
        // construct 'coord' as first element
        iterator(const mgram_map &map, const coord &c)
            : map(map)
        {
            c.validate();
            // get the range
            index_t beg = map.firsts[c.m][c.i]; // first element of child
            end = map.firsts[c.m][c.i + 1];     // end = first of next entry
            // set the first child coordinate
            m = c.m + 1; // we iterate over the child level
            i = beg;     // first element
        }
        // alternative to loop over all m-grams of a level
        iterator(const mgram_map &map, int m)
            : map(map), coord(m, 0)
        {
            end = (m > 0) ? (index_t) map.ids[m].size() : 1; // loop over entire vector
        }
    };

    // for (deep_iterator iter (mgram_map, maxM); iter; ++iter) { mgram_data[iter]; key=*iter; }
    class deep_iterator : public coord
    {
    protected:
        int maxM;
        std::vector<index_t> pos; // current position [0..m]
        std::vector<int> mgram;   // current m-gram corresponding to 'pos'
        const mgram_map &map;     // remembered for operator*
        void operator=(const deep_iterator &);
        void validate() const
        {
            if (!valid())
                fail("iterator used beyond end");
        }

    public:
        // constructor
        deep_iterator(const mgram_map &map, int p_maxM = -1)
            : map(map), maxM(p_maxM), coord(map.firsts[0].size() >= 2)
        {
            if (maxM == -1)
                maxM = map.M;
            else if (maxM > map.M)
                fail("deep_iterator instantiated for invalid maximum depth");
            mgram.resize(maxM, -1);
            pos.resize(maxM + 1, 0);
        }
        // bool: true if can use or increment
        inline operator bool() const
        {
            return valid();
        }
        // increment
        inline void operator++()
        {
            validate();
            // if current position has a child then enter it
            if (m < maxM && m < map.M && map.firsts[m][pos[m]] < map.firsts[m][pos[m] + 1])
            {
                i = map.firsts[m][pos[m]];
                m++;
                pos[m] = i;
                mgram[m - 1] = map.id2w[map.ids[m][i]];
                return;
            }
            // advance vertically or step up one level
            for (; m > 0;)
            {
                // advance current position if still elements left
                i++;
                if (i < map.firsts[m - 1][pos[m - 1] + 1]) // not hit the end yet
                {
                    pos[m] = i;
                    mgram[m - 1] = map.id2w[map.ids[m][i]];
                    return;
                }
                // cannot enter or advance: step back one
                m--;
                i = pos[m]; // parent position
            }
            // reached the end
            invalidate(); // invalidates 'coord'--next call to bool() will return false
            return;
        }
        // retrieve keys -- returns -1 if not used in user's w->id map, e.g. skipped word
        // The key points into the iterator structure, i.e. it operator++ invalidates it!
        inline key operator*() const
        {
            validate();
            return key(&mgram[0], m);
        }
    };

    // for (reordering_iterator iter (mgram_map, wrank[], maxM); iter; ++iter) { mgram_data[iter]; key=*iter; }
    // Like deep_iterator, but iterates the map such that ws are returned in
    // increasing wrank[w] rather than in the original storage order.
    // Used for merging multiple models such as linear interpolation.
    class reordering_iterator : public deep_iterator
    {
        const std::vector<int> &wrank;             // assigns a rank to each w
        const char *i;                             // hide coord::i against accidental access
        std::vector<std::vector<index_t>> indexes; // coord::i <- indexes[m2][this->i]
        std::vector<index_t> indexbase;            // indexes[m2] is indexbase[m2]-based
        inline index_t &index_at(int m2, index_t i2)
        {
            return indexes[m2][i2 - indexbase[m2]];
        }
        std::vector<std::pair<int, int>> sortTemp; // temp for creating indexes
        void operator=(const reordering_iterator &);

    public:
        // constructor
        reordering_iterator(const mgram_map &map, const std::vector<int> &wrank, int p_maxM = -1)
            : deep_iterator(map, p_maxM), wrank(wrank)
        {
            if (wrank.size() < map.w2id.size())
                fail("reordering_iterator: wrank has wrong dimension");
            indexes.resize(maxM + 1);
            indexes[0].push_back(0); // look-up table for root: only one item
            indexbase.resize(maxM + 1, 0);
            pos[0] = coord::i; // zerogram level: same i because no mapping there
            if (map.M >= 1)
                sortTemp.reserve(map.size(1));
        }
        // increment
        // We iterate through the map using (m, pos[m]) while user consumes (m, i)
        // i.e. for operator++(), coord::i is not iterated but a return value.
        inline void operator++()
        {
            validate();
            // if current position has a child then enter it
            // Note: We enter the item that coord::i points to, which is not pos[m]
            // but the mapped pos[m].
            if (m < maxM && m < map.M && map.firsts[m][index_at(m, pos[m])] < map.firsts[m][index_at(m, pos[m]) + 1])
            {
                // enter the level
                index_t beg = map.firsts[m][index_at(m, pos[m])]; // index range of sub-level
                index_t end = map.firsts[m][index_at(m, pos[m]) + 1];
                m++;
                pos[m] = beg;
                // build look-up table for returned values
                size_t num = end - beg;
                // we sort i by rank (and i, keeping original order for identical rank)
                sortTemp.resize(end - beg);
                foreach_index (k, sortTemp)
                {
                    index_t i2 = beg + k;
                    int id = map.ids[m][i2];
                    int w = map.id2w[id];
                    sortTemp[k] = std::make_pair(wrank[w], i2);
                }
                std::sort(sortTemp.begin(), sortTemp.end());
                // remember sorted i's
                indexbase[m] = beg; // used by index_at (m, *)
                indexes[m].resize(num);
                foreach_index (k, sortTemp)
                    index_at(m, k + beg) = sortTemp[k].second;
                // set up return values
                coord::i = index_at(m, pos[m]);
                mgram[m - 1] = map.id2w[map.ids[m][coord::i]];
                return;
            }
            // advance vertically or step up one level
            for (; m > 0;)
            {
                // advance current position if still elements left
                // use our own i (in pos[m]), then map to coord::i using sorted list
                pos[m]++;
                if (pos[m] < map.firsts[m - 1][index_at(m - 1, pos[m - 1]) + 1]) // not hit the end yet
                {
                    coord::i = index_at(m, pos[m]);
                    mgram[m - 1] = map.id2w[map.ids[m][coord::i]];
                    return;
                }
                // cannot enter or advance: step back one
                m--;
            }
            // reached the end
            invalidate(); // invalidates 'coord'--next call to bool() will return false
            return;
        }
    };

    // --- functions for building

    // 'unmapped_key' contains original 'id' rather than 'w' values. It is only
    // used for create()--at creation time, we use our private mapping.
    typedef key unmapped_key;

// create a new key (to be called in sequence).
// Only the last word given in the key is added. The history of the given
// mgram must already exist and must be the last.
// Important: Unlike operator[], create() takes an unmapped_key, i.e. the
// mapping is not applied.
// 'cache' is used for speed-up, it must be as large as key.m-1 and
// initialized to 0.
#pragma warning(push)           // known compiler bug: size_t (marked _w64) vs. unsigned...
#pragma warning(disable : 4267) // ...int (not marked) incorrectly flagged in templates
    typedef std::vector<index_t> cache_t;
    coord create(const unmapped_key &k, cache_t &cache)
    {
        if (k.m < 1)
            return coord(); // (root need not be created)
        // locate history (must exist), also updates cache[]
        bool prevValid = true;
        index_t i = 0; // index of history in level k.m-1
        if (cache.empty())
            cache.resize(M, nindex); // lazy initialization
        for (int m = 1; m < k.m; m++)
        {
            int thisid = k[m - 1];
            if (prevValid && cache[m - 1] != nindex && ids[m][cache[m - 1]] == thisid)
            {
                i = cache[m - 1]; // get from cache
                continue;
            }
            // need to actually search
            i = find_child(m - 1, i, thisid);
            if (i == nindex)
                fail("create() called with unknown history");
            cache[m - 1] = i;
            prevValid = false;
        }
        for (int m = k.m; m < M && cache[m - 1] != nindex; m++)
            cache[m - 1] = nindex; // clear upper entries (now invalid)
        // now i is the index of the id of the last history item
        // make the firsts entry if not there yet
        bool newHist = (firsts[k.m - 1].size() < (size_t) i + 2);
        while (firsts[k.m - 1].size() < (size_t) i + 2) // [i+1] is the end for this array
            firsts[k.m - 1].push_back((mgram_map::index_t) ids[k.m].size());
        if (firsts[k.m - 1].size() != (size_t) i + 2)
            fail("create() called out of order (history)");
        // create new word id
        int thisid = k[k.m - 1];
        if (!newHist && thisid <= ids[k.m].back())
            fail("create() called out of order");
        // keep track of idmax
        if (thisid > idmax)
            idmax = thisid;

        coord c(k.m, (index_t) ids[k.m].size());

        assert(firsts[k.m - 1].back() == (index_t) ids[k.m].size());
        ids[k.m].push_back(thisid); // create value
        firsts[k.m - 1].back() = (index_t) ids[k.m].size();
        if (firsts[k.m - 1].back() != (index_t) ids[k.m].size())
            fail("create() numeric overflow--index_t too small");
        assert(k.m == M || firsts[k.m].back() == (index_t) ids[k.m + 1].size());

        // optimization: level1nonsparse flag
        // If unigram level is entirely non-sparse, we can save the search
        // operation at that level, which is significantly slower than for the
        // much sparser higher levels.
        if (c.m == 1)
        {
            if (c.i == 0)
                level1nonsparse = true;                   // first entry
            level1nonsparse &= (c.i == (index_t) thisid); // no search needed
            level1lookup.resize(thisid + 1, nindex);
            level1lookup[thisid] = c.i;
        }

        return c;
    }
#pragma warning(pop)

    // call this at the end
    //  - establish the w->id mapping that is used in operator[]
    //  - finalize the firsts arrays
    // This function swaps the user-provided map and our current one.
    // We use swapping to avoid the memory allocation (noone else outside should
    // have to keep the map).
    // This function also builds our internal reverse map used in the iterator.
    void created(std::vector<int> &userToLMSymMap)
    {
        // finalize firsts arrays
        foreach_index (m, firsts)
            firsts[m].resize(ids[m].size() + 1, (int) ids[m + 1].size());
        foreach_index (m, firsts)
        {
            assert(firsts[m][0] == 0);
            foreach_index (i, ids[m])
                assert(firsts[m][i] <= firsts[m][i + 1]);
            assert((size_t) firsts[m].back() == ids[m + 1].size());
        }
        // id mapping
        // user-provided w->id map
        std::swap(w2id, userToLMSymMap);
        // reverse map
        id2w.assign(maxid() + 1, nindex);
        foreach_index (w, w2id)
        {
            int id = w2id[w];
            if (id < 0)
                continue; // invalid word
            if (id > maxid())
                continue; // id not in use
            id2w[id] = w;
        }
    }

    // helper for created()--return an identical map, as we have several
    // occasions where such a map is passed as userToLMSymMap to created().
    std::vector<int> identical_map(size_t n = SIZE_MAX) const
    {
        if (n == SIZE_MAX)
            n = maxid() + 1;
        std::vector<int> v(n);
        foreach_index (i, v)
            v[i] = i;
        return v;
    }

    // decide whether iterator will return in increasing w order
    bool inorder() const
    {
#if 0 // fix this: need access to w2id, or have an inorder() function in mgram_map
        bool inorder = true;
        for (int i = 1; inorder && i < (int) map.w2id.size(); i++)
            inorder &= (map.w2id[i+1] >= map.w2id[i]);
#endif
        return false;
    }
};

// ===========================================================================
// mgram_data -- data stored according to mgram_map
// Separate from mgram_map, so that we can share the same map for multiple data.
// ===========================================================================

template <class DATATYPE>
class mgram_data
{
    std::vector<std::vector<DATATYPE>> data;
    static void fail(const char *msg)
    {
        RuntimeError("mgram_data::%s", msg);
    }

public:
    mgram_data()
    {
    }
    mgram_data(int M)
    {
        init(M);
    }
    // for an M-gram, indexes [0..M] are valid thus data[] has M+1 elements
    void init(int M)
    {
        data.assign(M + 1, std::vector<DATATYPE>());
    }
    void reserve(int m, size_t size)
    {
        data[m].reserve(size);
    }
    void resize(int M)
    {
        if ((size_t) M + 1 <= data.size())
            data.resize(M + 1);
        else
            fail("resize() can only shrink");
    }
    size_t size(int m) const
    {
        return data[m].size();
    }
    size_t size() const
    {
        size_t sz = 0;
        foreach_index (m, data)
            sz += size(m);
        return sz;
    }
    void clear()
    {
        data.clear();
    }
    void swap(mgram_data &other)
    {
        data.swap(other.data);
    }
    // access existing elements. Usage:
    // DATATYPE & element = mgram_data[mgram_map[mgram_map::key (mgram, m)]]
    __forceinline DATATYPE &operator[](const mgram_map::coord &c)
    {
        c.validate();
        return data[c.m][c.i];
    }
    __forceinline const DATATYPE &operator[](const mgram_map::coord &c) const
    {
        c.validate();
        return data[c.m][c.i];
    }
    // create entire vector (for random-access situations).
    void assign(int m, size_t size, const DATATYPE &value)
    {
        data[m].assign(size, value);
    }
    // create an element. We can only append.
    inline void push_back(const mgram_map::coord &c, const DATATYPE &val)
    {
        c.validate();
        if (data[c.m].size() != (size_t) c.i)
            fail("push_back() only allowed for last entry");
        data[c.m].push_back(val);
    }
};

// ===========================================================================
// CMGramLM -- a back-off M-gram language model in memory, loaded from an ARPA file
// ===========================================================================

class CMGramLM : public ILM
{
protected:
#if 0
    void clear()        // release all memory --object unusable after this
    {
        M = -1;
        map.clear();
        logP.clear();
        logB.clear();
    }
#endif
    int M; // e.g. M=3 for trigram
    // ^^ TODO: can we do away with this entirely and replace it by map.order()/this->order()
    mgram_map map;
    mgram_data<float> logP; // [M+1][i] probabilities
    mgram_data<float> logB; // [M][i] back-off weights (stored for histories only)
    friend class CMGramLMIterator;

    // diagnostics of previous score() call
    mutable int longestMGramFound;   // longest m-gram (incl. predicted token) found
    mutable int longestHistoryFound; // longest history (excl. predicted token) found

    // this function is for reducing M after the fact, e.g. during estimation
    // ... TODO: rethink the resize business. It is for shrinking only.
    void resize(int newM)
    {
        M = newM;
        map.resize(M);
    }

public:
    CMGramLM()
        : M(-1)
    {
    } // needs explicit initialization through read() or init()

    virtual int getLastLongestHistoryFound() const
    {
        return longestHistoryFound;
    }
    virtual int getLastLongestMGramFound() const
    {
        return longestMGramFound;
    }

    // -----------------------------------------------------------------------
    // score() -- compute an m-gram score (incl. back-off and fallback)
    // -----------------------------------------------------------------------
    // mgram[m-1] = word to predict, tokens before that are history
    // m=3 means trigram
    virtual double score(const int *mgram, int m) const
    {
        longestHistoryFound = 0; // (diagnostics)

        double totalLogB = 0.0; // accumulated back-off

        for (mgram_map::key key = map.truncate(mgram_map::key(mgram, m));; key = key.pop_h())
        {
            // look up the m-gram
            const mgram_map::foundcoord c = map[key];

            // (diagnostics -- can be removed if not used)
            if (c.valid() && key.order() - 1 > longestHistoryFound)
                longestHistoryFound = key.order() - 1;
            if (c.valid_w())
                longestMGramFound = key.order();

            // full m-gram found -> return it (zerogram always considered found)
            if (c.valid_w())
                return totalLogB + logP[c];

            // history found but predicted word not -> back-off
            if (c.valid_h())          // c is coordinate of parent instead
                totalLogB += logB[c]; // and continue like fall back

            // history not found -> fall back
        } // and go again with the shortened history
    }

    // same as score() but without optimizations (for reference)
    // ... this is really no longer needed
    virtual double score_unoptimized(const int *mgram, int m) const
    {
        return score_unoptimized(map.truncate(mgram_map::key(mgram, m)));
    }

    inline double score_unoptimized(const mgram_map::key &key) const
    {
        // look up the m-gram
        const mgram_map::foundcoord c = map[key];

        // full m-gram found -> return it
        if (c.valid_w())
            return logP[c];

        // history found but predicted word not -> back-off
        else if (c.valid_h()) // c is coordinate of patent instead
            return logB[c] + score_unoptimized(key.pop_h());

        // history not found -> fall back
        else
            return score_unoptimized(key.pop_h());
    }

    // test for OOV word (OOV w.r.t. LM)
    virtual bool oov(int w) const
    {
        return map.oov(w);
    }

    virtual void adapt(const int *, size_t)
    {
    } // this LM does not adapt

private:
    // keep this for debugging
    std::wstring filename; // input filename
    struct SYMBOL
    {
        std::string symbol; // token
        int id;        // numeric id in LM space (index of word read)
        bool operator<(const SYMBOL &other) const
        {
            return symbol < other.symbol;
        }
        SYMBOL(int p_id, const char *p_symbol)
            : id(p_id), symbol(p_symbol)
        {
        }
    };
    std::vector<SYMBOL> lmSymbols; // (id, word) symbols used in LM
    std::vector<int> idToSymIndex; // map LM id to index in lmSymbols[] array

    // search for a word in the sorted word array.
    // Only use this after sorting, i.e. after full 1-gram section has been read.
    // Only really used in read().
    inline int symbolToId(const char *word) const
    {
        int beg = 0;
        int end = (int) lmSymbols.size();
        while (beg < end)
        {
            int i = (beg + end) / 2;
            const char *v = lmSymbols[i].symbol.c_str();
            int cmp = strcmp(word, v);
            if (cmp == 0)
                return lmSymbols[i].id; // found it
            else if (cmp < 0)
                end = i; // id is left of i
            else
                beg = i + 1; // id is right of i
        }
        return -1; // not found
    }

    inline const char *idToSymbol(int id) const
    {
        if (id < 0)
            return NULL; // empty string for unknown ids
        int i = idToSymIndex[id];
        return lmSymbols[i].symbol.c_str();
    }

private:
    // type cast to const char*, to allow write() to use both const char* and string
    static const char *const_char_ptr(const char *p)
    {
        return p;
    }
    static const char *const_char_ptr(const std::string &s)
    {
        return s.c_str();
    }

public:
    // write model out as an ARPA (text) file.
    // symbols can be anything that has symbols[w] -> std::string& or const char*
    template <class SYMMAP>
    void write(FILE *outf, const SYMMAP &symbols, int M = INT_MAX) const
    {
        if (M > this->M)
            M = this->M; // clip; also covers default value
        if (M < 1 || map.size(1) == 0)
            RuntimeError("write: attempting to write empty model");

        // output header
        //  \data\
        //  ngram 1=58289
        //  ngram 2=956100
        //  ...
        fprintfOrDie(outf, "\\data\\\n");
        for (int m = 1; m <= M; m++)
        {
            fprintfOrDie(outf, "ngram %d=%d\n", m, map.size(m));
        }
        fflushOrDie(outf);

        // output m-grams themselves
        // M-gram sections
        const double log10 = log(10.0);
        for (int m = 1; m <= M; m++)
        {
            fprintf(stderr, "estimate: writing %d %d-grams..", map.size(m), m);
            int step = (int) logP.size(m) / 100;
            if (step == 0)
                step = 1;
            int numMGramsWritten = 0;

            // output m-gram section
            fprintfOrDie(outf, "\n\\%d-grams:\n", m);
            for (mgram_map::deep_iterator iter(map, m); iter; ++iter)
            {
                if (iter.order() != m) // a parent
                    continue;

                const mgram_map::key key = *iter;
                assert(m == key.order());

                // --- output m-gram to ARPA file
                fprintfOrDie(outf, "%.4f", logP[iter] / log10);
                for (int k = 0; k < m; k++)
                { // the M-gram words
                    int wid = key[k];
                    const char *w = const_char_ptr(symbols[wid]);
                    fprintfOrDie(outf, " %s", w);
                }

                if (m < M)
                { // back-off weight (not for highest order)
                    fprintfOrDie(outf, " %.4f", logB[iter] / log10);
                }
                fprintfOrDie(outf, "\n");

                // progress
                if (numMGramsWritten % step == 0)
                {
                    fprintf(stderr, ".");
                }
                numMGramsWritten++;
            }
            fflushOrDie(outf);
            assert(numMGramsWritten == map.size(m));
            fprintf(stderr, "\n");
        }

        fprintfOrDie(outf, "\n\\end\\\n");
        fflushOrDie(outf);
    }

    // get TopM Ngram probability
    // GangLi add this function to do probability pruning
    double KeepTopMNgramThreshold(int topM, int ngram)
    {
        // initial return as a very low value
        double probThrshold = -99;

        // check if nessary to prune
        if (map.size(ngram) > topM)
        {
            std::vector<std::pair<int, float>> probArray;
            probArray.reserve(map.size(ngram));
        }

        return probThrshold;
    }

protected:
    // replace zerogram prob by one appropriate for OOVs
    // We use the minimum of all unigram scores (assuming they represent singleton
    // events, which are closest to a zerogram--a better choice may be a leaving-
    // one-out estimate?).
    // Back-off weight is reset to 1.0 such that there is no extra penalty on it.
    void updateOOVScore()
    {
        float unknownLogP = 0.0f;
        for (mgram_map::iterator iter(map, mgram_map::coord()); iter; ++iter)
        {
            if (logP[iter] < -98.0f)
                continue; // disabled token, such as <s>, does not count
            if (logP[iter] < unknownLogP)
                unknownLogP = logP[iter];
        }
        logP[mgram_map::coord()] = unknownLogP;
        logB[mgram_map::coord()] = 0.0f;
    }

public:
    // read an ARPA (text) file.
    // Words do not need to be sorted in the unigram section, but the m-gram
    // sections have to be in the same order as the unigrams.
    // The 'userSymMap' defines the vocabulary space used in score().
    // If 'filterVocabulary' then LM entries for words not in userSymMap are skipped.
    // Otherwise the userSymMap is updated with the words from the LM.
    // 'maxM' allows to restrict the loading to a smaller LM order.
    // SYMMAP can be e.g. CSymMap or CSymbolSet.
    template <class SYMMAP>
    void read(const std::wstring &pathname, SYMMAP &userSymMap, bool filterVocabulary, int maxM)
    {
        int lineNo = 0;
        auto_file_ptr f(fopenOrDie(pathname, L"rbS"));
        fprintf(stderr, "read: reading %ls", pathname.c_str());
        filename = pathname; // (keep this info for debugging)

        // --- read header information

        // search for header line
        char buf[1024];
        lineNo++, fgetline(f, buf);
        while (strcmp(buf, "\\data\\") != 0 && !feof(f))
            lineNo++, fgetline(f, buf);
        lineNo++, fgetline(f, buf);

        // get the dimensions
        std::vector<int> dims;
        dims.reserve(4);

        while (buf[0] == 0 && !feof(f))
            lineNo++, fgetline(f, buf);

        int n, dim;
        dims.push_back(1); // dummy zerogram entry
        while (sscanf(buf, "ngram %d=%d", &n, &dim) == 2 && n == (int) dims.size())
        {
            dims.push_back(dim);
            lineNo++, fgetline(f, buf);
        }

        M = (int) dims.size() - 1;
        if (M == 0)
            RuntimeError("read: mal-formed LM file, no dimension information (%d): %ls", lineNo, pathname.c_str());
        int fileM = M;
        if (M > maxM)
            M = maxM;

        // allocate main storage
        map.init(M);
        logP.init(M);
        logB.init(M - 1);
        for (int m = 0; m <= M; m++)
        {
            map.reserve(m, dims[m]);
            logP.reserve(m, dims[m]);
            if (m < M)
                logB.reserve(m, dims[m]);
        }
        lmSymbols.reserve(dims[0]);

        logB.push_back(mgram_map::coord(), 0.0f); // dummy logB for backing off to zg
        logP.push_back(mgram_map::coord(), 0.0f); // zerogram score -- gets updated later

        std::vector<bool> skipWord; // true: skip entry containing this word
        skipWord.reserve(lmSymbols.capacity());

        // --- read main sections

        const double ln10xLMF = log(10.0);                // ARPA scores are strangely scaled
        msra::strfun::tokenizer tokens(" \t\n\r", M + 1); // used in tokenizing the input line
        for (int m = 1; m <= M; m++)
        {
            while (buf[0] == 0 && !feof(f))
                lineNo++, fgetline(f, buf);

            if (sscanf(buf, "\\%d-grams:", &n) != 1 || n != m)
                RuntimeError("read: mal-formed LM file, bad section header (%d): %ls", lineNo, pathname.c_str());
            lineNo++, fgetline(f, buf);

            std::vector<int> mgram(m + 1, -1);     // current mgram being read ([0]=dummy)
            std::vector<int> prevmgram(m + 1, -1); // cache to speed up symbol lookup
            mgram_map::cache_t mapCache;           // cache to speed up map.create()

            // read all the m-grams
            while (buf[0] != '\\' && !feof(f))
            {
                if (buf[0] == 0)
                {
                    lineNo++, fgetline(f, buf);
                    continue;
                }

                // -- parse the line
                tokens = &buf[0];
                if ((int) tokens.size() != ((m < fileM) ? m + 2 : m + 1))
                    RuntimeError("read: mal-formed LM file, incorrect number of tokens (%d): %ls", lineNo, pathname.c_str());
                double scoreVal = atof(tokens[0]);     // ... use sscanf() instead for error checking?
                double thisLogP = scoreVal * ln10xLMF; // convert to natural log

                bool skipEntry = false;
                for (int n2 = 1; n2 <= m; n2++)
                {
                    const char *tok = tokens[n2];
                    // map to id
                    int id;
                    if (m == 1) // unigram: build vocab table
                    {
                        id = (int) lmSymbols.size(); // unique id for this symbol
                        lmSymbols.push_back(SYMBOL(id, tok));
                        bool toSkip = false;
                        if (userSymMap.sym2existingId(lmSymbols.back().symbol) == -1)
                        {
                            if (filterVocabulary)
                                toSkip = true; // unknown word
                            else
                                userSymMap.sym2id(lmSymbols.back().symbol); // create it in user's space
                        }
                        skipWord.push_back(toSkip);
                    }
                    else // mgram: look up word in vocabulary
                    {
                        if (prevmgram[n2] >= 0 && strcmp(idToSymbol(prevmgram[n2]), tok) == 0)
                            id = prevmgram[n2]; // optimization: most of the time, it's the same
                        else
                        {
                            id = symbolToId(tok);
                            if (id == -1)
                                RuntimeError("read: mal-formed LM file, m-gram contains unknown word (%d): %ls", lineNo, pathname.c_str());
                        }
                    }
                    mgram[n2] = id;             // that's our id
                    skipEntry |= skipWord[id]; // skip entry if any token is unknown
                }

                double thisLogB = 0.0;
                if (m < M && !skipEntry)
                {
                    double boVal = atof(tokens[m + 1]); // ... use sscanf() instead for error checking?
                    thisLogB = boVal * ln10xLMF;        // convert to natural log
                }

                lineNo++, fgetline(f, buf);

                if (skipEntry) // word contained unknown vocabulary: skip entire entry
                    goto skipMGram;

                // -- enter the information into our data structure
                // Note that the mgram_map/mgram_data functions are highly efficient
                // because they can only be called in sorted order.

                // locate the corresponding entries
                {                                                   // (local block because we 'goto' over this)
                    mgram_map::key key(&mgram[1], m);               // key to locate this m-gram
                    mgram_map::coord c = map.create(key, mapCache); // create it & gets its location

                    // enter into data structure
                    logP.push_back(c, (float) thisLogP); // prob value
                    if (m < M)                           // back-off weight
                        logB.push_back(c, (float) thisLogB);
                }

            skipMGram:
                // remember current mgram for next iteration
                std::swap(mgram, prevmgram);
            }

            // fix the symbol set -- now we can binary-search in them with symbolToId()
            if (m == 1)
            {
                std::sort(lmSymbols.begin(), lmSymbols.end());
                idToSymIndex.resize(lmSymbols.size(), -1);
                for (int i = 0; i < (int) lmSymbols.size(); i++)
                {
                    idToSymIndex[lmSymbols[i].id] = i;
                }
            }

            fprintf(stderr, ", %d %d-grams", map.size(m), m);
        }
        fprintf(stderr, "\n");

        // check end tag
        if (M == fileM)
        { // only if caller did not restrict us to a lower order
            while (buf[0] == 0 && !feof(f))
                lineNo++, fgetline(f, buf);
            if (strcmp(buf, "\\end\\") != 0)
                RuntimeError("read: mal-formed LM file, no \\end\\ tag (%d): %ls", lineNo, pathname.c_str());
        }

        // update zerogram score by one appropriate for OOVs
        updateOOVScore();

        // establish mapping of word ids from user to LM space.
        // map's operator[] maps mgrams using this map.
        std::vector<int> userToLMSymMap(userSymMap.size());
        for (int i = 0; i < (int) userSymMap.size(); i++)
        {
            const char *sym = userSymMap.id2sym(i);
            int id = symbolToId(sym); // may be -1 if not found
            userToLMSymMap[i] = id;
        }
        map.created(userToLMSymMap);
    }

protected:
    // sort LM such that iterators will iterate in increasing order w.r.t. w2id[w]
    // This is achieved by replacing all internal ids by w2id[w].
    // This function is expensive: it makes a full temporary copy and involves sorting.
    // w2id[] gets destroyed by this function.
    void sort(std::vector<int> &w2id)
    {
        // create a full copy of logP and logB in the changed order
        mgram_map sortedMap(M);
        mgram_data<float> sortedLogP(M);
        mgram_data<float> sortedLogB(M - 1);

        for (int m = 1; m <= M; m++)
        {
            sortedMap.reserve(m, map.size(m));
            sortedLogP.reserve(m, logP.size(m));
            if (m < M)
                sortedLogB.reserve(m, logB.size(m));
        }

        // iterate in order of w2id
        // Order is determined by w2id[], i.e. entries with lower new id are
        // returned first.
        std::vector<int> mgram(M + 1, -1); // unmapped key in new id space
        mgram_map::cache_t createCache;
        for (mgram_map::reordering_iterator iter(map, w2id); iter; ++iter)
        {
            int m = iter.order();
            mgram_map::key key = *iter; // key in old 'w' space
            // keep track of an unmapped key in new id space
            if (m > 0)
            {
                int w = key.back();
                int newid = w2id[w]; // map to new id space
                mgram[m - 1] = newid;
            }
            for (int k = 0; k < m; k++)
                assert(mgram[k] == w2id[key[k]]);
            // insert new key into sortedMap
            mgram_map::coord c = sortedMap.create(mgram_map::unmapped_key(&mgram[0], m), createCache);
            // copy over logP and logB
            sortedLogP.push_back(c, logP[iter]);
            if (m < M)
                sortedLogB.push_back(c, logB[iter]);
        }

        // finalize sorted map
        sortedMap.created(w2id);

        // replace LM by sorted LM
        map.swap(sortedMap);
        logP.swap(sortedLogP);
        logB.swap(sortedLogB);
    }

public:
    // sort LM such that internal ids are in lexical order
    // After calling this function, iterators will iterate in lexical order,
    // and writing to an ARPA file creates a lexicographically sorted file.
    // Having sorted files is useful w.r.t. efficiency when iterating multiple
    // models in parallel, e.g. interpolating or otherwise merging models,
    // because then IIter can use the efficient deep_iterator (which iterates
    // in our internal order and therefore does not do any sorting) rather than
    // the reordering_iterator (which involves sort operations).
    template <class SYMMAP>
    void sort(const SYMMAP &userSymMap)
    {
        // deterine sort order
        // Note: This code copies all strings twice.
        std::vector<std::pair<std::string, int>> sortTemp(userSymMap.size()); // (string, w)
        foreach_index (w, sortTemp)
            sortTemp[w] = make_pair(userSymMap[w], w);
        std::sort(sortTemp.begin(), sortTemp.end());
        std::vector<int> w2id(userSymMap.size(), -1); // w -> its new id
        foreach_index (id, w2id)
            w2id[sortTemp[id].second] = id;

        // sort w.r.t. new id space
        sort(w2id);
    }

    // iterator to enumerate all known m-grams
    // This is used when creating whole models at once.
    template <class ITERATOR>
    class TIter : public ILM::IIter
    {
        int minM;               // minimum M we want to iterate (skip all below)
        const CMGramLM &lm;     // the underlying LM (for value())
        std::vector<int> wrank; // sorting criterion
        ITERATOR iter;          // the iterator used in this interface
        void findMinM()
        {
            while (iter && iter.order() < minM)
                ++iter;
        }

    public:
        // constructors
        TIter(const CMGramLM &lm, int minM, int maxM)
            : minM(minM), lm(lm), iter(lm.map, maxM)
        {
            findMinM();
        }
        TIter(const CMGramLM &lm, bool, int minM, int maxM)
            : minM(minM), lm(lm), wrank(lm.map.identical_map(lm.map.maxw() + 1)), iter(lm.map, wrank, maxM)
        {
            findMinM();
        }
        // has iterator not yet reached end?
        virtual operator bool() const
        {
            return iter;
        }
        // advance by one
        virtual void operator++()
        {
            ++iter;
            findMinM();
        }
        // current m-gram (mgram,m)
        virtual std::pair<const int *, int> operator*() const
        {
            mgram_map::key key = *iter;
            return std::make_pair(key.order() == 0 ? NULL : &key[0], key.order());
        }
        // current value (logP, logB)
        // No processing here--read out the logP/logB values directly from the data structure.
        virtual std::pair<double, double> value() const
        {
            if (iter.order() < lm.M)
                return std::make_pair(lm.logP[iter], lm.logB[iter]);
            else
                return std::make_pair(lm.logP[iter], 0.0);
        }
    };
    virtual IIter *iter(int minM, int maxM) const
    {
        if (maxM == INT_MAX)
            maxM = M; // default value
        // if no sorting needed, then we can use the efficient deep_iterator
        if (map.inorder())
            return new TIter<mgram_map::deep_iterator>(*this, minM, maxM);
        // sorting needed: use reordering_iterator
        return new TIter<mgram_map::reordering_iterator>(*this, true, minM, maxM);
    }

    virtual int order() const
    {
        return M;
    }
    virtual size_t size(int m) const
    {
        return (int) logP.size(m);
    }

protected:
    // computeSeenSums -- compute sum of seen m-grams, store at their history coord
    // If islog then P is logP, otherwise linear (non-log) P.
    template <class FLOATTYPE>
    static void computeSeenSums(const mgram_map &map, int M, const mgram_data<float> &P,
                                mgram_data<FLOATTYPE> &PSum, mgram_data<FLOATTYPE> &backoffPSum,
                                bool islog)
    {
        // dimension the accumulators and initialize them to 0
        PSum.init(M - 1);
        for (int m = 0; m <= M - 1; m++)
            PSum.assign(m, map.size(m), 0);

        backoffPSum.init(M - 1);
        for (int m = 0; m <= M - 1; m++)
            backoffPSum.assign(m, map.size(m), 0);

        // iterate over all seen m-grams
        msra::basetypes::fixed_vector<mgram_map::coord> histCoord(M); // index of history mgram
        for (mgram_map::deep_iterator iter(map, M); iter; ++iter)
        {
            int m = iter.order();
            if (m < M)
                histCoord[m] = iter;
            if (m == 0)
                continue;

            const mgram_map::key key = *iter;
            assert(m == key.order());

            float thisP = P[iter];
            if (islog)
            {
                if (thisP <= logzero)
                    continue; // pruned or otherwise lost
                thisP = exp(thisP);
            }
            else
            {
                if (thisP == 0.0f)
                    continue; // a pruned or otherwise lost m-gram
            }

            // parent entry
            const mgram_map::coord j = histCoord[m - 1]; // index of parent entry

            // accumulate prob in B field (temporarily misused)
            PSum[j] += thisP;

            // the mass of the back-off distribution covered by higher-order seen m-grams.
            // This must exist, as any sub-sequence of any seen m-mgram exists
            // due to the way we count the tokens.
            const mgram_map::key boKey = key.pop_h();
            const mgram_map::foundcoord c = map[boKey];
            if (!c.valid_w())
                RuntimeError("estimate: malformed data: back-off value not found"); // must exist
            // look it up
            float Pc = P[c];
            backoffPSum[j] += islog ? exp(Pc) : Pc;
        }
    }

    // computeBackoff -- compute back-off weights
    // Set up or update logB[] based on P[].
    // logB[] is an output from this function only.
    // If islog then P is logP, otherwise linear (non-log) P.
    static void computeBackoff(const mgram_map &map, int M,
                               const mgram_data<float> &P, mgram_data<float> &logB,
                               bool islog)
    {
        mgram_data<float> backoffPSum; // accumulator for the probability mass covered by seen m-grams

        // sum up probabilities of seen m-grams
        //  - we temporarily use the B field for the actual seen probs
        //  - and backoffSum for their prob pretending we are backing off
        computeSeenSums(map, M, P, logB, backoffPSum, islog);
        // That has dimensioned logB as we need it.

        // derive the back-off weight from it
        for (mgram_map::deep_iterator iter(map, M - 1); iter; ++iter)
        {
            double seenMass = logB[iter]; // B field misused: sum over all seen children
            if (seenMass > 1.0)
            {
                if (seenMass > 1.0001) // (a minor round-off error is acceptable)
                    fprintf(stderr, "estimate: seen mass > 1.0: %8.5f --oops??\n", seenMass);
                seenMass = 1.0; // oops?
            }

            // mass covered by seen m-grams is unused -> take out
            double coveredBackoffMass = backoffPSum[iter];
            if (coveredBackoffMass > 1.0)
            {
                if (coveredBackoffMass > 1.0001) // 1.0 for unigrams, sometimes flags this
                    fprintf(stderr, "estimate: unseen backoff mass < 0: %8.5f --oops??\n", 1.0 - coveredBackoffMass);
                coveredBackoffMass = 1.0; // oops?
            }

            // redistribute such that
            //      seenMass + bow * usedBackoffMass = 1
            //  ==> bow = (1 - seenMass) / usedBackoffMass
            double freeMass = 1.0 - seenMass;
            double accessibleBackoffMass = 1.0 - coveredBackoffMass; // sum of all backed-off items

            // back-off weight is just the free probability mass
            double bow = (accessibleBackoffMass > 0) ? freeMass / accessibleBackoffMass : 1.0;
            // A note on the curious choice of bow=1.0 for accessibleBackoffMass==0:
            // If accessibleBackoffMass==0, we are in undefined territory.
            // Because this means we never back off. Problem is that we have
            // already discounted the probabilities, i.e. there is probability
            // mass missing (distribution not normalized). Possibilities for
            // remedying the normalization issue are:
            //  1. use linear interpolation instead generally
            //  2. use linear interpolation only for such distributions
            //  3. push mass into <UNK> class if available
            //  4. ignore the normalization problem.
            // We choose 2. for the unigram distribution (enforced outside of this
            // function), and 4. for all other cases.
            // A second question arises for OOV words in this case. With OOVs,
            // accessibleBackoffMass should no longer be 0, but we don't know its
            // value. Be Poov the mass of all OOV words, then
            //  bow = (1 - seenMass) / Poov
            // Further, if seenMass was not discounted (as in our unigram case),
            // it computes to 1, but if we had accounted for Poov, it would
            // compute as (1-Poov) instead. Thus,
            //  bow = (1 - (1-Poov)) / Poov = 1
            // Realistically, this case happens for the unigram distribution.
            // Practically it means fallback instead of back-off for OOV words.
            // Also, practically, Poov is very small, so is the error.
            logB[iter] = logclip((float) bow);
        }
    }
};

// ===========================================================================
// CMGramLMIterator -- a special-purpose class that allows for direct iteration.
// ===========================================================================

class CMGramLMIterator : public msra::lm::mgram_map::iterator
{
    const CMGramLM &lm;

public:
    CMGramLMIterator(const CMGramLM &lm, mgram_map::coord c)
        : lm(lm), msra::lm::mgram_map::iterator(lm.map, c)
    {
    }
    float logP() const
    {
        return lm.logP[*this];
    }
    float logB() const
    {
        return lm.logB[*this];
    }
    float logB(mgram_map::coord c) const
    {
        return lm.logB[c];
    }
    msra::lm::mgram_map::coord locate(const int *mgram, int m2) const
    {
        msra::lm::mgram_map::foundcoord c = lm.map[msra::lm::mgram_map::key(mgram, m2)];
        if (!c.valid_w())
            LogicError("locate: attempting to locate a non-existing history");
        return c;
    }
};

}; }; // namespace
back to top