// // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE.md file in the project root for full license information. // #pragma once #include "stdafx.h" #include "CNTKLibrary.h" #include "Utils.h" namespace CNTK { const std::wstring versionKey = L"version"; const std::wstring typeKey = L"type"; const std::wstring uidKey = L"uid"; const std::wstring kindKey = L"kind"; const std::wstring dataTypeKey = L"data_type"; const std::wstring dynamicAxisKey = L"dynamic_axis"; const std::wstring isSparseKey = L"is_sparse"; const std::wstring nameKey = L"name"; const std::wstring needsGradientKey = L"needs_gradient"; const std::wstring shapeKey = L"shape"; const std::wstring valueKey = L"value"; const std::wstring opKey = L"op"; const std::wstring attributesKey = L"attributes"; const std::wstring inputsKey = L"inputs"; const std::wstring rootKey = L"root"; const std::wstring functionsKey = L"primitive_functions"; const std::wstring sampleCountKey = L"sample_count"; const std::wstring minibatchCountKey = L"minibatchCount"; // TODO: Python-style spelling const std::wstring sweepCountKey = L"sweepCount"; const std::wstring unitKey = L"unit"; const std::wstring refMBSizeKey = L"ref_mb_size"; const std::wstring epochSizeKey = L"epoch_size"; const std::wstring scheduleKey = L"schedule"; const std::wstring learningRateScheduleKey = L"learnig_rate_schedule"; const std::wstring smoothedGradientsKey = L"smoothed_gradients"; const std::wstring noiseInjectionSeedKey = L"noise_injection_seed"; const std::wstring smoothedCountKey = L"smoothed_count"; const std::wstring stateKey = L"state"; const std::wstring rngSeedKey = L"rng_seed"; const std::wstring rngOffsetKey = L"rng_offset"; const std::wstring blockFunctionCompositeKey = L"block_function_composite"; const std::wstring blockFunctionOpNameKey = L"block_function_op_name"; const std::wstring blockFunctionCompositeArgumentsMapKeysKey = L"block_function_composite_arguments_map_keys"; const std::wstring blockFunctionCompositeArgumentsMapValuesKey = L"block_function_composite_arguments_map_values"; const std::wstring internalWorkerStateKey = L"internal_worker_state"; const std::wstring externalWorkerStateKey = L"external_worker_state"; const std::wstring userDefinedStateKey = L"user_defined_state"; const std::wstring udfModuleNameKey = L"module"; const std::wstring udfFactoryMethodNameKey = L"deserialize_method"; const std::wstring nativeUDFKey = L"native"; template inline std::string GetVersionsString(size_t currentVersion, size_t dictVersion) { std::stringstream info; info << "Current " << Typename() << " version = " << currentVersion << ", Dictionary version = " << dictVersion; return info.str(); } inline size_t GetVersion(const Dictionary& dict) { if (!dict.Contains(versionKey)) LogicError("Required key '%ls' is not found in the dictionary.", versionKey.c_str()); return dict[versionKey].Value(); } template inline void ValidateType(const Dictionary& dict, const std::wstring& typeValue, size_t currentVersion) { if (!dict.Contains(typeKey)) { const auto& version = GetVersion(dict); LogicError("Required key '%ls' is not found in the dictionary (%s).", typeKey.c_str(), GetVersionsString(currentVersion, version).c_str()); } const auto& type = dict[typeKey].Value(); if (type != typeValue) { const auto& version = GetVersion(dict); LogicError("Unexpected '%ls':'%ls' in place of '%ls':'%ls' (%s).", typeKey.c_str(), type.c_str(), typeKey.c_str(), typeValue.c_str(), GetVersionsString(currentVersion, version).c_str()); } } // Make sure that the dictionary contains all required keys, and if it does, return version value // from the dictionary. template inline size_t ValidateDictionary(const Dictionary& dict, const std::vector& requiredKeys, const std::wstring& typeValue, size_t currentVersion) { const auto& version = GetVersion(dict); for (const auto& key : requiredKeys) { if (!dict.Contains(key)) { LogicError("Required key '%ls' is not found in the dictionary (%s).", key.c_str(), GetVersionsString(currentVersion, version).c_str()); } } ValidateType(dict, typeValue, currentVersion); return version; } }