swh:1:snp:f50ab94432af916b5fb8b4ad831e8dddded77084
Raw File
Tip revision: 7bf32d16f58ec13366aad67136e92ca4d85a53c9 authored by liqfu on 22 October 2018, 01:34:23 UTC
ready for seq ops
Tip revision: 7bf32d1
SequenceEnumerator.h
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//

#pragma once

#include <vector>
#include "DataDeserializer.h"
#include "Reader.h"

namespace CNTK {

// Defines a set of sequences for a set of streams.
// Return by the sequence enumerator.
struct Sequences
{
    // Data for up to a requested number of sequences.
    // Indices in the outer vector have to correspond to the stream ids returned from the GetStreamDescriptions().
    std::vector<std::vector<SequenceDataPtr>> m_data;

    // Indicates whether the returned data comes from a sweep end or
    // crosses a sweep boundary (and as a result includes sequences 
    // from different sweeps).
    bool m_endOfSweep{ false };

    // Indicates whether the epoch ends with the data returned.
    bool m_endOfEpoch{ false };
};

class SequenceEnumerator;
typedef std::shared_ptr<SequenceEnumerator> SequenceEnumeratorPtr;

// Sequence enumerator is internal interface used by the packer to get a set of new sequences.
// It is implemented either by different randomizers or by TransformController that can wrap the randomizer
// and apply different transforms on top of data.

// This interface is not exposed to the developers of deserializers/plugins, internal to CNTK.
class SequenceEnumerator
{
public:
    // Describes streams the sequence enumerator produces.
    virtual std::vector<StreamInformation> GetStreamDescriptions() const = 0;

    // Sets current epoch configuration.
    // TODO: should be deprecated.
    virtual void StartEpoch(const EpochConfiguration& config) = 0;

    // Sets current configuration.
    virtual void SetConfiguration(const ReaderConfiguration& config) = 0;

    // Set current sample position
    virtual void SetState(const std::map<std::wstring, size_t>& state) = 0;

    // Returns the current state of the enumerator.
    virtual std::map<std::wstring, size_t> GetState() = 0;

    // Gets next sequences up to a maximum count of local and global samples.
    virtual Sequences GetNextSequences(size_t globalSampleCount, size_t localSampleCount) = 0;

    virtual ~SequenceEnumerator()
    {
    }
};

}
back to top