https://github.com/google-research/s4l
Raw File
Tip revision: 8f1cf0555dad64d987309e3bee682cf8390bf48a authored by Avital Oliver on 06 November 2019, 09:59:56 UTC
Add MOAM step 1
Tip revision: 8f1cf05
representation_lib.py
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Library for training and evaluation of representation learning.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import functools
import math
import os

from absl import flags

import tensorflow.compat.v1 as tf
import tensorflow_hub as hub
from tensorflow.contrib.cluster_resolver import TPUClusterResolver
from tensorflow.contrib.tpu import RunConfig, TPUConfig, TPUEstimator
from tensorflow.contrib.training import checkpoints_iterator

import datasets
import utils
import semi_supervised

FLAGS = flags.FLAGS

# Number of iterations (=training steps) per TPU training loop. Use >100 for
# good speed. This is the minimum number of steps between checkpoints.
TPU_ITERATIONS_PER_LOOP = 200


def train_and_eval():
  """Trains a network on (self) supervised data."""
  checkpoint_dir = FLAGS.get_flag_value("checkpoint", FLAGS.workdir)
  tf.gfile.MakeDirs(checkpoint_dir)

  if FLAGS.tpu_name:
    cluster = TPUClusterResolver(tpu=[FLAGS.tpu_name])
  else:
    cluster = None

  # tf.logging.info("master: %s", master)
  config = RunConfig(
      model_dir=checkpoint_dir,
      tf_random_seed=FLAGS.random_seed,
      cluster=cluster,
      keep_checkpoint_max=None,
      save_checkpoints_steps=FLAGS.save_checkpoints_steps,
      tpu_config=TPUConfig(iterations_per_loop=TPU_ITERATIONS_PER_LOOP))

  # Optionally resume from a stored checkpoint.
  if FLAGS.path_to_initial_ckpt:
    warm_start_from = tf.estimator.WarmStartSettings(
        ckpt_to_initialize_from=FLAGS.path_to_initial_ckpt,
        # The square bracket is important for loading all the
        # variables from GLOBAL_VARIABLES collection.
        # See https://www.tensorflow.org/api_docs/python/tf/estimator/WarmStartSettings  # pylint: disable=line-too-long
        # section vars_to_warm_start for more details.
        vars_to_warm_start=[FLAGS.vars_to_restore]
    )
  else:
    warm_start_from = None

  # The global batch-sizes are passed to the TPU estimator, and it will pass
  # along the local batch size in the model_fn's `params` argument dict.
  estimator = TPUEstimator(
      model_fn=semi_supervised.get_model(FLAGS.task),
      model_dir=checkpoint_dir,
      config=config,
      use_tpu=FLAGS.tpu_name is not None,
      train_batch_size=FLAGS.batch_size,
      eval_batch_size=FLAGS.get_flag_value("eval_batch_size", FLAGS.batch_size),
      warm_start_from=warm_start_from
  )

  if FLAGS.run_eval:
    data_fn = functools.partial(
        datasets.get_data,
        split_name=FLAGS.val_split,
        preprocessing=FLAGS.get_flag_value("preprocessing_eval",
                                           FLAGS.preprocessing),
        is_training=False,
        shuffle=False,
        num_epochs=1,
        drop_remainder=True)

    # Contrary to what the documentation claims, the `train` and the
    # `evaluate` functions NEED to have `max_steps` and/or `steps` set and
    # cannot make use of the iterator's end-of-input exception, so we need
    # to do some math for that here.
    num_samples = datasets.get_count(FLAGS.val_split)
    num_steps = num_samples // FLAGS.get_flag_value("eval_batch_size",
                                                    FLAGS.batch_size)
    tf.logging.info("val_steps: %d", num_steps)

    for checkpoint in checkpoints_iterator(
        estimator.model_dir, timeout=FLAGS.eval_timeout_mins * 60):

      result_dict_val = estimator.evaluate(
          checkpoint_path=checkpoint, input_fn=data_fn, steps=num_steps)

      hub_exporter = hub.LatestModuleExporter("hub", serving_input_fn)
      hub_exporter.export(
          estimator,
          os.path.join(checkpoint_dir, "export/hub"),
          checkpoint)
      # This is here instead of using the above `checkpoints_iterator`'s
      # `timeout_fn` param, because that would wait forever on failed
      # trainers which will never create this file.
      if tf.gfile.Exists(os.path.join(FLAGS.workdir, "TRAINING_IS_DONE")):
        break

    # Evaluates the latest checkpoint on validation set.
    result_dict_val = estimator.evaluate(input_fn=data_fn, steps=num_steps)
    tf.logging.info(result_dict_val)

    # Optionally evaluates the latest checkpoint on test set.
    if FLAGS.test_split:
      data_fn = functools.partial(
          datasets.get_data,
          split_name=FLAGS.test_split,
          preprocessing=FLAGS.get_flag_value("preprocessing_eval",
                                             FLAGS.preprocessing),
          is_training=False,
          shuffle=False,
          num_epochs=1,
          drop_remainder=True)
      num_samples = datasets.get_count(FLAGS.test_split)
      num_steps = num_samples // FLAGS.get_flag_value("eval_batch_size",
                                                      FLAGS.batch_size)
      result_dict_test = estimator.evaluate(input_fn=data_fn, steps=num_steps)
      tf.logging.info(result_dict_test)
    return result_dict_val

  else:
    train_data_fn = functools.partial(
        datasets.get_data,
        split_name=FLAGS.train_split,
        preprocessing=FLAGS.preprocessing,
        is_training=True,
        num_epochs=None,  # read data indefenitely for training
        drop_remainder=True)

    # We compute the number of steps and make use of Estimator's max_steps
    # arguments instead of relying on the Dataset's iterator to run out after
    # a number of epochs so that we can use "fractional" epochs, which are
    # used by regression tests. (And because TPUEstimator needs it anyways.)
    num_samples = datasets.get_count(FLAGS.train_split)
    if FLAGS.num_supervised_examples:
      num_samples = FLAGS.num_supervised_examples
    # Depending on whether we drop the last batch each epoch or only at the
    # ver end, this should be ordered differently for rounding.
    updates_per_epoch = num_samples // FLAGS.batch_size
    epochs = utils.str2intlist(FLAGS.schedule, strict_int=False)[-1]
    num_steps = int(math.ceil(epochs * updates_per_epoch))
    tf.logging.info("train_steps: %d", num_steps)

    return estimator.train(
        train_data_fn,
        max_steps=num_steps)


def serving_input_fn():  # pylint: disable=missing-docstring
  """A serving input fn."""
  input_shape = utils.str2intlist(FLAGS.serving_input_shape)
  image_features = {
      FLAGS.serving_input_key:
          tf.placeholder(dtype=tf.float32, shape=input_shape)}
  return tf.estimator.export.ServingInputReceiver(
      features=image_features, receiver_tensors=image_features)
back to top