// Basics.h -- some shared generally useful pieces of code used by CNTK // // We also include a simple "emulation" layer for some proprietary MSVC CRT functions. #pragma once #ifndef _BASICS_H_ #define _BASICS_H_ #include "Platform.h" #include "ExceptionWithCallStack.h" #include #include #include #include #include #ifdef _WIN32 #include #undef max #endif #if __unix__ #include // for Plugin #endif #include #include #define TWO_PI 6.283185307f // TODO: find the official standards-confirming definition of this and use it instead #define EPSILON 1e-5 #define ISCLOSE(a, b, threshold) (std::abs(a - b) < threshold) ? true : false #define DLCLOSE_SUCCESS 0 #define UNUSED(x) (void)(x) // for variables that are, e.g., only used in _DEBUG builds #pragma warning(disable : 4702) // disable some incorrect unreachable-code warnings #define DISABLE_COPY_AND_MOVE(TypeName) \ TypeName(const TypeName&) = delete; \ TypeName& operator=(const TypeName&) = delete; \ TypeName(TypeName&&) = delete; \ TypeName& operator=(TypeName&&) = delete #ifndef let #define let const auto // let x = ... ; let & r = ... #endif namespace Microsoft { namespace MSR { namespace CNTK { using namespace std; // ----------------------------------------------------------------------- // ThrowFormatted() - template function to throw a std::exception with a formatted error string // ----------------------------------------------------------------------- template __declspec_noreturn static inline void ThrowFormattedVA(const char* format, va_list args) { // Note: The call stack will skip 2 levels to suppress this function and its call sites (XXXError()). // If more layers are added here, it would have to be adjusted. auto callstack = DebugUtil::GetCallStack(/*skipLevels=*/2, /*makeFunctionNamesStandOut=*/true); va_list args_copy; va_copy(args_copy, args); auto size = vsnprintf(nullptr, 0, format, args) + 1; string buffer("Unknown error."); if (size > 0) { buffer = string(size, ' '); if (vsnprintf(&buffer[0], size, format, args_copy) < 0) { buffer = string("Unknown error."); } } va_end(args_copy); #ifdef _DEBUG // print this to log, so we can see what the error is before throwing fprintf(stderr, "\nAbout to throw exception '%s'\n", buffer.c_str()); #endif throw ExceptionWithCallStack(buffer, callstack); } #ifndef _MSC_VER // TODO: what is the correct trigger for gcc? template __declspec_noreturn void ThrowFormatted(const char* format, ...) __attribute__((format(printf, 1, 2))); #endif template __declspec_noreturn static inline void ThrowFormatted(const char* format, ...) { va_list args; va_start(args, format); ThrowFormattedVA(format, args); va_end(args); }; // RuntimeError - throw a std::runtime_error with a formatted error string #ifndef _MSC_VER // gcc __attribute__((format(printf())) does not percolate through variadic templates; so must go the macro route #ifndef RuntimeError #define RuntimeError ThrowFormatted #endif #ifndef LogicError #define LogicError ThrowFormatted #endif #ifndef InvalidArgument #define InvalidArgument ThrowFormatted #endif #else template __declspec_noreturn static inline void RuntimeError(const char* format, _Types&&... _Args) { ThrowFormatted(format, forward<_Types>(_Args)...); } template __declspec_noreturn static inline void LogicError(const char* format, _Types&&... _Args) { ThrowFormatted(format, forward<_Types>(_Args)...); } template __declspec_noreturn static inline void InvalidArgument(const char* format, _Types&&... _Args) { ThrowFormatted(format, forward<_Types>(_Args)...); } #endif // Warning - warn with a formatted error string #pragma warning(push) #pragma warning(disable : 4996) static inline void Warning(const char* format, ...) { va_list args; char buffer[1024]; va_start(args, format); vsprintf(buffer, format, args); va_end(args); }; #pragma warning(pop) static inline void Warning(const string& message) { Warning("%s", message.c_str()); } #ifndef NOT_IMPLEMENTED #define NOT_IMPLEMENTED \ \ { \ fprintf(stderr, "Inside File: %s Line: %d Function: %s -> Feature Not Implemented.\n", __FILE__, __LINE__, __FUNCTION__); \ LogicError("Inside File: %s Line: %d Function: %s -> Feature Not Implemented.", __FILE__, __LINE__, __FUNCTION__); \ \ } #endif // Computes the smallest multiple of k greater or equal to n inline size_t AsMultipleOf(size_t n, size_t k) { return n + (k - n % k) % k; } }}} #ifndef _MSC_VER using Microsoft::MSR::CNTK::ThrowFormatted; #else using Microsoft::MSR::CNTK::RuntimeError; using Microsoft::MSR::CNTK::LogicError; using Microsoft::MSR::CNTK::InvalidArgument; #endif #ifdef _MSC_VER #include // std::codecvt_utf8 #endif namespace msra { namespace strfun { // TODO: rename this // ---------------------------------------------------------------------------- // (w)cstring -- helper class like std::string but with auto-cast to char* // and also implements an sprintf variant for STL strings // ---------------------------------------------------------------------------- // a class that can return a std::string with auto-convert into a const char* template struct basic_cstring : public std::basic_string { template basic_cstring(S p) : std::basic_string(p) { } operator const C*() const { return this->c_str(); } }; typedef basic_cstring cstring; typedef basic_cstring wcstring; // [w]strprintf() -- like sprintf() but resulting in a C++ string template struct _strprintf : public std::basic_string<_T> { // works for both wchar_t* and char* _strprintf(const _T* format, ...) { va_list args; va_start(args, format); // varargs stuff size_t n = _cprintf(format, args); // num chars excl. '\0' va_end(args); va_start(args, format); const int FIXBUF_SIZE = 128; // incl. '\0' if (n < FIXBUF_SIZE) { _T fixbuf[FIXBUF_SIZE]; this->assign(_sprintf(&fixbuf[0], sizeof(fixbuf) / sizeof(*fixbuf), format, args), n); } else // too long: use dynamically allocated variable-size buffer { std::vector<_T> varbuf(n + 1); // incl. '\0' this->assign(_sprintf(&varbuf[0], varbuf.size(), format, args), n); } va_end(args); } private: // helpers inline size_t _cprintf(const wchar_t* format, va_list args) { #ifdef _MSC_VER return _vscwprintf(format, args); #elif defined(__UNIX__) // TODO: Really??? Write to file in order to know the length of a string? FILE* dummyf = fopen("/dev/null", "w"); if (dummyf == NULL) perror("The following error occurred in basetypes.h:cprintf"); int n = vfwprintf(dummyf, format, args); if (n < 0) perror("The following error occurred in basetypes.h:cprintf"); fclose(dummyf); return n; #endif } inline size_t _cprintf(const char* format, va_list args) { #ifdef _MSC_VER return _vscprintf(format, args); #elif defined(__UNIX__) // TODO: Really??? Write to file in order to know the length of a string? FILE* dummyf = fopen("/dev/null", "wb"); if (dummyf == NULL) perror("The following error occurred in basetypes.h:cprintf"); int n = vfprintf(dummyf, format, args); if (n < 0) perror("The following error occurred in basetypes.h:cprintf"); fclose(dummyf); return n; #endif } inline const wchar_t* _sprintf(wchar_t* buf, size_t bufsiz, const wchar_t* format, va_list args) { vswprintf(buf, bufsiz, format, args); return buf; } inline const char* _sprintf(char* buf, size_t bufsiz, const char* format, va_list args) { #ifdef _MSC_VER vsprintf_s(buf, bufsiz, format, args); #else vsprintf(buf, format, args); #endif return buf; } }; // ---------------------------------------------------------------------------- // (w)strprintf() -- sprintf() that returns an STL string // ---------------------------------------------------------------------------- typedef ::msra::strfun::_strprintf strprintf; // char version typedef ::msra::strfun::_strprintf wstrprintf; // wchar_t version // ---------------------------------------------------------------------------- // utf8(), utf16() -- convert between narrow and wide strings // ---------------------------------------------------------------------------- #ifdef _MSC_VER // string-encoding conversion functions struct utf8 : std::string { utf8(const std::wstring& p) // utf-16 to -8 { size_t len = p.length(); if (len == 0) { return; } // empty string std::vector buf(3 * len + 1); // max: 1 wchar => up to 3 mb chars // ... TODO: this fill() should be unnecessary (a 0 is appended)--but verify std::fill(buf.begin(), buf.end(), (char)0); int rc = WideCharToMultiByte(CP_UTF8, 0, p.c_str(), (int) len, &buf[0], (int) buf.size(), NULL, NULL); if (rc == 0) RuntimeError("WideCharToMultiByte"); (*(std::string*) this) = &buf[0]; } }; struct utf16 : std::wstring { utf16(const std::string& p) // utf-8 to -16 { size_t len = p.length(); if (len == 0) { return; } // empty string std::vector buf(len + 1); // ... TODO: this fill() should be unnecessary (a 0 is appended)--but verify std::fill(buf.begin(), buf.end(), (wchar_t) 0); int rc = MultiByteToWideChar(CP_UTF8, 0, p.c_str(), (int) len, &buf[0], (int) buf.size()); if (rc == 0) RuntimeError("MultiByteToWideChar"); assert(rc < buf.size()); (*(std::wstring*) this) = &buf[0]; } }; #endif #ifndef _MSC_VER // these are needed by the gcc conversion functions // Note: generally, 8-bit strings in this codebase are UTF-8. // One exception are functions that take 8-bit pathnames. Those will be interpreted by the OS as MBS. Best use wstring pathnames for all file accesses. #pragma warning(push) #pragma warning(disable : 4996) // Reviewed by Yusheng Li, March 14, 2006. depr. fn (wcstombs, mbstowcs) static inline std::string wcstombs(const std::wstring& p) // output: MBCS { size_t len = p.length(); std::vector buf(2 * len + 1); // max: 1 wchar => 2 mb chars std::fill(buf.begin(), buf.end(), (char)0); ::wcstombs(&buf[0], p.c_str(), 2 * len + 1); return std::string(&buf[0]); } static inline std::wstring mbstowcs(const std::string& p) // input: MBCS { size_t len = p.length(); std::vector buf(len + 1); // max: >1 mb chars => 1 wchar std::fill(buf.begin(), buf.end(), (wchar_t) 0); // OACR_WARNING_SUPPRESS(UNSAFE_STRING_FUNCTION, "Reviewed OK. size checked. [rogeryu 2006/03/21]"); ::mbstowcs(&buf[0], p.c_str(), len + 1); return std::wstring(&buf[0]); } #pragma warning(pop) #endif #ifdef _MSC_VER static inline cstring utf8(const std::wstring& p) { return std::wstring_convert>().to_bytes(p); } // utf-16 to -8 static inline wcstring utf16(const std::string& p) { return std::wstring_convert>().from_bytes(p); } // utf-8 to -16 #else // BUGBUG: we cannot compile the above on Cygwin GCC, so for now fake it using the mbs functions, which will only work for 7-bit ASCII strings static inline std::string utf8(const std::wstring& p) { return msra::strfun::wcstombs(p.c_str()); } // output: UTF-8... not really static inline std::wstring utf16(const std::string& p) { return msra::strfun::mbstowcs(p.c_str()); } // input: UTF-8... not really #endif static inline cstring utf8(const std::string& p) { return p; } // no conversion (useful in templated functions) static inline wcstring utf16(const std::wstring& p) { return p; } // ---------------------------------------------------------------------------- // charpath() -- convert a wchar_t path to what gets passed to CRT functions that take narrow characters // This is needed for the Linux CRT which does not accept wide-char strings for pathnames anywhere. // Always use this function for mapping the paths. // TODO: This does not seem to work well, most places use wtocharpath() instead. Maybe we can remove this. // ---------------------------------------------------------------------------- static inline cstring charpath(const std::wstring& p) { #ifdef _MSC_VER return std::wstring_convert>().to_bytes(p); #else // old version, delete once we know it works size_t len = p.length(); std::vector buf(2 * len + 1, 0); // max: 1 wchar => 2 mb chars ::wcstombs(buf.data(), p.c_str(), 2 * len + 1); return msra::strfun::cstring(&buf[0]); #endif } // ---------------------------------------------------------------------------- // split and join -- tokenize a string like strtok() would, join() strings together // ---------------------------------------------------------------------------- template static inline std::vector> split(const std::basic_string<_T>& s, const _T* delim) { std::vector> res; for (size_t st = s.find_first_not_of(delim); st != std::basic_string<_T>::npos;) { size_t en = s.find_first_of(delim, st + 1); if (en == std::basic_string<_T>::npos) en = s.length(); res.push_back(s.substr(st, en - st)); st = s.find_first_not_of(delim, en + 1); // may exceed } return res; } template static inline std::basic_string<_T> join(const std::vector>& a, const _T* delim) { std::basic_string<_T> res; for (int i = 0; i < (int) a.size(); i++) { if (i > 0) res.append(delim); res.append(a[i]); } return res; } // ---------------------------------------------------------------------------- // find and replace // ---------------------------------------------------------------------------- template // actual operations that we perform static String ReplaceAll(const String& s, const String& what, const String& withwhat) { String res = s; auto pos = res.find(what); while (pos != String::npos) { res = res.substr(0, pos) + withwhat + res.substr(pos + what.size()); pos = res.find(what, pos + withwhat.size()); } return res; } // ---------------------------------------------------------------------------- // parsing strings to numbers // ---------------------------------------------------------------------------- static inline int toint(const wchar_t* s) { return _wtoi(s); } static inline int toint(const char* s) { return atoi(s); } static inline int toint(const std::wstring& s) { return toint(s.c_str()); } static inline double todouble(const char* s) { char* ep; // will be set to point to first character that failed parsing double value = strtod(s, &ep); if (*s == 0 || *ep != 0) RuntimeError("todouble: invalid input string '%s'", s); return value; } // TODO: merge this with todouble(const char*) above static inline double todouble(const std::string& s) { double value = 0.0; size_t* idx = 0; value = std::stod(s, idx); if (idx) RuntimeError("todouble: invalid input string '%s'", s.c_str()); return value; } static inline double todouble(const std::wstring& s) { wchar_t* endptr; double value = wcstod(s.c_str(), &endptr); if (*endptr) RuntimeError("todouble: invalid input string '%ls'", s.c_str()); return value; } // ---------------------------------------------------------------------------- // tokenizer -- utility for white-space tokenizing strings in a character buffer // This simple class just breaks a string, but does not own the string buffer. // ---------------------------------------------------------------------------- class tokenizer : public std::vector { const char* delim; public: tokenizer(const char* delim, size_t cap) : delim(delim) { reserve(cap); } // Usage: tokenizer tokens (delim, capacity); tokens = buf; tokens.size(), tokens[i] void operator=(char* buf) { resize(0); // strtok_s not available on all platforms - so backoff to strtok on those #ifdef _MSC_VER char* context; // for strtok_s() for (char* p = strtok_s(buf, delim, &context); p; p = strtok_s(NULL, delim, &context)) push_back(p); #else for (char* p = strtok(buf, delim); p; p = strtok(NULL, delim)) push_back(p); #endif } }; }} // ---------------------------------------------------------------------------- // iscalpha(), iscspace(), etc.: saner version of ctype helpers that do not blow up upon negative signed char values // Name differs in using 'c' between is- and -what(). Also returns bool instead of int. // TODO: switch all uses if isspace() etc. to this once tested well // ---------------------------------------------------------------------------- #define DefineIsCType(pred) \ static inline bool isc ## pred(char c) { return !!is ## pred((unsigned char) c); } \ static inline bool isc ## pred(wchar_t c) { return !!isw ## pred(c); } DefineIsCType(alpha); DefineIsCType(upper); DefineIsCType(lower); DefineIsCType(cntrl); DefineIsCType(digit); DefineIsCType(punct); DefineIsCType(space); // ---------------------------------------------------------------------------- // functional-programming style helper macros (...do this with templates?) // ---------------------------------------------------------------------------- #define foreach_index(_i, _dat) for (int _i = 0; _i < (int) (_dat).size(); _i++) #define map_array(_x, _expr, _y) \ { \ _y.resize(_x.size()); \ foreach_index (_i, _x) \ _y[_i] = _expr(_x[_i]); \ } #define reduce_array(_x, _expr, _y) \ { \ foreach_index (_i, _x) \ _y = (_i == 0) ? _x[_i] : _expr(_y, _x[_i]); \ } namespace Microsoft { namespace MSR { namespace CNTK { // ---------------------------------------------------------------------------- // case-insensitive string-comparison helpers // ---------------------------------------------------------------------------- // normalize between char* and string // Note: Intended for use within a single expression. Otherwise be aware of memory ownership; same restrictions apply as for string::c_str(). static inline const char * c_str(const char * p) { return p; } static inline const char * c_str(const string & p) { return p.c_str(); } static inline const wchar_t * c_str(const wchar_t * p) { return p; } static inline const wchar_t * c_str(const wstring & p) { return p.c_str(); } // compare strings static inline int CompareCI(const char * a, const char * b) { return _stricmp(a, b); } static inline int CompareCI(const wchar_t * a, const wchar_t * b) { return _wcsicmp(a, b); } template static inline int CompareCI(const S1 & a, const S2 & b) { return CompareCI(c_str(a), c_str(b)); } // compare for equality template static inline bool EqualCI(const S1 & a, const S2 & b) { return CompareCI(a, b) == 0; } // comparer class for defining maps with case-insensitive key lookup struct nocase_compare { // std::string version of 'less' function template bool operator()(const S1& left, const S2& right) const { return CompareCI(left, right) < 0; } }; // ---------------------------------------------------------------------------- // random collection of stuff we needed at some place // ---------------------------------------------------------------------------- // Array class // Wrapper that holds pointer to data, as well as size template class ArrayRef { T* elements; // Array of type T size_t count; public: ArrayRef(T* elementsIn, size_t sizeIn) { elements = elementsIn; count = sizeIn; } // TODO: Copy Constructor ArrayRef(const ArrayRef& other) = delete; // TODO: Move Constructor ArrayRef(ArrayRef&& other) = delete; // TODO: Assignment operator ArrayRef& operator=(const ArrayRef& rhs) = delete; // TODO: Move assignment operator ArrayRef& operator=(ArrayRef&& rhs) = delete; size_t size() const { return count; } void setSize(size_t size) { count = size; } T* data() const { return elements; } T operator[](size_t i) const { if (i >= size()) LogicError("ArrayRef: index overflow"); return elements[i]; } T& operator[](size_t i) { if (i >= count) LogicError("ArrayRef: index overflow"); return elements[i]; } const T* begin() const { return data(); } const T* end() const { return data() + size(); } }; // TODO: maybe change to type id of an actual thing we pass in // TODO: is this header appropriate? template static wstring TypeId() { return msra::strfun::utf16(typeid(C).name()); } // ---------------------------------------------------------------------------- // dynamic loading of modules --TODO: not Basics, should move to its own header // ---------------------------------------------------------------------------- #ifdef _WIN32 class Plugin { HMODULE m_hModule; // module handle for the writer DLL std::wstring m_dllName; // name of the writer DLL public: Plugin() : m_hModule(NULL) { } template // accepts char (UTF-8) and wide string FARPROC Load(const STRING& plugin, const std::string& proc, bool isCNTKPlugin = true) { return LoadInternal(msra::strfun::utf16(plugin), proc, isCNTKPlugin); } ~Plugin() { } // we do not unload because this causes the exception messages to be lost (exception vftables are unloaded when DLL is unloaded) // ~Plugin() { if (m_hModule) FreeLibrary(m_hModule); } private: FARPROC LoadInternal(const std::wstring& plugin, const std::string& proc, bool isCNTKPlugin); }; #else class Plugin { private: void* handle; public: Plugin() : handle(NULL) { } template // accepts char (UTF-8) and wide string void *Load(const STRING& plugin, const std::string& proc, bool isCNTKPlugin = true) { return LoadInternal(msra::strfun::utf8(plugin), proc, isCNTKPlugin); } ~Plugin() { if (handle != NULL) { int rc = dlclose(handle); if ((rc != DLCLOSE_SUCCESS) && !std::uncaught_exception()) { RuntimeError("Plugin: Failed to decrements the reference count."); } } } private: void *LoadInternal(const std::string& plugin, const std::string& proc, bool isCNTKPlugin); }; #endif template struct ScopeExit { explicit ScopeExit(EF &&f) : m_exitFunction(std::move(f)), m_exitOnDestruction(true) {} ~ScopeExit() { if (m_exitOnDestruction) m_exitFunction(); } ScopeExit(ScopeExit&& other) : m_exitFunction(std::move(other.m_exitFunction)), m_exitOnDestruction(other.m_exitOnDestruction) { other.m_exitOnDestruction = false; } private: // Disallow copy construction, assignment ScopeExit(const ScopeExit&) = delete; ScopeExit& operator=(const ScopeExit&) = delete; // Disallow move assignment ScopeExit& operator=(ScopeExit&&) = delete; EF m_exitFunction; bool m_exitOnDestruction; }; template ScopeExit::type> MakeScopeExit(EF&& exitFunction) { return ScopeExit::type>(std::forward(exitFunction)); } }}} #ifdef _WIN32 // ---------------------------------------------------------------------------- // frequently missing Win32 functions // ---------------------------------------------------------------------------- // strerror() for Win32 error codes static inline std::wstring FormatWin32Error(DWORD error) { wchar_t buf[1024] = {0}; ::FormatMessageW(FORMAT_MESSAGE_FROM_SYSTEM, "", error, 0, buf, sizeof(buf) / sizeof(*buf) - 1, NULL); std::wstring res(buf); // eliminate newlines (and spaces) from the end size_t last = res.find_last_not_of(L" \t\r\n"); if (last != std::string::npos) res.erase(last + 1, res.length()); return res; } #endif // _WIN32 #endif // _BASICS_H_