swh:1:snp:f50ab94432af916b5fb8b4ad831e8dddded77084
Tip revision: 4967e82976319db7d622971fbf508c4b7f6e7874 authored by Cheng Tang on 26 May 2017, 23:17:01 UTC
update interface name. update python api to access lambda with parameters and gradients
update interface name. update python api to access lambda with parameters and gradients
Tip revision: 4967e82
Reader.h
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//
#pragma once
#include <vector>
#include <memory>
#include <functional>
#include "Sequences.h"
#include "TensorShape.h"
#include "ReaderConstants.h"
namespace Microsoft { namespace MSR { namespace CNTK {
typedef GPUSPARSE_INDEX_TYPE IndexType;
typedef std::shared_ptr<TensorShape> TensorShapePtr;
struct MBLayout;
typedef std::shared_ptr<MBLayout> MBLayoutPtr;
// Configuration for the current epoch.
// Each time the epoch is started CNTK should provide the configuration to the reader using StartEpoch method
// and the below structure.
struct ReaderConfiguration
{
ReaderConfiguration()
: m_numberOfWorkers(0), m_workerRank(0), m_minibatchSizeInSamples(0), m_truncationSize(0)
{}
size_t m_numberOfWorkers; // Number of the Open MPI workers for the current epoch
size_t m_workerRank; // Rank of the Open MPI worker, worker rank has to be less than the number of workers
size_t m_minibatchSizeInSamples; // Maximum minibatch size for the epoch in samples
size_t m_truncationSize; // Truncation size in samples for truncated BPTT mode.
// This flag indicates whether the minibatches are allowed to overlap the boundary
// between sweeps (in which case, they can contain data from different sweeps) or
// if they need to be trimmed at the sweep end.
bool m_allowMinibatchesToCrossSweepBoundaries{ false };
};
// TODO: Should be deprecated.
struct EpochConfiguration : public ReaderConfiguration
{
size_t m_totalEpochSizeInSamples; // Total size of the epoch in samples
size_t m_totalEpochSizeInSweeps {g_infinity}; // Total size of the epoch in sweeps (default = no limit).
size_t m_epochIndex; // Current epoch index [0 .. max number of epochs)
};
// Supported primitive element types, will be extended in the future.
enum class ElementType
{
tvariant,// Used by stream definition if deserializer can expose sequences of different type.
// Before the sequence enters the network there should be a transform that
// cast all sequences from such stream to the same type (i.e. tdouble or tfloat).
tfloat, // single precision
tdouble, // double precision
tuchar, // unsigned char
};
// Supported storage types, will be extended in the future.
enum class StorageType
{
dense,
sparse_csc,
};
typedef size_t StreamId;
// This class describes a particular stream: its name, element type, storage, etc.
struct StreamDescription
{
std::wstring m_name; // Unique name of the stream
StreamId m_id; // Unique identifier of the stream
StorageType m_storageType; // Storage type of the stream
ElementType m_elementType; // Element type of the stream
TensorShapePtr m_sampleLayout; // Layout of the sample for the stream
// If not specified - can be specified per sequence
bool m_definesMbSize; // Flag indicating whether the stream is defining the minibatch size
};
typedef std::shared_ptr<StreamDescription> StreamDescriptionPtr;
// Represent a minibatch date for a single stream formatted in according to the minibatch layout.
// This data is returned per stream as a part of Minibatch from the ReadMinibatch function.
// All raw non owned pointers are valid till the next call to the ReadMinibatch function.
struct StreamMinibatch
{
void* m_data; // Contiguous array of data. Can be encoded in dense or sparse formats depending on the stream description.
// The size is (the number of rows * number of columns in the layout) * by the element size of the stream (float/double/etc.).
MBLayoutPtr m_layout; // Layout of the data
};
typedef std::shared_ptr<StreamMinibatch> StreamMinibatchPtr;
// Represents a single minibatch, that contains information about all streams.
struct Minibatch
{
// Indicates that this minibatch is either adjacent to the data sweep boundary
// (-----<minibatch>|---) or crosses the boundary (-----<mini|batch>---).
bool m_endOfSweep;
// Indicates that the end of epoch has been reached.
// It is set to true for the last minibatch, there still
// can be data in m_data field even if this flag is set.
bool m_endOfEpoch;
// Minibatch data
std::vector<StreamMinibatchPtr> m_data;
// A function that maps a sequence id from minibatch layout
// to the string representation of the sequence key.
std::function<std::string(const size_t)> m_getKeyById;
Minibatch(bool endOfSweep = false, bool endOfEpoch = false)
: m_endOfSweep(endOfSweep), m_endOfEpoch(endOfEpoch)
{
}
};
//////////////////////////////////////////////////////////////////////////////////////////////////
// Main Reader interface. The border interface between the CNTK and reader libraries.
//////////////////////////////////////////////////////////////////////////////////////////////////
class Reader
{
public:
// Describes the streams this reader produces.
virtual std::vector<StreamDescriptionPtr> GetStreamDescriptions() = 0;
// Starts a new epoch with the provided configuration
// TODO: should be deprecated, SetConfiguration should be used instead.
virtual void StartEpoch(const EpochConfiguration& config, const std::map<std::wstring, int>& inputDescriptions) = 0;
// Sets a new configuration for the reader.
virtual void SetConfiguration(const ReaderConfiguration& config, const std::map<std::wstring, int>& inputDescriptions) = 0;
// Returns current position in the global timeline. The returned value is in samples.
// TODO: Currently in case of sequence to sequence training,
// TODO: the logical sequence size in samples = max(constitutuing sequences among all streams)
// TODO: This will change in the future.
virtual size_t GetCurrentSamplePosition() = 0;
// Set current global position
virtual void SetCurrentSamplePosition(size_t currentSamplePosition) = 0;
// Reads a minibatch that contains data across all streams.
virtual Minibatch ReadMinibatch() = 0;
virtual ~Reader() {};
};
typedef std::shared_ptr<Reader> ReaderPtr;
}}}