# Copyright (c) Microsoft. All rights reserved.
# Licensed under the MIT license. See LICENSE file in the project root for full license information.

# G2P.cntk                                                                 #
#                                                                          #
# Example for sequence-to-sequence modeling for grapheme-to-phoneme        #
# (aka letter-to-sound) conversion on the CMUDict                          #

# directory defaults (if not overridden)

RunRootDir = "."      # default if not overridden
DataDir    = "$RunRootDir$/../Data"
OutDir     = "$RunRootDir$/Out"

# command to execute

command = train
#command = write
#command = dump

makeMode = false          # set this to true to enable restarting fr0m checkpoint
traceLevel = 1

# experiment id

deviceId = 0              # set the GPU device here, or "auto" to auto-select; or override from the command line.
ExpId = g2p-01-$deviceId$ # choose a meaningful id here. This is used for unique directory and filenames.

# model

modelPath  = "$OutDir$/$ExpId$/G2P.dnn"
stderr     = "$OutDir$/$ExpId$/G2P"

# decoding config  --used by the "write" command ("write" decodes and writes the result)

beamDepth = 3                                      # 0=predict; 1=greedy; >1=beam
decodeModel = 9
decodeModelPath = "$modelPath$.$decodeModel$"      # note: epoch to decode is appended to the model path
decodeOutputPath = "$decodeModelPath$.$beamDepth$" # results are written next to the model, with beamDepth appended

# dump config  --used by the "dump" command, for inspecting the model parameters

dumpModelPath = "$modelPath$.2" # put the epoch id here

# top-level model configuration
hiddenDim = 512
precision  = "float"
maxLayer = 2
isBidirectional = false

# vocab
inputVocabSize = 69     # shared vocab (even though they are separate)
labelVocabSize = 69     

# corpus
maxLength = 20           # 0 disables attention
startSymbol = "<s>"      # (need to override the default which is </s>)
trainFile   = "cmudict-0.7b.train-dev-20-21.ctf"
validFile   = "cmudict-0.7b.train-dev-1-21.ctf"
testFile    = "cmudict-0.7b.test.ctf"
mappingFile = "cmudict-0.7b.mapping"

# some reader variables that occur multiple times
cntkReaderInputDef = { rawInput = { alias = "S0" ; dim = $inputVocabSize$ ; format = "sparse" } ; rawLabels = { alias = "S1" ;  dim = $labelVocabSize$ ;  format = "sparse" } }

#  network definition                 #

