https://github.com/Microsoft/CNTK
Tip revision: 57bc5fb448c0b74b13d0134b55e40d195e7f38f0 authored by Eldar Akchurin on 26 June 2017, 11:47:29 UTC
More fixes
More fixes
Tip revision: 57bc5fb
ScriptableObjects.h
// BrainScriptObjects.h -- objects that the config parser operates on
#pragma once
#include "Basics.h"
#include "Platform.h" // for noexcept workaround on VS2013
#include <memory> // for shared_ptr<>
#include <functional> // for function<>
#include <map>
#include <set>
namespace Microsoft { namespace MSR { namespace ScriptableObjects {
using namespace std;
using namespace msra::strfun; // for wstrprintf()
using namespace Microsoft::MSR::CNTK; // for stuff from Basics.h
// -----------------------------------------------------------------------
// ScriptingException -- base class for any errors thrown by scripting
// It's a runtime_error with an additional virtual function PrintError().
// -----------------------------------------------------------------------
class ScriptingException : public runtime_error
{
public:
template <typename M>
ScriptingException(const M &msg) :
runtime_error(msg)
{
}
virtual std::wstring GetError(const std::wstring& /*linePrefix*/) const = 0;
virtual void PrintError(const std::wstring& /*linePrefix*/) const = 0;
};
// -----------------------------------------------------------------------
// Object -- common base class for objects that can be used in config files
// -----------------------------------------------------------------------
// All values that can be used in config files
// - are heap objects
// - primitives are wrapped
// - object pointers are ref-counted shared_ptr, wrapped in ConfigValuePtr (see BrainScriptEvaluator.h)
// - derive from Object (outside classes get wrapped)
//
// This code supports three kinds of value types:
// - self-defined classes -> derive from Object, e.g. Expression
// - classes defined outside -> wrap in a BoxOf object, e.g. String = BoxOf<std::wstring>
// - C++ primitives like 'double' -> wrap in a Wrapper first then in a BoxOf, e.g. Number = BoxOf<Wrapped<double>>
struct Object
{
virtual ~Object()
{
}
};
// indicates that the object has a name should be set from the expression path
struct HasName
{
virtual void SetName(const std::wstring &name) = 0;
};
// -----------------------------------------------------------------------
// Wrapped<T> -- wraps non-class primitive C++ type into a class, like 'double'.
// (It can also be used for class types, but better use BoxOf<> below directly.)
// -----------------------------------------------------------------------
template <typename T>
class Wrapped
{
T value; // meant to be a primitive type
public:
operator const T &() const
{
return value;
}
operator T &()
{
return value;
}
Wrapped(T value)
: value(value)
{
}
T &operator=(const T &newValue)
{
value = newValue;
return *this;
}
};
typedef Wrapped<double> Double;
typedef Wrapped<bool> Bool;
// -----------------------------------------------------------------------
// BoxOf<T> -- wraps a pre-defined type, e.g. std::wstring, to derive from Object.
// BoxOf<T> can dynamic_cast to T (e.g. BoxOf<std::wstring> is a std::wstring).
// -----------------------------------------------------------------------
template <class C>
class BoxOf : public Object, public C
{
public:
#if 1
template <class... _Types>
BoxOf(_Types &&... _Args)
: C(forward<_Types>(_Args)...)
{
}
#else
// TODO: change this to variadic templates, then we can instantiate everything we need through this
BoxOf(const C &val)
: C(val)
{
}
BoxOf()
{
}
#endif
};
// -----------------------------------------------------------------------
// String -- a std::string in config files
// Can cast to std::wstring (done in a way that ConfigValuePtr can also cast to std::wstring).
// -----------------------------------------------------------------------
typedef BoxOf<std::wstring> String;
// -----------------------------------------------------------------------
// ComputationNodeObject -- the 'magic' class that our parser understands for infix operations
// TODO: unify with ComputationNodeBase
// -----------------------------------------------------------------------
class ComputationNodeObject : public Object
{
}; // a base class for all nodes (that has no template parameter)
// -----------------------------------------------------------------------
// HasToString -- trait to indicate an object can print their content
// Derive from HasToString() and implement ToString() method.
// FormatConfigValue() will then return ToString().
// -----------------------------------------------------------------------
struct HasToString
{
virtual std::wstring ToString() const = 0;
// some std::string helpers useful for ToString() operations of nested structures
// TODO: move these out from this header into some more general place (I had to move them here because otherwise CNTKEval failed to compile)
static std::wstring IndentString(std::wstring s, size_t indent)
{
const std::wstring prefix(indent, L' ');
size_t pos = 0;
for (;;)
{
s.insert(pos, prefix);
pos = s.find(L'\n', pos + 2);
if (pos == std::wstring::npos)
return s;
pos++;
}
}
static std::wstring NestString(std::wstring s, wchar_t open, bool newline, wchar_t close)
{
std::wstring result = IndentString(s, 2);
if (newline) // have a new line after the open symbol
result = L" \n" + result + L"\n ";
else
result.append(L" ");
result.front() = open;
result.back() = close;
return result;
}
};
// -----------------------------------------------------------------------
// WithTags -- trait to give an object a set of tag strings
// -----------------------------------------------------------------------
class WithTags : std::set<std::wstring>
{
public:
WithTags()
{
}
bool SetTag(const std::wstring &tag)
{
auto res = insert(tag);
return res.second; // true if was not there before
}
bool ClearTag(const std::wstring &tag)
{
auto iter = find(tag);
if (iter == end())
return false;
erase(iter);
return true; // indicates that we used to have this tag
}
bool HasTag(const std::wstring &tag) const
{
return find(tag) != end();
}
const std::set<std::wstring>& GetTags() const
{
return *this;
}
};
// =======================================================================
// ConfigValuePtr -- shared pointer to a config value
// =======================================================================
// A ConfigValuePtr holds the value of a configuration variable.
// - specifically, it holds a shared_ptr to a strongly typed C++ object
// - ConfigValuePtrs are immutable when consumed.
//
// All configuration values, that is, values that can be held by a ConfigValuePtr, derive from BS::Object.
// To get a shared_ptr<T> of an expected type T, type-cast the ConfigValuePtr to it.
// To get the value of a copyable type like T=double or std::wstring, type-cast to T directly.
//
// ConfigValuePtrs are evaluated on-demand upon first retrieval:
// - initially, a ConfigValuePtr would hold a Thunk; that is, a lambda that computes (resolves) the value
// - upon first use, the Thunk is invoked to compute the value, which will then *replace* the Thunk
// - any consumer of a ConfigValuePtr will only ever see the resolved value, since any access for consumption will force it to be resolved
// - a resolved ConfigValuePtr is immutable
//
// On-demand evaluation is critical to the semantics of this entire configuration system.
// A configuration is but one big expression (of nested records), but some evaluations cause side effects (such as saving a model), and some expressions may not even be in use at all.
// Thus, we must use on-demand evaluation in order to ensure that side effects are only executed when desired.
//
// Further, to ensure a Thunk is executed at most once (otherwise we may get the same side-effect multiple times),
// an unresolved ConfigValuePtr can only live in a single place. This means,
// - an unresolved ConfigValuePtr (i.e. one holding a Thunk) cannot be copied (while resolved ones are immutable and can be copied freely)
// - it can be moved (std::move()) during creation
// - after creation, it should only live in a known location from which it can be retrieved; specifically:
// - ConfigRecord entries
// - ConfigArrays elements
// - ConfigLambdas (default values of named arguments)
struct IConfigRecord;
class ConfigArray;
// TODO: separate this out from BrainScript to an interface that still does type casts--possible?
class ConfigValuePtr : public shared_ptr<Object>
{
function<void(const std::wstring &)> failfn; // function to call in case of failure due to this value
std::wstring expressionName; // the expression name reflects the path to reach this expression in the (possibly dynamically macro-expanded) expression tree. Used for naming ComputationNodes.
// Thunk for resolving a value. This Object represents a function that returns a ConfigValuePtr; call to resolve a deferred value
class Thunk : public Object
{
function<ConfigValuePtr()> f; // the function to compute the value
bool currentlyResolving; // set during resolution phase, to detect circular references
function<void(const std::wstring &)> failfn; // function to call in case of failure due to this value
public:
Thunk(function<ConfigValuePtr()> f, const function<void(const std::wstring &)> &failfn)
: f(f), failfn(failfn), currentlyResolving(false)
{
}
ConfigValuePtr ResolveValue()
{
if (currentlyResolving) // detect circular references (infinite recursion)
failfn(L"circular reference (expression to compute identifier's value uses the identifier's value)");
currentlyResolving = true; // can't run from inside ourselves
return f();
// no need to reset currentlyResolving because this object gets replaced and thus deleted anyway
}
};
Thunk *GetThunk() const
{
return dynamic_cast<Thunk *>(get());
} // get Thunk object or nullptr if already resolved
public:
// --- assignment and copy/move constructors
ConfigValuePtr()
{
} // (formally needed somehow)
ConfigValuePtr(const shared_ptr<Object> &p, const function<void(const std::wstring &)> &failfn, const std::wstring &expressionName)
: shared_ptr<Object>(p), failfn(failfn), expressionName(expressionName)
{
}
// ConfigValuePtr(const function<ConfigValuePtr()> & f, TextLocation location, const std::wstring & expressionName) : shared_ptr<Object>(make_shared<Thunk>(f, location)), location(location), expressionName(expressionName) { }
static ConfigValuePtr MakeThunk(const function<ConfigValuePtr()> &f, const function<void(const std::wstring &)> &failfn, const std::wstring &expressionName)
{
return ConfigValuePtr(make_shared<Thunk>(f, failfn), failfn, expressionName);
}
// TODO: somehow the constructor overload from Thunk function fails to compile, so for now use MakeThunk instead
ConfigValuePtr(const ConfigValuePtr &other)
{
*this = other;
}
ConfigValuePtr(ConfigValuePtr &&other) noexcept
{
*this = move(other);
}
void operator=(const ConfigValuePtr &other)
{
if (other.GetThunk()) // unresolved ConfigValuePtrs are not copyable, only movable
Microsoft::MSR::CNTK::LogicError("ConfigValuePtr::operator=() on unresolved object '%ls'; ConfigValuePtr is not assignable until resolved", expressionName.empty() ? L"(unassigned)" : expressionName.c_str());
(shared_ptr<Object> &) *this = other;
failfn = other.failfn;
expressionName = other.expressionName;
}
void operator=(ConfigValuePtr &&other) noexcept
{
failfn = move(other.failfn);
expressionName = move(other.expressionName);
(shared_ptr<Object> &) *this = move(other);
}
void Fail(const std::wstring &msg) const
{
failfn(msg);
}
const function<void(const std::wstring &)> &GetFailFn() const
{
return failfn;
} // if you need to pass on the fail function
// --- retrieving values by type cast
// access as a reference, that is, as a shared_ptr<T> --use this for Objects
template <typename T>
operator shared_ptr<T>() const
{
return AsPtr<T>();
}
// access as a (const & to) value --use this for primitive types (also works to get a const std::wstring & from a String)
template <typename T>
operator T() const
{
return AsRef<T>();
}
// Linux gcc barfs on this ^^ for 'us = (double)((std::wstring)arg).size();' due to some ambiguity error (while it works fine with Visual Studio).
// If you encounter this, instead say 'us = (double)((const std::wstring&)arg).size();' with a &. Don't forget the const (I have seen broken typecasts without).
operator const IConfigRecord &() const
{
return AsRef<IConfigRecord>();
}
operator const ConfigArray &() const
{
return AsRef<ConfigArray>();
}
operator const std::wstring &() const
{
return AsRef<std::wstring>();
} // somehow operator T() does not work here, still giving ambiguous messages. This makes it work. Probably not generic. Need to fix this.
operator double() const
{
return AsRef<Double>();
}
operator float() const
{
return (float) AsRef<Double>();
}
operator bool() const
{
return AsRef<Bool>();
}
template <typename INT>
INT AsInt() const
{
double val = AsRef<Double>();
INT ival = (INT) val;
const wchar_t *type = L"size_t";
// TODO: there is some duplication of type checking; can we unify that?
if (ival != val)
Fail(wstrprintf(L"expected expression of type %ls instead of floating-point value %f", type, val));
return ival;
}
operator size_t() const
{
return AsInt<size_t>();
}
operator int() const
{
return AsInt<int>();
}
// --- access functions
template <class C>
bool Is() const // note: also works with null pointers (will return false)
{
EnsureIsResolved();
const auto p = dynamic_cast<C *>(get());
return p != nullptr;
}
template <class C>
const C &AsRef() const // returns reference to the 'value' member. Configs are considered immutable, so return a const&
{
// TODO: factor these lines into a separate function
// Note: since this returns a reference into 'this', you must keep the object you call this on around as long as you use the returned reference
EnsureIsResolved();
// const C * wanted = (C *) nullptr; const auto * got = get(); wanted; got; // allows to see C in the debugger
const auto p = dynamic_cast<C *>(get());
if (p == nullptr) // TODO: can we make this look the same as TypeExpected in BrainScriptEvaluator.cpp? We'd need the type name
Fail(L"config member has wrong type (" + msra::strfun::utf16(typeid(*get()).name()) + L"), expected a " + TypeId<C>());
return *p;
}
template <class C>
shared_ptr<C> AsPtr() const // returns a shared_ptr cast to the 'value' member
{
EnsureIsResolved();
const auto p = dynamic_pointer_cast<C>(*this);
if (!p) // TODO: can we make this look the same as TypeExpected in BrainScriptEvaluator.cpp? We'd need the type name
Fail(L"config member has wrong type (" + msra::strfun::utf16(typeid(*get()).name()) + L"), expected a " + TypeId<C>());
return p;
}
// --- properties
const char *TypeName() const
{
return typeid(*get()).name();
}
const std::wstring &GetExpressionName() const
{
return expressionName;
}
// TODO: ^^ it seems by saving the name in the ConfigValuePtr itself, we don't gain anything; maybe remove again in the future
// --- methods for resolving the value
const ConfigValuePtr &ResolveValue() const // (this is const but mutates the value if it resolves)
{
// call this when a a member might be as-of-yet unresolved, to evaluate it on-demand
// get() is a pointer to a Thunk in that case, that is, a function object that yields the value
const auto thunkp = GetThunk(); // is it a Thunk?
if (thunkp) // value is a Thunk: we need to resolve
{
const auto value = thunkp->ResolveValue(); // completely replace ourselves with the actual result. This also releases the Thunk object
const_cast<ConfigValuePtr &>(*this) = value;
ResolveValue(); // allow it to return another Thunk...
}
return *this; // return ourselves so we can access a value as p_resolved = p->ResolveValue()
}
void EnsureIsResolved() const
{
if (GetThunk())
Microsoft::MSR::CNTK::LogicError("ConfigValuePtr: unexpected access to unresolved object; ConfigValuePtrs can only be accessed after resolution");
}
}; // ConfigValuePtr
// use this for primitive values, double and bool
template <typename T>
static inline ConfigValuePtr MakePrimitiveConfigValuePtr(const T &val, const function<void(const std::wstring &)> &failfn, const std::wstring &exprPath)
{
return ConfigValuePtr(make_shared<BoxOf<Wrapped<T>>>(val), failfn, exprPath);
}
// -----------------------------------------------------------------------
// IConfigRecord -- config record
// Inside BrainScript, this would be a BS::ConfigRecord, but outside of the
// evaluator, we will only pass it through this interface, to allow for
// extensibility (e.g. Python interfacing).
// Also, Objects themselves can expose this interface to make something accessible.
// -----------------------------------------------------------------------
struct IConfigRecord // any class that exposes config can derive from this
{
virtual const ConfigValuePtr &operator[](const std::wstring &id) const = 0; // e.g. confRec[L"message"]
virtual const ConfigValuePtr *Find(const std::wstring &id) const = 0; // returns nullptr if not found
virtual std::vector<std::wstring> GetMemberIds() const = 0; // returns the names of all members in this record (but not including parent scopes)
// prettier access if config record is a pointer
const ConfigValuePtr &Get(const std::wstring &id) const
{
return operator[](id);
} // e.g. confRecPtr->Get(L"message")
// access with default values
template <class ValueType>
ValueType operator()(const std::wstring &id, const ValueType &defaultValue) const // e.g. confRec("message", "hello)
{
const auto *valp = Find(id);
return valp ? *valp : defaultValue;
}
inline const IConfigRecord &operator()(const std::wstring &id, const IConfigRecord &defaultValue) const // retrieve a nested ConfigRecord
{
const auto *valp = Find(id);
return valp ? valp->AsRef<IConfigRecord>() : defaultValue;
}
std::string operator()(const std::wstring &id, const char *defaultValue) const
{
return msra::strfun::utf8(operator()(id, (std::wstring) msra::strfun::utf16(defaultValue)));
} // special case for narrow strings
std::wstring operator()(const std::wstring &id, const wchar_t *defaultValue) const
{
return operator()(id, std::wstring(defaultValue));
}
// -----------------------------------------------------------------------
// emulation of old CNTK config/NL
// This allows code written for CNTK config to simply turn ConfigParameters into a template parameter and accept an IConfigRecord.
// TODO: change all id args to wide strings, then update the code.
// -----------------------------------------------------------------------
const ConfigValuePtr &operator()(const std::wstring &id) const;
template <class T>
std::vector<T> operator()(const std::wstring &id, const std::vector<T> &defaultValue) const;
bool ExistsCurrent(const std::wstring &id) const;
bool Exists(const std::wstring &id) const
{
return Find(id) != nullptr;
}
bool Match(const std::wstring &id, const std::wstring &compareValue) const;
bool CanBeConfigRecord(const std::wstring &id) const
{
return operator[](id).Is<IConfigRecord>();
}
bool CanBeString(const std::wstring &id) const
{
return operator[](id).Is<std::wstring>();
}
const std::string ConfigName() const;
static const IConfigRecord &Record();
template <class V>
static const std::vector<typename V::value_type> &Array(const V &vec);
};
typedef shared_ptr<struct IConfigRecord> IConfigRecordPtr;
// -----------------------------------------------------------------------
// ConfigRecord -- collection of named config values
// -----------------------------------------------------------------------
class ConfigRecord : public Object, public IConfigRecord // all configuration arguments to class construction, resolved into ConfigValuePtrs
{
function<void(const std::wstring &)> failfn; // function to call in case of failure due to this value
// change to ContextInsensitiveMap<ConfigValuePtr>
std::map<std::wstring, ConfigValuePtr> members;
IConfigRecordPtr parentScope; // we look up the chain
ConfigRecord()
{
} // forbidden (private) to instantiate without a scope
public:
// --- creation phase
ConfigRecord(IConfigRecordPtr parentScope, const function<void(const std::wstring &)> &failfn)
: parentScope(parentScope), failfn(failfn)
{
}
void Add(const std::wstring &id, const function<void(const std::wstring &)> & /*failfn*/, const ConfigValuePtr &value)
{
members[id] = value;
}
void Add(const std::wstring &id, const function<void(const std::wstring &)> & /*failfn*/, ConfigValuePtr &&value)
{
members[id] = move(value);
} // use this for unresolved ConfigPtrs
// TODO: Add() does not yet correctly handle the failfn. It is meant to flag the location of the variable identifier
// --- usage phase
// regular lookup: just use record[id] or record(id, L"helpful message what 'id' does")
// Any unresolved value is resolved at this time, as it is being consumed. Only after resolving a ConfigValuePtr, it can be copied.
const ConfigValuePtr & /*IConfigRecord::*/ operator[](const std::wstring &id) const // e.g. confRec[L"name"]
{
const auto memberIter = members.find(id);
if (memberIter != members.end())
return memberIter->second.ResolveValue(); // resolve upon access
if (!parentScope) // not found: if at top scope, we fail
failfn(L"required parameter '" + id + L"' not found");
// The failfn will report the location where the dictionary itself was formed.
// This is because this function is meant to be used by C++ code.
// When we look up a name by a BrainScript ".FIELD" expression, we will use Find() so we can report the error for the offending FIELD itself.
return (*parentScope)[id]; // have parent: look it up there
}
const ConfigValuePtr * /*IConfigRecord::*/ Find(const std::wstring &id) const // returns nullptr if not found
{
auto memberIter = members.find(id);
if (memberIter == members.end())
if (parentScope)
return parentScope->Find(id);
else
return nullptr;
else
return &memberIter->second.ResolveValue();
}
// get member ids; use this when you intend to consume all record entries and do not know the names
// Note that unlike Find() and operator[], which return parent matches, this only returns entries in this record.
virtual std::vector<std::wstring> /*IConfigRecord::*/ GetMemberIds() const
{
std::vector<std::wstring> ids;
for (auto &member : members)
ids.push_back(member.first);
return ids;
}
};
typedef shared_ptr<ConfigRecord> ConfigRecordPtr;
// TODO: can ConfigRecordPtr be IConfigRecordPtr?
// create a runtime object from its type --general case
// There can be specializations of this that instantiate objects that do not take ConfigRecords or involve mapping like ComputationNode.
template <typename C>
shared_ptr<Object> MakeRuntimeObject(const IConfigRecordPtr config)
{
return make_shared<C>(config);
}
// -----------------------------------------------------------------------
// ConfigArray -- an array of config values
// -----------------------------------------------------------------------
// an array is just a std::vector of config values
class ConfigArray : public Object
{
std::vector<ConfigValuePtr> values;
int firstIndex;
public:
ConfigArray()
: firstIndex(0)
{
}
ConfigArray(int firstIndex, std::vector<ConfigValuePtr> &&values)
: firstIndex(firstIndex), values(move(values))
{
}
// ConfigArray(ConfigValuePtr && val) : firstIndex(0), values(std::vector<ConfigValuePtr>{ move(val) }) { }
pair<int, int> GetIndexBeginEnd() const
{
return make_pair(firstIndex, firstIndex + (int)values.size());
}
// for use as a plain array: get size and verify that index range starts with 0
template <typename FAILFN>
size_t GetSize(const FAILFN &Fail) const
{
if (firstIndex != 0)
Fail(L"This array is expected to begin with index 0.");
return values.size();
}
// building the array from expressions: append an element or an array
void Append(const ConfigValuePtr &value)
{
values.push_back(value);
}
void Append(ConfigValuePtr &&value)
{
values.push_back(move(value));
} // this appends an unresolved ConfigValuePtr
void Append(const ConfigArray &other)
{
values.insert(values.end(), other.values.begin(), other.values.end());
}
// get element at index, including bounds check
template <typename FAILFN>
const ConfigValuePtr &At(int index, const FAILFN &Fail /*should report location of the index*/) const
{
if (index < firstIndex || index >= firstIndex + values.size())
Fail(msra::strfun::wstrprintf(L"Index %d out of bounds [%d..%d].", (int) index, (int) firstIndex, (int) (firstIndex + values.size() - 1)));
return values[(size_t)(index - firstIndex)].ResolveValue(); // resolve upon access
}
// get element when knowing that the bounds are correct, e.g. looping over the item range returned by GetItemRange()
const ConfigValuePtr &At(int index) const
{
return At(index, [](const std::wstring &) { LogicError("ConfigArray::At(): Index unexpectedly out of bounds."); });
}
// get an entire array into a std::vector. Note that this will force all values to be evaluated.
template <typename C, typename FAILFN>
std::vector<C> AsVector(const FAILFN &Fail, bool flatten = false) const
{
std::vector<C> res;
res.reserve(GetSize(Fail));
for (const auto& valp : values)
{
valp.ResolveValue(); // resolve upon access
if (!flatten || !valp.Is<ConfigArray>())
{
const C &type = valp;
res.push_back(type);
}
else // special case: flatten nested vectors (only if 'flatten')
{
std::vector<C> subVector = valp.AsRef<ConfigArray>().AsVector<C>(Fail, flatten);
res.insert(res.end(), subVector.begin(), subVector.end());
}
}
return res;
}
// helper function: get a vector from config that may be a scalar, a ConfigArray, or nested ConfigArrays meant to be flattened
template <typename E>
static vector<E> FlattenedVectorFrom(const ConfigValuePtr& valp)
{
if (valp.Is<vector<E>>())
return valp.AsRef<vector<E>>(); // UNTESTED
else if (valp.Is<ConfigArray>())
return valp.AsRef<ConfigArray>().AsVector<E>([&](const wstring& msg) { valp.Fail(msg); }, /*flatten=*/true);
else
return std::vector<E>(1, (const E&)valp); // single element
}
};
typedef shared_ptr<ConfigArray> ConfigArrayPtr;
// -----------------------------------------------------------------------
// ConfigLambda -- a lambda
// -----------------------------------------------------------------------
class ConfigLambda : public Object
{
public:
typedef std::map<std::wstring, ConfigValuePtr> NamedParams; // TODO: maybe even not use a typedef, just use the type
private:
// the function itself is a C++ lambda
function<ConfigValuePtr(std::vector<ConfigValuePtr> &&, NamedParams &&, const std::wstring &exprName)> f;
// inputs. This defines the interface to the function. Very simple in our case though.
// We pass rvalue references because that allows to pass Thunks.
std::vector<std::wstring> paramNames; // #parameters and parameter names (names are used for naming expressions only)
NamedParams namedParams; // lists named parameters with their default values. Named parameters are optional and thus always must have a default.
public:
template <typename F>
ConfigLambda(std::vector<std::wstring> &¶mNames, NamedParams &&namedParams, const F &f)
: paramNames(move(paramNames)), namedParams(move(namedParams)), f(f)
{
}
size_t GetNumParams() const { return paramNames.size(); }
const std::vector<std::wstring>& GetParamNames() const { return paramNames; } // used for expression naming and function composition
const NamedParams& GetNamedParams() const { return namedParams; } // used for function composition
// what this function does is call f() held in this object with the given arguments except optional arguments are verified and fall back to their defaults if not given
// The arguments are rvalue references, which allows us to pass Thunks, which is important to allow stuff with circular references like CNTK's DelayedNode.
ConfigValuePtr Apply(std::vector<ConfigValuePtr> &&args, NamedParams &&namedArgs, const std::wstring &exprName) const
{
NamedParams actualNamedArgs;
// actualNamedArgs is a filtered version of namedArgs that contains all optional args listed in namedParams,
// falling back to their default if not given in namedArgs.
// On the other hand, any name in namedArgs that is not found in namedParams should be rejected.
for (const auto &namedParam : namedParams)
{
const auto &id = namedParam.first; // id of expected named parameter
const auto valuei = namedArgs.find(id); // was such parameter passed?
if (valuei == namedArgs.end()) // named parameter not passed
{ // if not given then fall back to default
auto f2 = [&namedParam]() // we pass a lambda that resolves it upon first use, in our original location
{
return namedParam.second.ResolveValue();
};
actualNamedArgs[id] = move(ConfigValuePtr::MakeThunk(f2, namedParam.second.GetFailFn(), exprName));
}
else // named parameter was passed
actualNamedArgs[id] = move(valuei->second); // move it, possibly remaining unresolved
// BUGBUG: we should pass in the location of the identifier, not that of the expression
}
for (const auto &namedArg : namedArgs) // make sure there are no extra named args that the macro does not take
if (namedParams.find(namedArg.first) == namedParams.end())
namedArg.second.Fail(L"function does not have an optional argument '" + namedArg.first + L"'");
return f(move(args), move(actualNamedArgs), exprName);
}
// TODO: define an overload that takes const & for external users (which will then take a copy and pass it on to Apply &&)
};
typedef shared_ptr<ConfigLambda> ConfigLambdaPtr;
// -----------------------------------------------------------------------
// CustomConfigRecord -- helper for implementors of IConfigRecord
// Custom classes that implement IConfigRecord can derive from this to make
// it easier to manage the simulated config record.
// -----------------------------------------------------------------------
struct CustomConfigRecord : public IConfigRecord // any class that exposes config can derive from this
{
const ConfigValuePtr& /*IConfigRecord::*/ operator[](const std::wstring& id) const override // e.g. confRec[L"message"]
{
const auto* valuep = Find(id);
if (!valuep)
RuntimeError("Unknown configuration-record member '%ls'", id.c_str());
return *valuep;
}
const ConfigValuePtr* /*IConfigRecord::*/ Find(const std::wstring& id) const // returns nullptr if not found
{
const auto& mapIter = members.find(id);
if (mapIter != members.end())
return &mapIter->second;
LazyCreateConfigMember(id);
const auto& mapIter2 = members.find(id);
if (mapIter2 != members.end())
return &mapIter2->second;
else
return nullptr;
}
void InsertConfigMember(const std::wstring& id, ConfigValuePtr&& valuep) const/*because it is called from Find() which is const*/
{
const auto res = members.insert(make_pair(id, move(valuep)));
assert(&res.first->second == &members.find(id)->second);
assert(res.second); // this says whether it has been inserted. It better be.
}
// call this whenever anything changes about this node
// Once we use a ComputationNode as a config record, we are in immutable BS world.
// So we should not really need to ever call this.
// However, we *must* call it at least in DetachInputs() in order to break cyclic dependencies.
void ClearConfigMemberCache()
{
members.clear();
}
// user of this class must implement LazyCreateConfigMember() and GetMemberIds()
virtual void LazyCreateConfigMember(const std::wstring &id) const = 0;
protected:
// cached return values from IConfigRecord implementation
mutable std::map<std::wstring, ScriptableObjects::ConfigValuePtr> members; // [id] -> cached ConfigValuePtr
};
// -----------------------------------------------------------------------
// ConfigurableRuntimeType -- interface to scriptable runtime types
// -----------------------------------------------------------------------
// helper for configurableRuntimeTypes initializer below
// This returns a ConfigurableRuntimeType info structure that consists of
// - a lambda that is a constructor for a given runtime type and
// - a bool saying whether T derives from IConfigRecord
struct ConfigurableRuntimeType // TODO: rename to ScriptableObjects::Factory or something like that
{
bool isConfigRecord; // exposes IConfigRecord --in this case the expression name is computed differently, namely relative to this item
// TODO: is this ^^ actually still used anywhere?
function<shared_ptr<Object>(const IConfigRecordPtr)> construct; // lambda to construct an object of this class
// TODO: we should pass the expression name to construct() as well
};
// scriptable runtime types must be exposed by this function
// TODO: should this be a static member of above class?
const ConfigurableRuntimeType *FindExternalRuntimeTypeInfo(const std::wstring &typeId);
// -----------------------------------------------------------------------
// ConfigurableRuntimeTypeRegister -- static table of all configurable runtime types
// -----------------------------------------------------------------------
class ConfigurableRuntimeTypeRegister
{
// we wrap the static variable in a function so that we don't need a CPP file
static std::map<std::wstring, ConfigurableRuntimeType> &GetTheRegister()
{
// the one static variable that contains all configurable runtime types
static std::map<std::wstring, ConfigurableRuntimeType> reg;
return reg;
}
static void Register(const wchar_t *typeId, ConfigurableRuntimeType &&rtInfo)
{
auto ® = GetTheRegister();
auto res = reg.insert(std::make_pair((std::wstring) typeId, std::move(rtInfo)));
if (!res.second)
LogicError("RegisterConfigurableRuntimeType: Attempted to register type '%ls' twice.", typeId);
}
public:
// to instantiate a ConfigurableRuntimeType object, use this function to find its constructor
static const ConfigurableRuntimeType *Find(const std::wstring &typeId)
{
auto ® = GetTheRegister();
auto iter = reg.find(typeId);
if (iter == reg.end())
return nullptr;
else
return &iter->second;
}
// to register a runtime type, use an instance of this class in each library
// ConfigurableRuntimeTypeRegister::Add<ClassName> registerClassName(L"ClassName")l
template <class C>
struct Add
{
Add(const wchar_t *typeId)
{
// create the runtime info
ConfigurableRuntimeType rtInfo;
rtInfo.construct = [](const IConfigRecordPtr config) // lambda to construct--this lambda can construct both the <float> and the <double> variant based on config parameter 'precision'
{
return MakeRuntimeObject<C>(config);
};
rtInfo.isConfigRecord = is_base_of<IConfigRecord, C>::value;
// insert it into the static table
Register(typeId, std::move(rtInfo));
}
};
// to register a class that exists in dual precisions (Something<ElemType>>, use this one instead
// ConfigurableRuntimeTypeRegister::AddFloatDouble<ClassName<float>,ClassName<double>> registerClassName(L"ClassName")l
template <class Cfloat, class Cdouble>
struct AddFloatDouble
{
AddFloatDouble(const wchar_t *typeId)
{
// create the runtime info
ConfigurableRuntimeType rtInfo;
rtInfo.construct = [](const IConfigRecordPtr config) // lambda to construct--this lambda can construct both the <float> and the <double> variant based on config parameter 'precision'
{
std::wstring precision = (*config)[L"precision"]; // dispatch on ElemType
if (precision == L"float")
return MakeRuntimeObject<Cfloat>(config);
else if (precision == L"double")
return MakeRuntimeObject<Cdouble>(config);
else
RuntimeError("invalid value '%ls' for 'precision', must be 'float' or 'double'", precision.c_str());
};
rtInfo.isConfigRecord = is_base_of<IConfigRecord, Cfloat>::value;
static_assert(is_base_of<IConfigRecord, Cfloat>::value == is_base_of<IConfigRecord, Cdouble>::value, ""); // we assume that both float and double have the same behavior
// insert it into the static table
Register(typeId, std::move(rtInfo));
}
};
};
// -----------------------------------------------------------------------
// IConfigRecord emulation of old CNTK config/NL
// This allows code written for CNTK config to simply turn ConfigParameters into a template parameter and accept an IConfigRecord.
// -----------------------------------------------------------------------
inline const ConfigValuePtr &IConfigRecord::operator()(const std::wstring &id) const
{
return operator[](id);
} // e.g. confRec(L"message") instead of confRec[L"message"]
template <class T>
inline std::vector<T> IConfigRecord::operator()(const std::wstring &id, const std::vector<T> &defaultValue) const // retrieve an argvector (which derives from std::vector)
{
const auto *valp = Find(id);
if (!valp)
return defaultValue; // default value
if (!valp->Is<ConfigArray>())
return std::vector<T>(1, (const T &) *valp); // scalar value
const ConfigArray &arr = *valp; // actual array
#if 1 // TODO: test whether this works correctly w.r.t. typecasting
return arr.AsVector<T>([&](const std::wstring &msg) { valp->Fail(msg); });
#else
const auto size = arr.GetSize([&](const std::wstring &msg) { valp->Fail(msg); });
std::vector<T> res(size);
for (int i = 0; i < size; i++)
res[i] = (const T &) arr.At(i);
return res;
#endif
}
inline bool IConfigRecord::ExistsCurrent(const std::wstring &id) const // this is inefficient, but we can optimize it if it ever turns out to be a problem. I rather think, this function is misguided. The name is bad, too.
{
for (const auto &idIter : GetMemberIds()) // linear scan. Not using STL algorithm to avoid pulling in a big header at this level
if (idIter == id)
return true;
return false;
}
inline bool IConfigRecord::Match(const std::wstring &id, const std::wstring &compareValue) const
{
auto *valp = Find(id);
std::wstring val = valp ? *valp : std::wstring();
return EqualCI(compareValue, val);
}
inline const std::string IConfigRecord::ConfigName() const
{
LogicError("ConfigName not supported by BrainScript."); // needed in BinaryWriter
}
/*static*/ inline const IConfigRecord &IConfigRecord::Record() // empty record to be passed as a default to operator() when retrieving a nested ConfigRecord
{
static struct EmptyConfigRecord : public IConfigRecord
{
virtual const ScriptableObjects::ConfigValuePtr &operator[](const std::wstring &) const override final
{
InvalidArgument("EmptyConfigRecord: Attempted to return a value from the empty record.");
}
virtual const ScriptableObjects::ConfigValuePtr *Find(const std::wstring &) const override final
{
return nullptr;
}
virtual std::vector<std::wstring> GetMemberIds() const
{
return std::vector<std::wstring>();
}
} emptyParameters;
return emptyParameters;
}
template <class V>
/*static*/ const std::vector<typename V::value_type> &IConfigRecord::Array(const V &vec)
{
return static_cast<const std::vector<typename V::value_type> &>(vec);
} // use this specifically for XXXargvector
}}} // end namespaces