https://github.com/Microsoft/CNTK
Raw File
Tip revision: d16c5c2b9e766d703fec5deb90c78858dd8c3786 authored by thhoens on 21 January 2016, 20:07:36 UTC
Added a small patch to SGD to turn off Batch Normlaization during validation.
Tip revision: d16c5c2
Basics.h
// Basics.h -- some shared generally useful pieces of code used by CNTK
//
// We also include a simple "emulation" layer for some proprietary MSVC CRT functions.

#pragma once

#ifndef _BASICS_H_
#define _BASICS_H_

#include "Platform.h"
#include "DebugUtil.h"
#include <string>
#include <vector>
#include <assert.h>
#if __unix__
#include <dlfcn.h> // for Plugin
#endif

#define TWO_PI 6.283185307f // TODO: find the official standards-confirming definition of this and use it instead

#define EPSILON 1e-5
#define ISCLOSE(a, b, threshold) (abs(a - b) < threshold) ? true : false

#define UNUSED(x) (void)(x) // for variables that are, e.g., only used in _DEBUG builds

#pragma warning(disable : 4702) // disable some incorrect unreachable-code warnings

#define DISABLE_COPY_AND_MOVE(TypeName)            \
    TypeName(const TypeName&) = delete;            \
    TypeName& operator=(const TypeName&) = delete; \
    TypeName(TypeName&&) = delete;                 \
    TypeName& operator=(TypeName&&) = delete

namespace Microsoft { namespace MSR { namespace CNTK {

using namespace std;

// -----------------------------------------------------------------------
// ThrowFormatted() - template function to throw a std::exception with a formatted error string
// -----------------------------------------------------------------------

#pragma warning(push)
#pragma warning(disable : 4996)
#ifndef _MSC_VER // TODO: what is the correct trigger for gcc?
template <class E>
__declspec_noreturn void ThrowFormatted(const char* format, ...) __attribute__((format(printf, 1, 2)));
#endif
template <class E>
__declspec_noreturn static inline void ThrowFormatted(const char* format, ...)
{
    va_list args;
    char buffer[1024];

    va_start(args, format);
    vsprintf(buffer, format, args);
    Microsoft::MSR::CNTK::DebugUtil::PrintCallStack();
#ifdef _DEBUG // print this to log before throwing, so we can see what the error is
    fprintf(stderr, "About to throw exception '%s'\n", buffer);
#endif
    throw E(buffer);
};
#pragma warning(pop)

// RuntimeError - throw a std::runtime_error with a formatted error string
#ifndef _MSC_VER // gcc __attribute__((format(printf())) does not percolate through variadic templates; so must go the macro route
#define RuntimeError ThrowFormatted<std::runtime_error>
#define LogicError ThrowFormatted<std::logic_error>
#define InvalidArgument ThrowFormatted<std::invalid_argument>
#define BadExceptionError(...) throw std::bad_exception() // ThrowFormatted<std::bad_exception> does not exist on gcc
#else
template <class... _Types>
__declspec_noreturn static inline void RuntimeError(const char* format, _Types&&... _Args)
{
    ThrowFormatted<std::runtime_error>(format, forward<_Types>(_Args)...);
}
template <class... _Types>
__declspec_noreturn static inline void LogicError(const char* format, _Types&&... _Args)
{
    ThrowFormatted<std::logic_error>(format, forward<_Types>(_Args)...);
}
template <class... _Types>
__declspec_noreturn static inline void InvalidArgument(const char* format, _Types&&... _Args)
{
    ThrowFormatted<std::invalid_argument>(format, forward<_Types>(_Args)...);
}
template <class... _Types>
__declspec_noreturn static inline void BadExceptionError(const char* format, _Types&&... _Args)
{
    ThrowFormatted<std::bad_exception>(format, forward<_Types>(_Args)...);
}
#endif

// Warning - warn with a formatted error string
#pragma warning(push)
#pragma warning(disable : 4996)
static inline void Warning(const char* format, ...)
{
    va_list args;
    char buffer[1024];

    va_start(args, format);
    vsprintf(buffer, format, args);
};
#pragma warning(pop)
static inline void Warning(const string& message)
{
    Warning("%s", message.c_str());
}

#ifndef NOT_IMPLEMENTED
#define NOT_IMPLEMENTED                                                                                                              \
    \
{                                                                                                                             \
        fprintf(stderr, "Inside File: %s  Line: %d  Function: %s  -> Feature Not Implemented.\n", __FILE__, __LINE__, __FUNCTION__); \
        LogicError("Inside File: %s  Line: %d  Function: %s  -> Feature Not Implemented.\n", __FILE__, __LINE__, __FUNCTION__);      \
    \
}
#endif
}
}
}