BrainScriptNetworkBuilder = (new ComputationNetwork {

    inputVocabDim = $inputVocabSize$
    labelVocabDim = $labelVocabSize$

    attentionSpan = $maxLength$         # attention window, must be large enough for largest input sequence. 0 to disable. Exactly 20 is needed for the g2p CMUDict task
    useBidirectionalEncoder = $isBidirectional$ # bi-directional LSTM for encoder

    hiddenDim       = $hiddenDim$
    attentionDim    = 128               # dim of attention  projection
    maxLayer        = $maxLayer$        # e.g. 2 for 3 hidden layers

    useStabilizer = true
    useEncoder    = true                # if false, this becomes a regular RNN
    useNYUStyle   = false               # if true use thought vector for all inputs, NYU-style

    # dimensions
    embeddingDim = 200
    inputEmbeddingDim = if inputVocabDim < 200 then inputVocabDim else embeddingDim
    labelEmbeddingDim = if labelVocabDim < 200 then labelVocabDim else embeddingDim

    encoderDims[i:0..maxLayer] = hiddenDim # this defines the number of hidden layers in each
    decoderDims[i:0..maxLayer] = hiddenDim # both are one LSTM layer only for now

    # inputs

    # inputs and axes must be defined on top-scope level in order to get a clean node name from BrainScript.
    inputAxis = DynamicAxis()
    rawInput  = Input (inputVocabDim, dynamicAxis=inputAxis, tag='feature')
    rawLabels = Input (labelVocabDim, tag='label')

    # inputs and labels are expected to be surrounded by sentence delimiters, e.g. <s> A B C </s>  ==>  <s> D E F </s>
    # The encoder uses all tokens of 'input', while for the target labels we must exclude the initial sentence start, which is only used as the LM history.
    inputSequence = Pass (rawInput)                             # e.g. <s> A   B   C    </s>
    labelSequence = Pass (Slice (1,  0, rawLabels,  axis=-1))   # e.g. D   E   F   </s>
    labelSentenceStart = Pass (BS.Sequences.First (rawLabels))  # e.g. <s>
    inputSequenceDim = inputVocabDim
    labelSequenceDim = labelVocabDim

    isFirstLabel = BS.Loop.IsFirst (labelSequence)

    # embeddings

    # Note: when reading input and labels from a single text file, we share the token mapping and embedding.
    # Note: Embeddings are linear. Should we use BatchNormalization?

    # note: this is assumed to be applied transposed, hence the swapped dimensions. Actually--why? Still needed?
    Einput  = BS.Parameters.WeightParam (inputSequenceDim, inputEmbeddingDim)
    Elabels = BS.Parameters.WeightParam (labelSequenceDim, labelEmbeddingDim)
    EmbedInput (x)  = if inputSequenceDim == inputEmbeddingDim then Pass(x) else TransposeTimes (Einput, x)
    EmbedLabels (x) = if labelSequenceDim == labelEmbeddingDim then Pass(x) else TransposeTimes (Elabels, x)

    inputEmbedded  = EmbedInput  (inputSequence)
    labelsEmbedded = EmbedLabels (labelSequence)
    labelSentenceStartEmbedded = Pass (EmbedLabels (labelSentenceStart))  # TODO: remove Pass() if not actually needed in decoder
    labelSentenceStartEmbeddedScattered = BS.Sequences.Scatter (isFirstLabel, labelSentenceStartEmbedded) # unfortunately needed presently

    S(x) = BS.Parameters.Stabilize (x, enabled=useStabilizer)

    # encoder (processes inputEmbedded)

    # Note: We reverse our input by running the recurrence from right to left.

    encoderFunction = if useBidirectionalEncoder then BS.RNNs.RecurrentBidirectionalLSTMPStack else BS.RNNs.RecurrentLSTMPStack
    encoder = encoderFunction (encoderDims, cellDims=encoderDims, S(inputEmbedded), inputDim=inputEmbeddingDim,
        previousHook=if useBidirectionalEncoder then BS.RNNs.PreviousHC else BS.RNNs.NextHC,
    encoderOutput = encoder[Length (encoderDims)-1]

    # get the final encoder state for use as the initial state (not used with attention model)
    # Since we run right-to-left, the final state is the first, not the last.
    # For beam decoding, we will also inject a second dimension.
    thoughtVector = {
        h = ReshapeDimension (BS.Sequences.First (encoderOutput.h), 1, (dim:1))
        c = ReshapeDimension (BS.Sequences.First (encoderOutput.c), 1, (dim:1))
        dim = encoderOutput.dim

    thoughtVectorBroadcast = { # broadcast to all time steps of the target sequence
        h = BS.Sequences.BroadcastSequenceAs (labelsEmbedded, thoughtVector.h)
        c = BS.Sequences.BroadcastSequenceAs (labelsEmbedded, thoughtVector.c)
        dim = thoughtVector.dim

    # decoder reordering hook: propagation of beam hypotheses

    # we bake into the LSTMs to multiply h and c with the 'beamSearchReorderHook' matrix, which is
    # a dummy in training but will be patched through model editing for beam decoding.
    # Specifically, the decoder will replace this by a per-sample matrix that reorders hypotheses according to
    # how they propagate. E.g. the 2nd best in a frame may be the history of the 3rd best in the subsequent frame

    beamSearchReorderHook = Pass (BS.Constants.OnesTensor (1:1))

    # helper functions to delay h and c that apply beam-search reordering, if so configured

    PreviousHCWithReorderingHook (lstmState, layerIndex=0) = {
       h = BS.Loop.Previous (lstmState.h * beamSearchReorderHook)             // hidden state(t-1)
       c = BS.Loop.Previous (lstmState.c * beamSearchReorderHook)             // cell(t-1)
       dim = lstmState.dim

    PreviousHCFromThoughtVectorWithReorderingHook (lstmState, layerIndex=0) =
        if layerIndex > 0 then PreviousHCWithReorderingHook (lstmState, layerIndex=1)
        else { # with both thought vector and beam-search hook
            isFirst = BS.Loop.IsFirst (labelsEmbedded)
            h = BS.Boolean.If (isFirst, thoughtVectorBroadcast.h, BS.Loop.Previous (lstmState.h * beamSearchReorderHook))
            c = BS.Boolean.If (isFirst, thoughtVectorBroadcast.c, BS.Loop.Previous (lstmState.c * beamSearchReorderHook))
            dim = lstmState.dim

    # decoder history hook: LM history, from ground truth vs. output

    # these are the two choices for the input to the decoder network
    decoderHistoryFromGroundTruth = labelsEmbedded              # for training, decoder input is ground truth...
    decoderHistoryFromOutput = Pass (EmbedLabels (Hardmax (z))) # ...but for (greedy) decoding, the decoder's output is its previous input

    # during training, we use ground truth. For decoding, we will rewire decoderHistoryHook = decoderHistoryFromOutput
    decoderHistoryHook = Pass (decoderHistoryFromGroundTruth) # this gets redirected in decoding to feed back decoding output instead

    # decoder

    # There are three ways of passing encoder state:
    #  1. as initial state for decoder (Google style)
    #  2. as side information for every decoder step (NYU style)
    #  3. attention

    decoderInput    = Pass (BS.Boolean.If (isFirstLabel, labelSentenceStartEmbeddedScattered, BS.Loop.Previous (decoderHistoryHook)))
    decoderInputDim = labelEmbeddingDim

    decoderDynamicAxis = labelsEmbedded
    FixedWindowAttentionHook = BS.Seq2Seq.CreateAugmentWithFixedWindowAttentionHook (attentionDim, attentionSpan, decoderDynamicAxis, encoderOutput, enableSelfStabilization=useStabilizer)

    # some parameters to the decoder stack depend on the mode
    decoderParams =
        # with attention
        if useEncoder && attentionSpan > 0 then {
            previousHook = PreviousHCWithReorderingHook # add reordering for beam search
            augmentInputHook = FixedWindowAttentionHook # input gets augmented by the attention window
            augmentInputDim = encoderOutput.dim
        # with thought vector appended to every frame
        else if useEncoder && useNYUStyle then {
            previousHook = PreviousHCWithReorderingHook
            augmentInputHook (input, lstmState) = S(thoughtVectorBroadcast.h) # each input frame gets augmented by the thought vector
            augmentInputDim = thoughtVector.dim
        # thought vector as initial state for decoder
        else {
            previousHook = PreviousHCFromThoughtVectorWithReorderingHook # Previous() function with thought vector as initial state
            augmentInputHook = BS.RNNs.NoAuxInputHook
            augmentInputDim = 0

    # this is the decoder LSTM stack
    decoder = BS.RNNs.RecurrentLSTMPStack (decoderDims, cellDims=decoderDims,
                                           S(decoderInput), inputDim=decoderInputDim,
                                           augmentInputHook=decoderParams.augmentInputHook, augmentInputDim=decoderParams.augmentInputDim,

    decoderOutputLayer = Length (decoder)-1
    decoderOutput = decoder[decoderOutputLayer].h
    decoderDim = decoderDims[decoderOutputLayer]

    # softmax output layer

    W = BS.Parameters.WeightParam (labelSequenceDim, decoderDim)
    B = BS.Parameters.BiasParam (labelSequenceDim)

    z = W * S(decoderOutput) + B;  // top-level input to Softmax

    # training criteria

    ce   = Pass (ReduceLogSum (z) - TransposeTimes (labelSequence,          z),  tag='criterion')
    errs = Pass (BS.Constants.One - TransposeTimes (labelSequence, Hardmax (z)), tag='evaluation')

    # score output for decoding
    scoreSequence = Pass (z)


#  TRAINING CONFIG                    #

train = {
    action = "train"
    traceLevel = 1
    epochSize = 0               # (for quick tests, this can be overridden with something small)

    # BrainScriptNetworkBuilder is defined in outer scope

    SGD = {
        minibatchSize = 72
        learningRatesPerSample = 0.007
        momentumAsTimeConstant = 1100
        gradientClippingWithTruncation = true   # (as opposed to clipping the Frobenius norm of the matrix)
        clippingThresholdPerSample = 2.3   # visibly impacts objectives, but not final result, so keep it for safety
        maxEpochs = 50
        numMBsToShowResult = 100
        firstMBsToShowResult = 10
        gradUpdateType = FSAdaGrad # TODO: Try FSAdaGrad?
        loadBestModel = false      # true # broken for some models (rereading overwrites something that got set by validation)

        dropoutRate = 0.0

        # settings for Auto Adjust Learning Rate
        AutoAdjust = {
            autoAdjustLR = "adjustAfterEpoch"
            reduceLearnRateIfImproveLessThan = 0.001
            continueReduce = false
            increaseLearnRateIfImproveMoreThan = 1000000000
            learnRateDecreaseFactor = 0.5
            learnRateIncreaseFactor = 1.382
            numMiniBatch4LRSearch = 100
            numPrevLearnRates = 5
            numBestSearchEpoch = 1

    # reader definitions
    reader = {
        randomize = true
        deserializers = ({
            type = "CNTKTextFormatDeserializer" ; module = "CNTKTextFormatReader"
            file = "$DataDir$/$trainFile$"
            input = $cntkReaderInputDef$


    cvreader = {
        randomize = false
        deserializers = ({
            type = "CNTKTextFormatDeserializer" ; module = "CNTKTextFormatReader"
            file = "$DataDir$/$validFile$"
            input = $cntkReaderInputDef$


#  DUMP CONFIG                        #

# dumps the model, specifically the learnable parameters

dump = {
    action = "dumpnode"
    modelPath = "$dumpModelPath$"
    outputFile = "$dumpModelPath$.txt"

#  WRITE CONFIG                       #

# This will decode the test set. The beamDepth parameter specifies the decoding mode:
#  beamDepth = 0: word prediction given ground truth history (only useful for perplexity measurement)
#  beamDepth = 1: greedy decoding: At each time step, choose a word greedily
#  beamDepth > 1: beam decoder. Keep 'beamDepth' best hypotheses, and output their globally best at the end.

write = {
    action = "write"

    # select the decoder
    BrainScriptNetworkBuilder = (
        # beamDepth = 0 will decode with the unmodified model.
        # beamDepth = 1 will modify the model to use the decoding output as the decoder's input.
        # beamDepth > 1 will modify the model to track multiple hypotheses and select the globally best at the end.
        if      $beamDepth$ == 0 then BS.Network.Load ("$decodeModelPath$")
        else if $beamDepth$ == 1 then BS.Seq2Seq.GreedySequenceDecoderFrom (BS.Network.Load ("$decodeModelPath$"))
        else                          BS.Seq2Seq.BeamSearchSequenceDecoderFrom (BS.Network.Load ("$decodeModelPath$"), $beamDepth$)

    outputPath = $decodeOutputPath$
    #outputPath = "-"                    # "-" will write to stdout; useful for debugging

    # declare the nodes we want to write out
    # not all decoder configs have the same node names, so we just list them all
    #outputNodeNames = inputsOut:labelsOut:decodeOut:network.beamDecodingModel.inputsOut:network.beamDecodingModel.labelsOut:network.beamDecodingModel.decodeOut

    # output format
    # We configure the output to emit a flat sequence of token strings.
    format = {
        type = "category"
        transpose = false
        labelMappingFile = "$DataDir$/$mappingFile$"

    minibatchSize = 1024                # choose this to be big enough for the longest sentence
    traceLevel = 1
    epochSize = 0

    reader = {
        randomize = false
        deserializers = ({
            type = "CNTKTextFormatDeserializer" ; module = "CNTKTextFormatReader"
            file = "$DataDir$/$testFile$"
            input = $cntkReaderInputDef$

