https://github.com/Microsoft/CNTK
Tip revision: 476a60cc2c353d657f61923e92c2806a680c412c authored by Bowen Bao on 02 July 2018, 17:47:37 UTC
small tweak in seq conv to avoid additional gpu memory allocation and increase performance.
small tweak in seq conv to avoid additional gpu memory allocation and increase performance.
Tip revision: 476a60c
CNTK.cpp
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//
// CNTK.cpp : Defines the entry point for the console application.
//
#define _CRT_NONSTDC_NO_DEPRECATE // make VS accept POSIX functions without _
#include "stdafx.h"
#ifdef _WIN32
#include <crtdbg.h>
#endif
#include "Basics.h"
#include "Globals.h"
#include "Actions.h"
#include "ComputationNetwork.h"
#include "ComputationNode.h"
#include "DataReader.h"
#include "DataWriter.h"
#include "SimpleNetworkBuilder.h"
#include "NDLNetworkBuilder.h"
#include "ModelEditLanguage.h"
#include "CPUMatrix.h" // used for SetNumThreads()
#include "CommonMatrix.h"
#include "SGD.h"
#include "MPIWrapper.h"
#include "EnvironmentUtil.h"
#include "Config.h"
#include "SimpleEvaluator.h"
#include "SimpleOutputWriter.h"
#include "BestGpu.h"
#include "ProgressTracing.h"
#include "fileutil.h"
#include "ScriptableObjects.h"
#include "BrainScriptEvaluator.h"
#include "BrainScriptParser.h"
#include "PerformanceProfiler.h"
#include "CNTKLibrary.h"
#include <string>
#include <chrono>
#include <algorithm>
#if defined(_WIN32)
#include "io.h"
#include <DelayImp.h>
#pragma comment(lib, "Delayimp.lib")
#pragma comment(lib, "shlwapi.lib")
#endif
#include "buildinfo.h"
#include "hostname.h"
#ifdef LEAKDETECT
#include "vld.h" // for memory leak detection
#endif
#include <vector>
#include <iostream>
#include <queue>
#include <set>
#include <memory>
#ifndef let
#define let const auto
#endif
using namespace std;
using namespace Microsoft::MSR;
using namespace Microsoft::MSR::CNTK;
// internal test routine forward declaration
template <typename ElemType>
void TestCn(const ConfigParameters& config);
// Setup profiling
template <typename ConfigParamType>
void SetupProfiling(ProfilerContext& profilerContext, const ConfigParamType& config, int nodeRank)
{
if (config(L"profilerEnabled", false))
{
wstring workDir = config(L"WorkDir", L".");
profilerContext.Init(workDir + L"/profiler",
config(L"profilerBufferSize", static_cast<uint64_t>(32 * 1024 * 1024)),
std::to_wstring(nodeRank),
config(L"profilerSyncGpu", true));
}
}
void RedirectStdErr(wstring logpath, bool appendLogFile = false)
{
// TODO: if there is already a file, rename it
LOGPRINTF(stderr, "Redirecting stderr to file %S\n", logpath.c_str());
auto fileOption = appendLogFile ? fileOptionsAppend : fileOptionsWrite;
auto f = make_shared<File>(logpath.c_str(), fileOption | fileOptionsText);
if (dup2(fileno(*f), 2) == -1)
{
RuntimeError("unexpected failure to redirect stderr to log file");
}
setvbuf(stderr, NULL, _IONBF, 16384); // unbuffer it
static auto fKept = f; // keep it around (until it gets changed)
}
std::string WCharToString(const wchar_t* wst)
{
std::wstring ws(wst);
std::string s(ws.begin(), ws.end());
s.assign(ws.begin(), ws.end());
return s;
}
size_t GetMaxEpochs(const ConfigParameters& configParams)
{
ConfigParameters configSGD(configParams("SGD"));
size_t maxEpochs = configSGD("maxEpochs");
return maxEpochs;
}
// Currently we force determinism by setting compatibility mode for different CPU versions
// and limiting computation to a single CPU thread.
// TODO: Clarify how a single thread restriction can be lifted.
void ForceDeterministicAlgorithmsOnCPU()
{
LOGPRINTF(stderr, "WARNING: forceDeterministicAlgorithms flag is specified. Using 1 CPU thread for processing.\n");
CPUMatrix<float /*any type will do*/>::SetNumThreads(1);
CPUMatrix<float /*any type will do*/>::SetCompatibleMode();
}
#ifndef CPUONLY
// abort execution is GPU is not supported (e.g. compute capability not supported)
void CheckSupportForGpu(DEVICEID_TYPE deviceId)
{
auto gpuData = GetGpuData(deviceId);
if (gpuData.validity == GpuValidity::ComputeCapabilityNotSupported)
{
InvalidArgument("CNTK: The GPU (%s) has compute capability %d.%d. CNTK is only supported on GPUs with compute capability 3.0 or greater",
gpuData.name.c_str(), gpuData.versionMajor, gpuData.versionMinor);
}
else if (gpuData.validity == GpuValidity::UnknownDevice)
{
InvalidArgument("CNTK: Unknown GPU with Device ID %d.", gpuData.deviceId);
}
}
#endif
// special temporary function to guard against a now invalid usage of "truncated" which exists in some IPG production setups
static void DisableLegacyTruncationSettings(const ConfigParameters& TopLevelConfig, const ConfigParameters& commandConfig)
{
if (TopLevelConfig.ExistsCurrent(L"Truncated"))
{
return;
}
// if any of the action has set a reader/SGD section and has different Truncated value for reader and SGD section
ConfigArray actions = commandConfig(L"action");
for (size_t i = 0; i < actions.size(); i++)
{
if (actions[i] == "train" || actions[i] == "trainRNN")
{
ConfigParameters sgd = ConfigParameters(commandConfig(L"SGD"));
ConfigParameters reader = ConfigParameters(commandConfig(L"reader"));
// reader and SGD sections are two must-have sections in train/trainRNN
if (reader.ExistsCurrent(L"Truncated") && !sgd.ExistsCurrent(L"Truncated"))
{
InvalidArgument("DisableLegacyUsage: setting Truncated only in reader section are not allowed. Please move Truncated=true/false to the top level section.");
}
}
}
}
static void DisableLegacyUsage(const ConfigParameters& TopLevelConfig, const ConfigArray& commands)
{
for (size_t i = 0; i < commands.size(); i++)
{
ConfigParameters cfgParameters(TopLevelConfig(commands[i]));
DisableLegacyTruncationSettings(TopLevelConfig, cfgParameters);
}
}
// When running in parallel with MPI, only commands in 'commandstoRunOnAllRanks' should
// be run in parallel across multiple ranks. Others should only run on rank 0
const std::set<std::string> commandstoRunOnAllRanks = { "train", "trainRNN", "adapt", "test", "eval", "cv", "devtest", "bnstat" };
// process the command
template <typename ElemType>
void DoCommands(const ConfigParameters& config, const shared_ptr<MPIWrapper>& mpi)
{
ConfigArray command = config(L"command", "train");
if (Globals::ShouldForceDeterministicAlgorithms())
ForceDeterministicAlgorithmsOnCPU();
else
{
// Setting specified number of threads.
int numCPUThreads = config(L"numCPUThreads", "0");
numCPUThreads = CPUMatrix<ElemType>::SetNumThreads(numCPUThreads);
if (numCPUThreads > 0)
{
LOGPRINTF(stderr, "Using %d CPU threads.\n", numCPUThreads);
}
}
bool progressTracing = config(L"progressTracing", false);
// temporary hack to prevent users from failing due to a small breaking change related to the "truncated" flag (will be redone bigger and better some day)
DisableLegacyUsage(config, command);
// summarize command info upfront in the log and stdout
size_t fullTotalMaxEpochs = 0;
for (int i = 0; i < command.size(); i++)
{
// get the configuration parameters that match the command
ConfigParameters commandParams(config(command[i]));
ConfigArray action = commandParams("action", "train");
// determine the action to perform, and do it
for (int j = 0; j < action.size(); j++)
{
if (action[j] == "train" || action[j] == "trainRNN")
{
wstring modelPath = commandParams("modelPath");
size_t maxEpochs = GetMaxEpochs(commandParams);
if (progressTracing)
{
LOGPRINTF(stderr, "CNTKModelPath: %ls\n", modelPath.c_str());
LOGPRINTF(stderr, "CNTKCommandTrainInfo: %s : %d\n", command[i].c_str(), (int)maxEpochs);
}
fullTotalMaxEpochs += maxEpochs;
}
}
}
if (progressTracing)
{
LOGPRINTF(stderr, "CNTKCommandTrainInfo: CNTKNoMoreCommands_Total : %d\n", (int)fullTotalMaxEpochs);
}
// set up progress tracing for compute cluster management
if (progressTracing && (!mpi || mpi->IsMainNode()))
{
ProgressTracing::SetTracingFlag();
ProgressTracing::TraceTotalNumberOfSteps(fullTotalMaxEpochs); // enable tracing, using this as the total number of epochs
}
size_t fullEpochsOffset = 0;
// execute the commands
for (int i = 0; i < command.size(); i++)
{
// get the configuration parameters that match the command
const string thisCommand = command[i];
ConfigParameters commandParams(config(thisCommand));
ConfigArray action = commandParams("action", "train");
int traceLevel = commandParams("traceLevel", "0");
if (progressTracing && ((mpi == nullptr) || mpi->IsMainNode()))
{
ProgressTracing::SetStepOffset(fullEpochsOffset); // this is the epoch number that SGD will log relative to
}
// determine the action to perform, and do it
for (int j = 0; j < action.size(); j++)
{
const string thisAction = action[j];
// print a banner to visually separate each action in the log
const char* delim = "##############################################################################";
string showActionAs = thisCommand + " command (" + thisAction + " action)";
fprintf(stderr, "\n");
LOGPRINTF(stderr, "%s\n", delim);
LOGPRINTF(stderr, "#%*s#\n", (int)(strlen(delim) - 2), "");
LOGPRINTF(stderr, "# %s%*s #\n", showActionAs.c_str(), (int)(strlen(delim) - showActionAs.size() - 4), "");
LOGPRINTF(stderr, "#%*s#\n", (int)(strlen(delim) - 2), "");
LOGPRINTF(stderr, "%s\n\n", delim);
if ((mpi == nullptr) || (commandstoRunOnAllRanks.find(thisAction) != commandstoRunOnAllRanks.end()) || mpi->IsMainNode())
{
if (thisAction == "train" || thisAction == "trainRNN")
{
if (progressTracing)
{
LOGPRINTF(stderr, "CNTKCommandTrainBegin: %s\n", command[i].c_str());
}
DoTrain<ConfigParameters, ElemType>(commandParams);
if (progressTracing)
{
LOGPRINTF(stderr, "CNTKCommandTrainEnd: %s\n", command[i].c_str());
}
fullEpochsOffset += GetMaxEpochs(commandParams);
}
else if (thisAction == "bnstat")
{
DoBatchNormalizationStat<ElemType>(commandParams);
}
else if (thisAction == "adapt")
{
DoAdapt<ElemType>(commandParams);
}
else if (thisAction == "test" || thisAction == "eval")
{
DoEval<ElemType>(commandParams);
}
else if (thisAction == "edit")
{
DoEdit<ElemType>(commandParams);
}
else if (thisAction == "cv")
{
DoCrossValidate<ElemType>(commandParams);
}
else if (thisAction == "write")
{
DoWriteOutput<ElemType>(commandParams);
}
else if (thisAction == "devtest")
{
TestCn<ElemType>(config); // for "devtest" action pass the root config instead
}
else if (thisAction == "dumpNodes" /*deprecated:*/ || thisAction == "dumpNode" || thisAction == "dumpnode")
{
DoDumpNodes<ElemType>(commandParams);
}
else if (thisAction == "convertdbn")
{
DoConvertFromDbn<ElemType>(commandParams);
}
else if (thisAction == "exportdbn")
{
DoExportToDbn<ElemType>(commandParams);
}
else if (thisAction == "createLabelMap")
{
DoCreateLabelMap<ElemType>(commandParams);
}
else if (thisAction == "writeWordAndClass")
{
DoWriteWordAndClassInfo<ElemType>(commandParams);
}
else if (thisAction == "plot")
{
DoTopologyPlot<ElemType>(commandParams);
}
else if (thisAction == "SVD")
{
DoParameterSVD<ElemType>(commandParams);
}
else
{
RuntimeError("unknown action: %s in command set: %s", thisAction.c_str(), command[i].c_str());
}
}
fprintf(stderr, "\n");
if (traceLevel > 0)
{
LOGPRINTF(stderr, "Action \"%s\" complete.\n\n", thisAction.c_str());
}
NDLScript<ElemType> ndlScript;
ndlScript.ClearGlobal(); // clear global macros between commands
// Synchronize all ranks before proceeding to next action/command
if (mpi)
mpi->WaitAll();
}
}
}
std::string TimeDateStamp()
{
time_t t = time(NULL);
struct tm now = *localtime(&t);
char buf[30];
sprintf(buf, "%04d/%02d/%02d %02d:%02d:%02d", now.tm_year + 1900, now.tm_mon + 1, now.tm_mday, now.tm_hour, now.tm_min, now.tm_sec);
return buf;
}
void PrintUsageInfo()
{
LOGPRINTF(stderr, "-------------------------------------------------------------------\n");
LOGPRINTF(stderr, "Usage: cntk configFile=yourConfigFile\n");
LOGPRINTF(stderr, "For detailed information please consult the CNTK book\n");
LOGPRINTF(stderr, "\"An Introduction to Computational Networks and the Computational Network Toolkit\"\n");
LOGPRINTF(stderr, "-------------------------------------------------------------------\n");
}
// ---------------------------------------------------------------------------
// main() for use with BrainScript as entire config language (this is experimental)
// ---------------------------------------------------------------------------
wstring ConsumeArg(vector<wstring>& args)
{
if (args.empty())
InvalidArgument("Unexpected end of command line.");
wstring arg = args.front();
args.erase(args.begin());
return arg;
}
template <class WHAT>
static void Append(vector<wstring>& toWhat, const WHAT& what)
{
toWhat.insert(toWhat.end(), what.begin(), what.end());
}
static wstring PathToBSStringLiteral(const wstring& path) // quote a pathname for BS
{
let hasSingleQuote = path.find(path, L'\'') != wstring::npos;
let hasDoubleQuote = path.find(path, L'"') != wstring::npos;
if (hasSingleQuote && hasDoubleQuote)
InvalidArgument("Pathname cannot contain both single (') and double (\") quote at the same time: %ls", path.c_str());
else if (hasSingleQuote)
return L"\"" + path + L"\"";
else
return L'"' + path + L'"';
}
// TODO: There is a lot of duplication between this function and the NDL version.
// The code here should be properly refactored to enable sharing.
int wmainWithBS(int argc, wchar_t* argv[]) // called from wmain which is a wrapper that catches & reports Win32 exceptions
{
vector<wstring> args(argv, argv + argc);
let exePath = ConsumeArg(args);
// startup message
// In case of a redirect of stderr, this will be printed twice, once upfront, and once again after the redirect so that it goes into the log file
wstring startupMessage = msra::strfun::wstrprintf(L"running on %ls at %ls\n", msra::strfun::utf16(GetHostName()).c_str(), msra::strfun::utf16(TimeDateStamp()).c_str());
startupMessage += msra::strfun::wstrprintf(L"command line: %ls", exePath.c_str());
for (const auto& arg : args)
startupMessage += L" " + arg;
LOGPRINTF(stderr, "%ls\n", startupMessage.c_str());
// parse command-line options
vector<wstring> sourceFiles;
vector<wstring> includePaths;
vector<wstring> overrides;
wstring workingDir;
while (!args.empty())
{
let option = ConsumeArg(args);
if (option == L"-f" || option == L"--file") // -f defines source files
Append(sourceFiles, msra::strfun::split(ConsumeArg(args), L";"));
else if (option == L"-I") // -I declares an include search path
Append(includePaths, msra::strfun::split(ConsumeArg(args), L";"));
else if (option == L"-D") // -D defines variables inline on the command line (which may override BS)
overrides.push_back(ConsumeArg(args));
else if (option == L"--cd") // --cd sets the working directory
workingDir = ConsumeArg(args);
else
InvalidArgument("Invalid command-line option '%ls'.", option.c_str());
}
// change working directory
if (workingDir != L"")
_wchdir(workingDir.c_str());
// compile the BrainScript
wstring bs = L"[\n";
bs += L"include \'cntk.core.bs'"; // start with including the standard macros
// Note: Using lowercase ^^ here to match the Linux name of the CNTK exe.
for (const auto& sourceFile : sourceFiles)
bs += L"include " + PathToBSStringLiteral(sourceFile) + L"\n";
bs += L"\n]\n";
for (const auto& over : overrides)
bs += L"with [ " + over + L" ]\n";
fprintf(stderr, "\n\n");
LOGPRINTF(stderr, "BrainScript -->\n\n%ls\n\n", bs.c_str());
let expr = BS::ParseConfigExpression(bs, move(includePaths)); // parse
let valp = BS::Evaluate(expr); // evaluate parse into a dictionary
let& config = valp.AsRef<ScriptableObjects::IConfigRecord>(); // this is the dictionary
if (config(L"forceDeterministicAlgorithms", false))
Globals::ForceDeterministicAlgorithms();
if (config(L"forceConstantRandomSeed", false))
Globals::ForceConstantRandomSeed();
#ifndef CPUONLY
auto valpp = config.Find(L"deviceId");
if (valpp)
{
auto valp2 = *valpp;
if (!valp2.Is<ScriptableObjects::String>()) // if it's not string 'auto' or 'cpu', then it's a gpu
{
if (static_cast<int>(valp2) >= 0) // gpu (id >= 0)
{
CheckSupportForGpu(valp2); // throws if gpu is not supported
}
}
}
#endif
// legacy parameters that have changed spelling
if (config.Find(L"DoneFile")) // variables follow camel case (start with lower-case letters)
InvalidArgument("Legacy spelling of 'DoneFile' no longer allowed. Use 'doneFile'.");
if (config.Find(L"command")) // spelling error, should be plural. Using 'actions' instead to match the data type.
InvalidArgument("Legacy spelling of 'command' no longer allowed. Use 'actions'.");
if (config.Find(L"type"))
InvalidArgument("Legacy name 'type' no longer allowed. Use 'precision'.");
// parallel training
shared_ptr<Microsoft::MSR::CNTK::MPIWrapper> mpi;
auto ensureMPIWrapperCleanup = MakeScopeExit(&MPIWrapper::DeleteInstance);
// when running under MPI with more than one node, use 'true' as the default value for parallelTrain,
// 'false' otherwise.
bool paralleltrain = config(L"parallelTrain", (EnvironmentUtil::GetTotalNumberOfMPINodes() > 1));
if (paralleltrain)
{
mpi = MPIWrapper::GetInstance(true /*create*/);
}
Globals::SetShareNodeValueMatrices(config(L"shareNodeValueMatrices", true));
Globals::SetGradientAccumulationOptimization(config(L"optimizeGradientAccumulation", true));
TracingGPUMemoryAllocator::SetTraceLevel(config(L"traceGPUMemoryAllocations", 0));
// logging
wstring logpath = config(L"stderr", L"");
if (logpath != L"")
{
if (paralleltrain && mpi->CurrentNodeRank() != 0)
logpath += msra::strfun::wstrprintf(L".rank%d", (int) mpi->CurrentNodeRank());
RedirectStdErr(logpath, config(L"appendLogFile", false));
LOGPRINTF(stderr, "%ls\n", startupMessage.c_str());
::CNTK::Internal::PrintBuiltInfo();
}
// echo gpu info to log
#ifndef CPUONLY
::CNTK::Internal::PrintGpuInfo(GetAllGpusData());
#endif
// Setup profiling
ProfilerContext profilerContext;
SetupProfiling(profilerContext, config, paralleltrain ? (int)mpi->CurrentNodeRank() : 0);
// execute the actions
// std::string type = config(L"precision", "float");
if (Globals::ShouldForceDeterministicAlgorithms())
ForceDeterministicAlgorithmsOnCPU();
else
{
int numCPUThreads = config(L"numCPUThreads", 0);
numCPUThreads = CPUMatrix<float /*any will do*/>::SetNumThreads(numCPUThreads);
if (numCPUThreads > 0)
LOGPRINTF(stderr, "Using %d CPU threads.\n", numCPUThreads);
}
bool progressTracing = config(L"progressTracing", false);
size_t fullTotalMaxEpochs = 1; // BUGBUG: BS does not allow me to read out the max epochs parameters, as that would instantiate and thus execute the objects
// set up progress tracing for compute cluster management
if (progressTracing && ((mpi == nullptr) || mpi->IsMainNode()))
ProgressTracing::TraceTotalNumberOfSteps(fullTotalMaxEpochs); // enable tracing, using this as the total number of epochs
// MAIN LOOP that executes the actions
auto actionsVal = config[L"actions"];
// Note: weird behavior. If 'actions' is a scalar value (rather than an array) then it will have been resolved already after the above call. That means, it has already completed its action!
// Not pretty, but a direct consequence of the lazy evaluation. The only good solution would be to have a syntax for arrays including length 0 and 1.
// Since this in the end behaves indistinguishable from the array loop below, we will keep it for now.
if (actionsVal.Is<ScriptableObjects::ConfigArray>())
{
const ScriptableObjects::ConfigArray& actions = actionsVal;
for (int i = actions.GetIndexBeginEnd().first; i < actions.GetIndexBeginEnd().second; i++)
{
// TODO: When running in parallel with MPI, only commands in 'commandstoRunOnAllRanks' should
// be run in parallel across multiple ranks. Others should only run on rank 0
actions.At(i, [](const wstring&){}); // this will evaluate and thus execute the action
}
}
// else action has already been executed, see comment above
// write a doneFile if requested
wstring doneFile = config(L"doneFile", L"");
if (doneFile != L"")
{
FILE* fp = fopenOrDie(doneFile.c_str(), L"w");
fprintf(fp, "successfully finished at %s on %s\n", TimeDateStamp().c_str(), GetHostName().c_str());
fcloseOrDie(fp);
}
// TODO: change this back to COMPLETED, double underscores don't look good in output
LOGPRINTF(stderr, "__COMPLETED__\n");
fflush(stderr);
// In case of success, finalizing the mpi if necessary.
if (mpi)
mpi->Finalize();
return EXIT_SUCCESS;
}
// ---------------------------------------------------------------------------
// main() for CNTK config language (this is the current way of using CNTK)
// ---------------------------------------------------------------------------
static void PrintBanner(int argc, wchar_t* argv[], const string& timestamp)
{
#ifndef CNTK_VERSION_BANNER
#error CNTK_VERSION_BANNER must be set
#endif
#define MACRO_TO_STRING(s) #s
fprintf(stderr, "CNTK %s (", MACRO_TO_STRING(CNTK_VERSION_BANNER));
#ifdef _GIT_EXIST
fprintf(stderr, "%s %.6s, ", _BUILDBRANCH_, _BUILDSHA1_);
#endif
fprintf(stderr, "%s %s", __DATE__, __TIME__); // build time
fprintf(stderr, ") at %s\n\n", timestamp.c_str());
for (int i = 0; i < argc; i++)
fprintf(stderr, "%*s%ls", i > 0 ? 2 : 0, "", argv[i]); // use 2 spaces for better visual separability
fprintf(stderr, "\n");
}
// called from wmain which is a wrapper that catches & repots Win32 exceptions
int wmainOldCNTKConfig(int argc, wchar_t* argv[])
{
std::string timestamp = TimeDateStamp();
PrintBanner(argc, argv, timestamp);
ConfigParameters config;
std::string rawConfigString = ConfigParameters::ParseCommandLine(argc, argv, config); // get the command param set they want
int traceLevel = config(L"traceLevel", 0);
#ifndef CPUONLY
ConfigValue val = config("deviceId", "auto");
if (!EqualCI(val, "cpu") && !EqualCI(val, "auto"))
{
if (static_cast<int>(val) >= 0) // gpu (id >= 0)
{
CheckSupportForGpu(static_cast<int>(val)); // throws if gpu is not supported
}
}
#endif
if (config(L"timestamping", false))
ProgressTracing::SetTimestampingFlag();
if (config(L"forceDeterministicAlgorithms", false))
Globals::ForceDeterministicAlgorithms();
if (config(L"forceConstantRandomSeed", false))
Globals::ForceConstantRandomSeed();
// get the command param set they want
wstring logpath = config(L"stderr", L"");
wstring doneFile = config(L"doneFile", L"");
ConfigArray command = config(L"command", "train");
// parallel training
// The top-level 'parallelTrain' is a bool, not to be confused with the parallelTrain block inside SGD.
shared_ptr<Microsoft::MSR::CNTK::MPIWrapper> mpi;
auto ensureMPIWrapperCleanup = MakeScopeExit(&MPIWrapper::DeleteInstance);
// when running under MPI with more than one node, use 'true' as the default value for parallelTrain,
// 'false' otherwise.
bool paralleltrain = config(L"parallelTrain", (EnvironmentUtil::GetTotalNumberOfMPINodes() > 1));
if (paralleltrain)
{
mpi = MPIWrapper::GetInstance(true /*create*/);
}
Globals::SetShareNodeValueMatrices(config(L"shareNodeValueMatrices", true));
Globals::SetGradientAccumulationOptimization(config(L"optimizeGradientAccumulation", true));
TracingGPUMemoryAllocator::SetTraceLevel(config(L"traceGPUMemoryAllocations", 0));
if (logpath != L"")
{
#if 1 // keep the ability to do it how it was done before 1.8; delete if noone needs it anymore
let useOldWay = ProgressTracing::GetTimestampingFlag(); // enable it when running in our server farm
if (useOldWay)
{
for (int i = 0; i < command.size(); i++) // append all 'command' entries
{
logpath += L"_";
logpath += (wstring)command[i];
}
logpath += L".log"; // append .log
}
if (paralleltrain && useOldWay)
{
std::wostringstream oss;
oss << mpi->CurrentNodeRank();
logpath += L"rank" + oss.str();
}
else
#endif
// for MPI workers except main, append .rankN
if (paralleltrain && mpi->CurrentNodeRank() != 0)
logpath += msra::strfun::wstrprintf(L".rank%d", mpi->CurrentNodeRank());
RedirectStdErr(logpath, config(L"appendLogFile", false));
if (traceLevel == 0)
PrintBanner(argc, argv, timestamp); // repeat simple banner into log file
}
// full config info
::CNTK::Internal::PrintBuiltInfo();
#ifndef CPUONLY
::CNTK::Internal::PrintGpuInfo(GetAllGpusData());
#endif
#ifdef _DEBUG
if (traceLevel > 0)
{
// This simply merges all the different config parameters specified (eg, via config files or via command line directly),
// and prints it.
fprintf(stderr, "\nConfiguration, Raw:\n\n");
LOGPRINTF(stderr, "%s\n", rawConfigString.c_str());
// Same as above, but all variables are resolved. If a parameter is set multiple times (eg, set in config, overridden at command line),
// All of these assignments will appear, even though only the last assignment matters.
fprintf(stderr, "\nConfiguration After Variable Resolution:\n\n");
LOGPRINTF(stderr, "%s\n", config.ResolveVariables(rawConfigString).c_str());
}
#endif
SetMathLibTraceLevel(traceLevel);
// This outputs the final value each variable/parameter is assigned to in config (so if a parameter is set multiple times, only the last
// value it is set to will appear).
if (traceLevel > 0)
{
fprintf(stderr, "\nConfiguration After Processing and Variable Resolution:\n\n");
config.dumpWithResolvedVariables();
LOGPRINTF(stderr, "Commands:");
for (int i = 0; i < command.size(); i++)
fprintf(stderr, " %s", command[i].c_str());
fprintf(stderr, "\n");
}
// Setup profiling
ProfilerContext profilerContext;
SetupProfiling(profilerContext, config, paralleltrain ? (int)mpi->CurrentNodeRank() : 0);
// run commands
std::string type = config(L"precision", "float");
// accept old precision key for backward compatibility
if (config.Exists("type"))
InvalidArgument("CNTK: Use of 'type' parameter is deprecated, it is called 'precision' now.");
if (traceLevel > 0)
{
LOGPRINTF(stderr, "precision = \"%s\"\n", type.c_str());
}
if (type == "float")
DoCommands<float>(config, mpi);
else if (type == "double")
DoCommands<double>(config, mpi);
else
RuntimeError("CNTK: Invalid precision string: \"%s\", must be \"float\" or \"double\"", type.c_str());
// if completed then write a doneFile if requested
if (!doneFile.empty())
{
FILE* fp = fopenOrDie(doneFile.c_str(), L"w");
fprintf(fp, "Successfully finished at %s on %s\n", TimeDateStamp().c_str(), GetHostName().c_str());
fcloseOrDie(fp);
}
if (ProgressTracing::GetTimestampingFlag())
{
LOGPRINTF(stderr, "__COMPLETED__\n"); // running in server environment which expects this string
}
else
fprintf(stderr, "COMPLETED.\n");
fflush(stderr);
if (mpi)
mpi->Finalize();
return EXIT_SUCCESS;
}
// new_handler to print call stack upon allocation failure
void AllocationFailureHandler()
{
Microsoft::MSR::CNTK::DebugUtil::PrintCallStack();
std::set_new_handler(nullptr);
throw std::bad_alloc();
}
// ---------------------------------------------------------------------------
// main wrapper that catches C++ exceptions and prints them
// ---------------------------------------------------------------------------
int wmain1(int argc, wchar_t* argv[]) // called from wmain which is a wrapper that catches & reports Win32 exceptions
{
std::set_new_handler(AllocationFailureHandler);
try
{
if (argc <= 1)
{
::CNTK::Internal::PrintBuiltInfo(); // print build info directly in case that user provides zero argument (convenient for checking build type)
LOGPRINTF(stderr, "No command-line argument given.\n");
PrintUsageInfo();
fflush(stderr);
return EXIT_FAILURE;
}
// detect legacy CNTK configuration
bool isOldCNTKConfig = false;
for (int i = 0; i < argc && !isOldCNTKConfig; i++)
isOldCNTKConfig |= !_wcsnicmp(L"configFile=", argv[i], 11);
if (isOldCNTKConfig)
return wmainOldCNTKConfig(argc, argv);
// run from BrainScript
return wmainWithBS(argc, argv);
}
catch (const ScriptableObjects::ScriptingException& err)
{
fprintf(stderr, "\n");
err.PrintError(ProgressTracing::GetTimeStampPrefix() + L"EXCEPTION occurred.");
}
catch (const IExceptionWithCallStackBase& err)
{
fprintf(stderr, "\n");
fprintf(stderr, "%s", err.CallStack());
LOGPRINTF(stderr, "EXCEPTION occurred: %s\n", dynamic_cast<const std::exception&>(err).what());
}
catch (const std::exception& err)
{
fprintf(stderr, "\n");
LOGPRINTF(stderr, "EXCEPTION occurred: %s\n", err.what());
}
catch (...)
{
fprintf(stderr, "\n");
LOGPRINTF(stderr, "Unknown ERROR occurred.\n");
}
fflush(stderr);
return EXIT_FAILURE;
}
#ifdef __WINDOWS__
void TerminateThis()
{
LOGPRINTF(stderr, "terminate_this: aborting.\n");
fflush(stderr);
exit(EXIT_FAILURE);
}
#define EXCEPTION_DLL_NOT_FOUND VcppException(ERROR_SEVERITY_ERROR, ERROR_MOD_NOT_FOUND)
static void LogDelayLoadError(PEXCEPTION_POINTERS pExcPointers)
{
if (pExcPointers->ExceptionRecord->ExceptionCode == EXCEPTION_DLL_NOT_FOUND)
{
const auto & pDelayLoadInfo = *PDelayLoadInfo(pExcPointers->ExceptionRecord->ExceptionInformation[0]);
LOGPRINTF(stderr, "CNTK: Failed to load DLL '%s'.\n", pDelayLoadInfo.szDll);
}
}
#if _DEBUG
// in case of asserts in debug mode, print the message into stderr and throw exception
int HandleDebugAssert(int, // reportType - ignoring reportType, printing message and aborting for all reportTypes
char *message, // message - fully assembled debug user message
int * returnValue) // returnValue - retVal value of zero continues execution
{
fprintf(stderr, "C-Runtime: %s\n", message);
if (returnValue) {
*returnValue = 0; // return value of 0 will continue operation and NOT start the debugger
}
return TRUE; // returning TRUE will make sure no message box is displayed
}
#endif
int wmain(int argc, wchar_t* argv[]) // wmain wrapper that reports Win32 exceptions
{
set_terminate(TerminateThis); // insert a termination handler to ensure stderr gets flushed before actually terminating
__try
{
// in case of asserts in debug mode, print the message into stderr and throw exception
if (_CrtSetReportHook2(_CRT_RPTHOOK_INSTALL, HandleDebugAssert) == -1) {
LOGPRINTF(stderr, "CNTK: _CrtSetReportHook2 failed.\n");
return -1;
}
int mainReturn = wmain1(argc, argv);
_CrtSetReportHook2(_CRT_RPTHOOK_REMOVE, HandleDebugAssert);
return mainReturn;
}
__except (LogDelayLoadError(GetExceptionInformation()), EXCEPTION_EXECUTE_HANDLER)
{
auto code = GetExceptionCode();
const char * msg = "";
if (code == EXCEPTION_ACCESS_VIOLATION) msg = ": Access violation"; // the famous 0xc0000005 error
else if (code == EXCEPTION_INT_DIVIDE_BY_ZERO) msg = ": Integer division by zero";
else if (code == EXCEPTION_STACK_OVERFLOW) msg = ": Stack overflow";
else if (code == EXCEPTION_DLL_NOT_FOUND) msg = ": Module not found";
LOGPRINTF(stderr, "CNTK: Caught Win32 exception 0x%08x%s.\n", (unsigned int)code, msg);
fflush(stderr);
exit(EXIT_FAILURE);
}
}
#endif
#ifdef __UNIX__
/// UNIX main function converts arguments in UTF-8 encoding and passes to Visual-Studio style wmain() which takes wchar_t strings.
int main(int argc, char* argv[])
{
// TODO: change to STL containers
wchar_t** wargs = new wchar_t*[argc];
for (int i = 0; i < argc; ++i)
{
wargs[i] = new wchar_t[strlen(argv[i]) + 1];
size_t ans = ::mbstowcs(wargs[i], argv[i], strlen(argv[i]) + 1);
assert(ans == strlen(argv[i]));
}
int ret = wmain1(argc, wargs);
for (int i = 0; i < argc; ++i)
delete[] wargs[i];
delete[] wargs;
return ret;
}
#endif