Raw File
Tip revision: d68dc7775ff53745fa898450c6a6306f35b8821d authored by Frank Seide on 30 September 2016, 20:26:46 UTC
lots of additions to
Tip revision: d68dc77
# macros to include
load = ndlMacroDefine

# the actual NDL that defines the network
run = ndlCreateNetwork_LSTMP_c1024_p256_x3

ndlMacroDefine = [
    # Macro definitions
        xMean = Mean(x);
        xStdDev = InvStdDev(x)
        xNorm = PerDimMeanVarNormalization(x, xMean, xStdDev)
    LogPrior(labels) = [
        prior = Mean(labels)
        logPrior = Log(Prior)

    LSTMPComponent(inputDim, outputDim, cellDim, inputx, cellDimX2, cellDimX3, cellDimX4) = [
        wx = Parameter(cellDimX4, 0, init="uniform", initValueScale=1);
        b  = Parameter(cellDimX4, 1, init="fixedValue", value=0.0);
        Wh = Parameter(cellDimX4, 0, init="uniform", initValueScale=1);

        Wci = Parameter(cellDim, init="uniform", initValueScale=1);
        Wcf = Parameter(cellDim, init="uniform", initValueScale=1);
        Wco = Parameter(cellDim, init="uniform", initValueScale=1);

        dh = PastValue(outputDim, output, timeStep=1);
        dc = PastValue(cellDim, ct, timeStep=1);

        wxx = Times(wx, inputx);
        wxxpb = Plus(wxx, b);
        whh = Times(wh, dh);

        wxxpbpwhh = Plus(wxxpb,whh)
        G1 = RowSlice(0, cellDim, wxxpbpwhh)
        G2 = RowSlice(cellDim, cellDim, wxxpbpwhh)
        G3 = RowSlice(cellDimX2, cellDim, wxxpbpwhh);
        G4 = RowSlice(cellDimX3, cellDim, wxxpbpwhh);

        Wcidc = DiagTimes(Wci, dc);
        it = Sigmoid (Plus ( G1, Wcidc));

        bit = ElementTimes(it, Tanh( G2 ));

        Wcfdc = DiagTimes(Wcf, dc);
        ft = Sigmoid( Plus (G3, Wcfdc));

        bft = ElementTimes(ft, dc);

        ct = Plus(bft, bit);

        Wcoct = DiagTimes(Wco, ct);
        ot = Sigmoid( Plus( G4, Wcoct));

        mt = ElementTimes(ot, Tanh(ct));

        Wmr = Parameter(outputDim, cellDim, init="uniform", initValueScale=1);
        output = Times(Wmr, mt); 
    LSTMPComponentBetter(inputDim, outputDim, cellDim, inputx, cellDimX2, cellDimX3, cellDimX4) = [
        wx = Parameter(cellDimX4, 0, init="uniform", initValueScale=1);
        b  = Parameter(cellDimX4, 1, init="fixedValue", value=0.0);
        Wh = Parameter(cellDimX4, 0, init="uniform", initValueScale=1);

        Wci = Parameter(cellDim, init="uniform", initValueScale=1);
        Wcf = Parameter(cellDim, init="uniform", initValueScale=1);
        Wco = Parameter(cellDim, init="uniform", initValueScale=1);

        dh = PastValue(outputDim, output, timeStep=1);
        dc = PastValue(cellDim, ct, timeStep=1);

        wxx = Times(wx, inputx);
        wxxpb = Plus(wxx, b);
        whh = Times(wh, dh);

        Wxix = RowSlice(0, cellDim, wxx); #Times(Wxi, inputx);
        Whidh = RowSlice(0, cellDim, whh); #Times(Whi, dh);
        Wcidc = DiagTimes(Wci, dc);

        it = Sigmoid (Plus ( Plus (Wxix, Whidh), Wcidc));

        Wxcx = RowSlice(cellDim, cellDim, wxx); #Times(Wxc, inputx);
        Whcdh = RowSlice(cellDim, cellDim, whh); #Times(Whc, dh);
        bit = ElementTimes(it, Tanh( Plus(Wxcx, Whcdh)));

        Wxfx = RowSlice(cellDimX2, cellDim, wxx); #Times(Wxf, inputx);
        Whfdh = RowSlice(cellDimX2, cellDim, whh); #Times(Whf, dh);
        Wcfdc = DiagTimes(Wcf, dc);

        ft = Sigmoid( Plus (Plus (Wxfx, Whfdh), Wcfdc));

        bft = ElementTimes(ft, dc);

        ct = Plus(bft, bit);

        Wxox  = RowSlice(cellDimX3, cellDim, wxx); #Times(Wxo, inputx);
        Whodh = RowSlice(cellDimX3, cellDim, whh); #Times(Who, dh);
        Wcoct = DiagTimes(Wco, ct);

        ot = Sigmoid( Plus( Plus( Wxox, Whodh), Wcoct));

        mt = ElementTimes(ot, Tanh(ct));

        Wmr = Parameter(outputDim, cellDim, init="uniform", initValueScale=1);
        output = Times(Wmr, mt); 

    LSTMPComponentNaive(inputDim, outputDim, cellDim, inputx) = [
        Wxo = Parameter(cellDim, 0, init="uniform", initValueScale=1);
        Wxi = Parameter(cellDim, 0, init="uniform", initValueScale=1);
        Wxf = Parameter(cellDim, 0, init="uniform", initValueScale=1);
        Wxc = Parameter(cellDim, 0, init="uniform", initValueScale=1);

        bo = Parameter(cellDim, init="fixedValue", value=0.0);
        bc = Parameter(cellDim, init="fixedValue", value=0.0);
        bi = Parameter(cellDim, init="fixedValue", value=0.0);
        bf = Parameter(cellDim, init="fixedValue", value=0.0);

        Whi = Parameter(cellDim, 0, init="uniform", initValueScale=1);

        Wci = Parameter(cellDim, init="uniform", initValueScale=1);

        Whf = Parameter(cellDim, 0, init="uniform", initValueScale=1);
        Wcf = Parameter(cellDim, init="uniform", initValueScale=1);
        Who = Parameter(cellDim, 0, init="uniform", initValueScale=1);
        Wco = Parameter(cellDim, init="uniform", initValueScale=1);
        Whc = Parameter(cellDim, 0, init="uniform", initValueScale=1);

        dh = PastValue(outputDim, output, timeStep=1);
        dc = PastValue(cellDim, ct, timeStep=1);

        Wxix = Times(Wxi, inputx);
        Whidh = Times(Whi, dh);
        Wcidc = DiagTimes(Wci, dc);

        it = Sigmoid (Plus ( Plus (Plus (Wxix, bi), Whidh), Wcidc));

        Wxcx = Times(Wxc, inputx);
        Whcdh = Times(Whc, dh);
        bit = ElementTimes(it, Tanh( Plus(Wxcx, Plus(Whcdh, bc))));

        Wxfx = Times(Wxf, inputx);
        Whfdh = Times(Whf, dh);
        Wcfdc = DiagTimes(Wcf, dc);

        ft = Sigmoid( Plus (Plus (Plus(Wxfx, bf), Whfdh), Wcfdc));

        bft = ElementTimes(ft, dc);

        ct = Plus(bft, bit);

        Wxox  = Times(Wxo, inputx);
        Whodh = Times(Who, dh);
        Wcoct = DiagTimes(Wco, ct);

        ot = Sigmoid( Plus( Plus( Plus(Wxox, bo), Whodh), Wcoct));

        mt = ElementTimes(ot, Tanh(ct));

        Wmr = Parameter(outputDim, cellDim, init="uniform", initValueScale=1);
        output = Times(Wmr, mt); 

ndlCreateNetwork_LSTMP_c1024_p256_x3 = [
    # define basic i/o
    basefeatDim = 363
    rowSliceStart = 1200 
    featDim = 363
    labelDim = 132
    cellDim = 1024
    cellDimX2 = 2048
    cellDimX3 = 3072
    cellDimX4 = 4096 
    hiddenDim = 512

    features = Input(featDim)
    labels = Input(labelDim)

    featNorm = MeanVarNorm(features)

    # layer 1
    LSTMoutput1 = LSTMPComponent(basefeatDim, hiddenDim, cellDim, featNorm, cellDimX2, cellDimX3, cellDimX4);
    # layer 2 
    LSTMoutput2 = LSTMPComponent(hiddenDim, hiddenDim, cellDim, LSTMoutput1, cellDimX2, cellDimX3, cellDimX4);
    # layer 3 
    LSTMoutput3 = LSTMPComponent(hiddenDim, hiddenDim, cellDim, LSTMoutput2, cellDimX2, cellDimX3, cellDimX4);

    W = Parameter(labelDim, 0, init="uniform", initValueScale=1);
    b = Parameter(labelDim, 1, init="fixedValue", value=0);
    LSTMoutputW = Plus(Times(W, LSTMoutput3), b);

    ce = CrossEntropyWithSoftmax(labels, LSTMoutputW);
    err = ClassificationError(labels, LSTMoutputW);
    logPrior = LogPrior(labels)
    scaledLogLikelihood = Minus(LSTMoutputW, logPrior)
    # Special Nodes
    FeatureNodes = (features)
    LabelNodes = (labels)
    CriterionNodes = (ce)
    EvalNodes = (err)
    OutputNodes = (scaledLogLikelihood)
back to top