#ifndef _MSC_VER
using Microsoft::MSR::CNTK::ThrowFormatted;
#else
using Microsoft::MSR::CNTK::RuntimeError;
using Microsoft::MSR::CNTK::LogicError;
using Microsoft::MSR::CNTK::InvalidArgument;
using Microsoft::MSR::CNTK::BadExceptionError;
#endif

#ifdef _MSC_VER
#include <codecvt> // std::codecvt_utf8
#endif

namespace msra {

namespace strfun
{ // TODO: rename this

// ----------------------------------------------------------------------------
// (w)cstring -- helper class like std::string but with auto-cast to char*
// and also implements an sprintf variant for STL strings
// ----------------------------------------------------------------------------

// a class that can return a std::string with auto-convert into a const char*
template <typename C>
struct basic_cstring : public std::basic_string<C>
{
    template <typename S>
    basic_cstring(S p)
        : std::basic_string<C>(p)
    {
    }
    operator const C*() const
    {
        return this->c_str();
    }
};
typedef basic_cstring<char> cstring;
typedef basic_cstring<wchar_t> wcstring;

// [w]strprintf() -- like sprintf() but resulting in a C++ string
template <class _T>
struct _strprintf : public std::basic_string<_T>
{ // works for both wchar_t* and char*
    _strprintf(const _T* format, ...)
    {
        va_list args;
        va_start(args, format);            // varargs stuff
        size_t n = _cprintf(format, args); // num chars excl. '\0'
        va_end(args);
        va_start(args, format);
        const int FIXBUF_SIZE = 128; // incl. '\0'
        if (n < FIXBUF_SIZE)
        {
            _T fixbuf[FIXBUF_SIZE];
            this->assign(_sprintf(&fixbuf[0], sizeof(fixbuf) / sizeof(*fixbuf), format, args), n);
        }
        else // too long: use dynamically allocated variable-size buffer
        {
            std::vector<_T> varbuf(n + 1); // incl. '\0'
            this->assign(_sprintf(&varbuf[0], varbuf.size(), format, args), n);
        }
    }

private:
    // helpers
    inline size_t _cprintf(const wchar_t* format, va_list args)
    {
#ifdef _MSC_VER
        return vswprintf(nullptr, 0, format, args);
#elif defined(__UNIX__)
        // TODO: Really??? Write to file in order to know the length of a string?
        FILE* dummyf = fopen("/dev/null", "w");
        if (dummyf == NULL)
            perror("The following error occurred in basetypes.h:cprintf");
        int n = vfwprintf(dummyf, format, args);
        if (n < 0)
            perror("The following error occurred in basetypes.h:cprintf");
        fclose(dummyf);
        return n;
#endif
    }
    inline size_t _cprintf(const char* format, va_list args)
    {
#ifdef _MSC_VER
        return vsprintf_s(nullptr, 0, format, args);
#elif defined(__UNIX__)
        // TODO: Really??? Write to file in order to know the length of a string?
        FILE* dummyf = fopen("/dev/null", "wb");
        if (dummyf == NULL)
            perror("The following error occurred in basetypes.h:cprintf");
        int n = vfprintf(dummyf, format, args);
        if (n < 0)
            perror("The following error occurred in basetypes.h:cprintf");
        fclose(dummyf);
        return n;
#endif
    }
    inline const wchar_t* _sprintf(wchar_t* buf, size_t bufsiz, const wchar_t* format, va_list args)
    {
        vswprintf(buf, bufsiz, format, args);
        return buf;
    }
    inline const char* _sprintf(char* buf, size_t bufsiz, const char* format, va_list args)
    {
#ifdef _MSC_VER
        vsprintf_s(buf, bufsiz, format, args);
#else
        vsprintf(buf, format, args);
#endif
        return buf;
    }
};

// ----------------------------------------------------------------------------
// (w)strprintf() -- sprintf() that returns an STL string
// ----------------------------------------------------------------------------

typedef strfun::_strprintf<char> strprintf;     // char version
typedef strfun::_strprintf<wchar_t> wstrprintf; // wchar_t version

// ----------------------------------------------------------------------------
// utf8(), utf16() -- convert between narrow and wide strings
// ----------------------------------------------------------------------------

#ifdef _MSC_VER
// string-encoding conversion functions
struct utf8 : std::string
{
    utf8(const std::wstring& p) // utf-16 to -8
    {
        size_t len = p.length();
        if (len == 0)
        {
            return;
        }                                   // empty string
        std::vector<char> buf(3 * len + 1); // max: 1 wchar => up to 3 mb chars
        // ... TODO: this fill() should be unnecessary (a 0 is appended)--but verify
        std::fill(buf.begin(), buf.end(), 0);
        int rc = WideCharToMultiByte(CP_UTF8, 0, p.c_str(), (int) len,
                                     &buf[0], (int) buf.size(), NULL, NULL);
        if (rc == 0)
            RuntimeError("WideCharToMultiByte");
        (*(std::string*) this) = &buf[0];
    }
};
struct utf16 : std::wstring
{
    utf16(const std::string& p) // utf-8 to -16
    {
        size_t len = p.length();
        if (len == 0)
        {
            return;
        } // empty string
        std::vector<wchar_t> buf(len + 1);
        // ... TODO: this fill() should be unnecessary (a 0 is appended)--but verify
        std::fill(buf.begin(), buf.end(), (wchar_t) 0);
        int rc = MultiByteToWideChar(CP_UTF8, 0, p.c_str(), (int) len,
                                     &buf[0], (int) buf.size());
        if (rc == 0)
            RuntimeError("MultiByteToWideChar");
        assert(rc < buf.size());
        (*(std::wstring*) this) = &buf[0];
    }
};
#endif

#ifndef _MSC_VER // these are needed by the gcc conversion functions
// Note: generally, 8-bit strings in this codebase are UTF-8.
// One exception are functions that take 8-bit pathnames. Those will be interpreted by the OS as MBS. Best use wstring pathnames for all file accesses.
#pragma warning(push)
#pragma warning(disable : 4996)                           // Reviewed by Yusheng Li, March 14, 2006. depr. fn (wcstombs, mbstowcs)
static inline std::string wcstombs(const std::wstring& p) // output: MBCS
{
    size_t len = p.length();
    std::vector<char> buf(2 * len + 1); // max: 1 wchar => 2 mb chars
    std::fill(buf.begin(), buf.end(), 0);
    ::wcstombs(&buf[0], p.c_str(), 2 * len + 1);
    return std::string(&buf[0]);
}
static inline std::wstring mbstowcs(const std::string& p) // input: MBCS
{
    size_t len = p.length();
    std::vector<wchar_t> buf(len + 1); // max: >1 mb chars => 1 wchar
    std::fill(buf.begin(), buf.end(), (wchar_t) 0);
    //OACR_WARNING_SUPPRESS(UNSAFE_STRING_FUNCTION, "Reviewed OK. size checked. [rogeryu 2006/03/21]");
    ::mbstowcs(&buf[0], p.c_str(), len + 1);
    return std::wstring(&buf[0]);
}
#pragma warning(pop)
#endif

#ifdef _MSC_VER
static inline cstring utf8(const std::wstring& p)
{
    return std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>>().to_bytes(p);
} // utf-16 to -8
static inline wcstring utf16(const std::string& p)
{
    return std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>>().from_bytes(p);
} // utf-8 to -16
#else // BUGBUG: we cannot compile the above on Cygwin GCC, so for now fake it using the mbs functions, which will only work for 7-bit ASCII strings
static inline std::string utf8(const std::wstring& p)
{
    return msra::strfun::wcstombs(p.c_str());
} // output: UTF-8... not really
static inline std::wstring utf16(const std::string& p)
{
    return msra::strfun::mbstowcs(p.c_str());
} // input: UTF-8... not really
#endif
static inline cstring utf8(const std::string& p)
{
    return p;
} // no conversion (useful in templated functions)
static inline wcstring utf16(const std::wstring& p)
{
    return p;
}

// ----------------------------------------------------------------------------
// charpath() -- convert a wchar_t path to what gets passed to CRT functions that take narrow characters
// This is needed for the Linux CRT which does not accept wide-char strings for pathnames anywhere.
// Always use this function for mapping the paths.
// TODO: This does not seem to work well, most places use wtocharpath() instead. Maybe we can remove this.
// ----------------------------------------------------------------------------

static inline cstring charpath(const std::wstring& p)
{
#ifdef _MSC_VER
    return std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>>().to_bytes(p);
#else // old version, delete once we know it works
    size_t len = p.length();
    std::vector<char> buf(2 * len + 1, 0); // max: 1 wchar => 2 mb chars
    ::wcstombs(buf.data(), p.c_str(), 2 * len + 1);
    return msra::strfun::cstring(&buf[0]);
#endif
}

// ----------------------------------------------------------------------------
// split and join -- tokenize a string like strtok() would, join() strings together
// ----------------------------------------------------------------------------

template <class _T>
static inline std::vector<std::basic_string<_T>> split(const std::basic_string<_T>& s, const _T* delim)
{
    std::vector<std::basic_string<_T>> res;
    for (size_t st = s.find_first_not_of(delim); st != std::basic_string<_T>::npos;)
    {
        size_t en = s.find_first_of(delim, st + 1);
        if (en == std::basic_string<_T>::npos)
            en = s.length();
        res.push_back(s.substr(st, en - st));
        st = s.find_first_not_of(delim, en + 1); // may exceed
    }
    return res;
}

template <class _T>
static inline std::basic_string<_T> join(const std::vector<std::basic_string<_T>>& a, const _T* delim)
{
    std::basic_string<_T> res;
    for (int i = 0; i < (int) a.size(); i++)
    {
        if (i > 0)
            res.append(delim);
        res.append(a[i]);
    }
    return res;
}

// ----------------------------------------------------------------------------
// parsing strings to numbers
// ----------------------------------------------------------------------------

static inline int toint(const wchar_t* s)
{
    return _wtoi(s);
}
static inline int toint(const char* s)
{
    return atoi(s);
}
static inline int toint(const std::wstring& s)
{
    return toint(s.c_str());
}

static inline double todouble(const char* s)
{
    char* ep; // will be set to point to first character that failed parsing
    double value = strtod(s, &ep);
    if (*s == 0 || *ep != 0)
        RuntimeError("todouble: invalid input string '%s'", s);
    return value;
}

// TODO: merge this with todouble(const char*) above
static inline double todouble(const std::string& s)
{
    s.size(); // just used to remove the unreferenced warning

    double value = 0.0;

// stod supposedly exists in VS2010, but some folks have compilation errors
// If this causes errors again, change the #if into the respective one for VS 2010.
#if _MSC_VER > 1400 // VS 2010+
    size_t* idx = 0;
    value = std::stod(s, idx);
    if (idx)
        RuntimeError("todouble: invalid input string '%s'", s.c_str());
#else
    char* ep = 0; // will be updated by strtod to point to first character that failed parsing
    value = strtod(s.c_str(), &ep);

    // strtod documentation says ep points to first unconverted character OR
    // return value will be +/- HUGE_VAL for overflow/underflow
    if (ep != s.c_str() + s.length() || value == HUGE_VAL || value == -HUGE_VAL)
        RuntimeError("todouble: invalid input string '%s'", s.c_str());
#endif

    return value;
}

static inline double todouble(const std::wstring& s)
{
    wchar_t* endptr;
    double value = wcstod(s.c_str(), &endptr);
    if (*endptr)
        RuntimeError("todouble: invalid input string '%ls'", s.c_str());
    return value;
}

// ----------------------------------------------------------------------------
// tokenizer -- utility for white-space tokenizing strings in a character buffer
// This simple class just breaks a string, but does not own the string buffer.
// ----------------------------------------------------------------------------

class tokenizer : public std::vector<char*>
{
    const char* delim;

public:
    tokenizer(const char* delim, size_t cap)
        : delim(delim)
    {
        reserve(cap);
    }
    // Usage: tokenizer tokens (delim, capacity); tokens = buf; tokens.size(), tokens[i]
    void operator=(char* buf)
    {
        resize(0);

// strtok_s not available on all platforms - so backoff to strtok on those
#ifdef _MSC_VER
        char* context; // for strtok_s()
        for (char* p = strtok_s(buf, delim, &context); p; p = strtok_s(NULL, delim, &context))
            push_back(p);
#else
        for (char* p = strtok(buf, delim); p; p = strtok(NULL, delim))
            push_back(p);
#endif
    }
};
}
}

