We are hiring ! See our job offers.
Revision e1467a79dc6580ae009d827b5e6f274faff3b339 authored by liqunfu on 27 March 2020, 21:42:04 UTC, committed by GitHub on 27 March 2020, 21:42:04 UTC
2 parents c7bc93f + a2055f6
Raw File
RegrSimple_CIFAR10.cntk
# RegrSimple_CIFAR10.cntk shows how to train a regression model on CIFAR-10 image data.
# It uses a very simple network and a composite reader using both the ImageReader
# and CNTKTextFormatReader and defines the RMSE (root mean square error) as the
# loss function. The value that the network learns to predict are simply the
# average rgb values of an image normalized to [0, 1].

command = TrainConvNet:Write:Test

makeMode = false ; traceLevel = 1 ; deviceId = "auto"

rootDir = ".." ; dataDir  = "$rootDir$/DataSets/CIFAR-10" ;
outputDir = "Output" ; modelDir = "$outputDir$/Models"

modelPath = "$ModelDir$/RegrSimple_CIFAR10.cmf"

# Training action for a convolutional network
TrainConvNet = {
    action = "train"

    BrainScriptNetworkBuilder = [
        imageShape = 32:32:3
        featScale = Constant(1/256)
        labelDim = 3

        model (features) = {
            featNorm = Scale(features, featScale)
            h1 = LinearLayer {100,      init="gaussian", initValueScale=1.5} (featNorm)
            ol = LinearLayer {labelDim, init="gaussian", initValueScale=1.5} (h1)
        }.ol

        # inputs
        features = Input {imageShape}
        regrLabels = Input {labelDim}
        
        # apply model to features
        ol = model (features)

        # define regression loss
        diff = regrLabels - ol
        sqerr = ReduceSum (diff.*diff, axis=1)
        rmse =  Sqrt (sqerr / labelDim)

        featureNodes    = (features)
        labelNodes      = (regrLabels)
        criterionNodes  = (rmse)
        evaluationNodes = (rmse)
        OutputNodes     = (ol)
    ]

    SGD = {
        epochSize = 0

        maxEpochs = 2
        minibatchSize = 128
        learningRatesPerSample = 0.0005
        momentumAsTimeConstant = 1024

        firstMBsToShowResult = 5 ; numMBsToShowResult = 50
    }

    # The reader is a composite reader that uses the ImageReader to read images
    # and the CNTKTextFormatReader to read the regression ground truth labels.
    # It does so by defining an array of deserializers (using {...} : {...} below)
    # and assigning the inputs as defined in the network above (cf. features and regrLabels).
    reader = {
        verbosity = 0 ; randomize = true
        deserializers = ({
            type = "ImageDeserializer" ; module = "ImageReader"
            file = "$dataDir$/train_map.txt"
            input = {
                features = { transforms = (
                    { type = "Scale" ; width = 32 ; height = 32 ; channels = 3 ; interpolations = "linear" } :
                    { type = "Transpose" }
                )}
                ignored = { labelDim = 10 }
            }
        } : {
            type = "CNTKTextFormatDeserializer" ; module = "CNTKTextFormatReader"
            file = "$dataDir$/train_regrLabels.txt"
            input = {
                regrLabels = { dim = 3 ; format = "dense" }
            }
        })
    }
}

# Write action
Write = {
    action = "write"
    minibatchSize = 1
    outputNodeNames = (ol, regrLabels, rmse)
    outputPath = "$OutputDir$/RegrSimple_CIFAR10"
    
    reader = {
        verbosity = 0 ; randomize = false
        deserializers = ({
            type = "ImageDeserializer" ; module = "ImageReader"
            file = "$dataDir$/test_map.txt"
            input = {
                features = { transforms = (
                    { type = "Scale" ; width = 32 ; height = 32 ; channels = 3 ; interpolations = "linear" } :
                    { type = "Transpose" }
                )}
                ignored = { labelDim = 10 }
            }
        } : {
            type = "CNTKTextFormatDeserializer" ; module = "CNTKTextFormatReader"
            file = "$dataDir$/test_regrLabels.txt"
            input = {
                regrLabels = { dim = 3 ; format = "dense" }
            }
        })
    }
}

# Test action
Test = {
    action = "test"
    minibatchSize = 512
    outputNodeNames = (ol, regrLabels, rmse)
    outputPath = "$OutputDir$/RegrSimple_CIFAR10"
    
    reader = {
        verbosity = 0 ; randomize = false
        deserializers = ({
            type = "ImageDeserializer" ; module = "ImageReader"
            file = "$dataDir$/test_map.txt"
            input = {
                features = { transforms = (
                    { type = "Scale" ; width = 32 ; height = 32 ; channels = 3 ; interpolations = "linear" } :
                    { type = "Transpose" }
                )}
                ignored = { labelDim = 10 }
            }
        } : {
            type = "CNTKTextFormatDeserializer" ; module = "CNTKTextFormatReader"
            file = "$dataDir$/test_regrLabels.txt"
            input = {
                regrLabels = { dim = 3 ; format = "dense" }
            }
        })
    }
}
back to top