https://github.com/GPflow/GPflow
Raw File
Tip revision: 3065dee5fed25d5dd06692be470244ecf260cb20 authored by Mark van der Wilk on 16 August 2017, 09:00:37 UTC
Remove pandas (#486)
Tip revision: 3065dee
session.py
import os
import warnings
import tensorflow as tf
from tensorflow.python.client import timeline
from ._settings import settings


class TracerSession(tf.Session):
    def __init__(self, output_file_name, output_directory, each_time, **kwargs):
        self.output_file_name = output_file_name
        self.output_directory = output_directory
        self.eachTime = each_time
        self.local_run_metadata = None
        if self.eachTime:
            warnings.warn("Outputting a trace for each run. May result in large disk usage.")

        super(TracerSession, self).__init__(**kwargs)
        self.counter = 0
        self.profiler_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
        if self.output_directory is not None:
            if os.path.isfile(self.output_directory):
                raise IOError("In tracer: given directory name is a file.")
            if not (os.path.isdir(self.output_directory)):
                os.mkdir(self.output_directory)

    def get_filename(self):
        dir_stub = self.output_directory if self.output_directory is not None else ''
        if self.eachTime:
            return os.path.join(dir_stub, self.output_file_name + '_' + str(self.counter) + '.json')
        else:
            return os.path.join(dir_stub, self.output_file_name + '.json')

    def run(self, fetches, feed_dict=None, options=None):
        # Make sure there is no disagreement doing this.
        if options is not None:
            if options.trace_level != self.profiler_options.trace_level:  # pragma: no cover
                raise ValueError('In profiler session. Inconsistent trace level from run call')  # pragma: no cover
            self.profiler_options.update(options)  # pragma: no cover

        self.local_run_metadata = tf.RunMetadata()
        output = super(TracerSession, self).run(fetches, feed_dict=feed_dict, options=self.profiler_options,
                                                run_metadata=self.local_run_metadata)

        tl = timeline.Timeline(self.local_run_metadata.step_stats)
        ctf = tl.generate_chrome_trace_format()
        with open(self.get_filename(), 'w') as f:
            f.write(ctf)

        if self.eachTime:
            self.counter += 1

        return output


def get_session(*args, **kwargs):
    # Pass session configuration options
    if('config' not in kwargs):
        kwargs['config'] = tf.ConfigProto(**settings.session)
    if settings.profiling.dump_timeline:
        return TracerSession(*args, **kwargs)
    else:
        kwargs.pop("output_file_name", None)
        kwargs.pop("output_directory", None)
        kwargs.pop("each_time", None)
        return tf.Session(*args, **kwargs)
back to top