// ----------------------------------------------------------------------------
// functional-programming style helper macros (...do this with templates?)
// ----------------------------------------------------------------------------

#define foreach_index(_i, _dat) for (int _i = 0; _i < (int) (_dat).size(); _i++)
#define map_array(_x, _expr, _y)    \
    {                               \
        _y.resize(_x.size());       \
        foreach_index (_i, _x)      \
            _y[_i] = _expr(_x[_i]); \
    }
#define reduce_array(_x, _expr, _y)                      \
    {                                                    \
        foreach_index (_i, _x)                           \
            _y = (_i == 0) ? _x[_i] : _expr(_y, _x[_i]); \
    }

namespace Microsoft { namespace MSR { namespace CNTK {

// ----------------------------------------------------------------------------
// string comparison class, so we do case insensitive compares
// E.g. to define maps with case-insensitive key lookup
// ----------------------------------------------------------------------------

struct nocase_compare
{
    // std::string version of 'less' function
    // return false for equivalent, true for different
    bool operator()(const string& left, const string& right) const
    {
        return _stricmp(left.c_str(), right.c_str()) < 0;
    }
    // std::wstring version of 'less' function, used in non-config classes
    bool operator()(const wstring& left, const wstring& right) const
    {
        return _wcsicmp(left.c_str(), right.c_str()) < 0;
    }
};

// ----------------------------------------------------------------------------
// random collection of stuff we needed at some place
// ----------------------------------------------------------------------------

// TODO: maybe change to type id of an actual thing we pass in
// TODO: is this header appropriate?
template <class C>
static wstring TypeId()
{
    return msra::strfun::utf16(typeid(C).name());
}

// ----------------------------------------------------------------------------
// dynamic loading of modules  --TODO: not Basics, should move to its own header
// ----------------------------------------------------------------------------

#ifdef _WIN32
class Plugin
{
    HMODULE m_hModule;      // module handle for the writer DLL
    std::wstring m_dllName; // name of the writer DLL
public:
    Plugin()
        : m_hModule(NULL)
    {
    }
    template <class STRING> // accepts char (UTF-8) and wide string
    FARPROC Load(const STRING& plugin, const std::string& proc)
    {
        m_dllName = msra::strfun::utf16(plugin);
        m_dllName += L".dll";
        m_hModule = LoadLibrary(m_dllName.c_str());
        if (m_hModule == NULL)
            RuntimeError("Plugin not found: %s", msra::strfun::utf8(m_dllName).c_str());
        // create a variable of each type just to call the proper templated version
        return GetProcAddress(m_hModule, proc.c_str());
    }
    ~Plugin()
    {
    }
    // we do not unload because this causes the exception messages to be lost (exception vftables are unloaded when DLL is unloaded)
    // ~Plugin() { if (m_hModule) FreeLibrary(m_hModule); }
};
#else
class Plugin
{
private:
    void* handle;

public:
    Plugin()
        : handle(NULL)
    {
    }
    template <class STRING> // accepts char (UTF-8) and wide string
    void* Load(const STRING& plugin, const std::string& proc)
    {
        string soName = msra::strfun::utf8(plugin);
        soName = soName + ".so";
        void* handle = dlopen(soName.c_str(), RTLD_LAZY);
        if (handle == NULL)
            RuntimeError("Plugin not found: %s (error: %s)", soName.c_str(), dlerror());
        return dlsym(handle, proc.c_str());
    }
    ~Plugin()
    {
        if (handle != NULL)
            dlclose(handle);
    }
};
#endif
}
}
}

#ifdef _WIN32
// ----------------------------------------------------------------------------
// frequently missing Win32 functions
// ----------------------------------------------------------------------------

// strerror() for Win32 error codes
static inline std::wstring FormatWin32Error(DWORD error)
{
    wchar_t buf[1024] = {0};
    ::FormatMessageW(FORMAT_MESSAGE_FROM_SYSTEM, "", error, 0, buf, sizeof(buf) / sizeof(*buf) - 1, NULL);
    std::wstring res(buf);
    // eliminate newlines (and spaces) from the end
    size_t last = res.find_last_not_of(L" \t\r\n");
    if (last != std::string::npos)
        res.erase(last + 1, res.length());
    return res;
}
#endif // _WIN32

#endif // _BASICS_H_
back to top