https://github.com/google-research/fixmatch
Tip revision: d4985a158065947dba803e626ee9a6721709c570 authored by David Berthelot on 12 November 2020, 17:50:23 UTC
Merge pull request #46 from daikikatsuragawa/master
Merge pull request #46 from daikikatsuragawa/master
Tip revision: d4985a1
ab_fixmatch_nocutout.py
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
from absl import app
from absl import flags
from fixmatch import FixMatch
from libml import utils, data, augment, ctaugment
FLAGS = flags.FLAGS
class AugmentPoolCTANoCutOut(augment.AugmentPoolCTA):
@staticmethod
def numpy_apply_policies(arglist):
x, cta, probe = arglist
if x.ndim == 3:
assert probe
policy = cta.policy(probe=True)
return dict(policy=policy,
probe=ctaugment.apply(x, policy),
image=x)
assert not probe
policy = lambda: cta.policy(probe=False)
return dict(image=np.stack([x[0]] + [ctaugment.apply(y, policy()) for y in x[1:]]).astype('f'))
class AB_FixMatch_NoCutOut(FixMatch):
AUGMENT_POOL_CLASS = AugmentPoolCTANoCutOut
def main(argv):
utils.setup_main()
del argv # Unused.
dataset = data.PAIR_DATASETS()[FLAGS.dataset]()
log_width = utils.ilog2(dataset.width)
model = AB_FixMatch_NoCutOut(
os.path.join(FLAGS.train_dir, dataset.name, FixMatch.cta_name()),
dataset,
lr=FLAGS.lr,
wd=FLAGS.wd,
arch=FLAGS.arch,
batch=FLAGS.batch,
nclass=dataset.nclass,
wu=FLAGS.wu,
confidence=FLAGS.confidence,
uratio=FLAGS.uratio,
scales=FLAGS.scales or (log_width - 2),
filters=FLAGS.filters,
repeat=FLAGS.repeat)
model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10) # 512 epochs (which is 524K parameter updates)
if __name__ == '__main__':
utils.setup_tf()
flags.DEFINE_float('confidence', 0.95, 'Confidence threshold.')
flags.DEFINE_float('wd', 0.0005, 'Weight decay.')
flags.DEFINE_float('wu', 1, 'Pseudo label loss weight.')
flags.DEFINE_integer('filters', 32, 'Filter size of convolutions.')
flags.DEFINE_integer('repeat', 4, 'Number of residual layers per stage.')
flags.DEFINE_integer('scales', 0, 'Number of 2x2 downscalings in the classifier.')
flags.DEFINE_integer('uratio', 7, 'Unlabeled batch size ratio.')
FLAGS.set_default('augment', 'd.d.d')
FLAGS.set_default('dataset', 'cifar10.3@250-1')
FLAGS.set_default('batch', 64)
FLAGS.set_default('lr', 0.03)
FLAGS.set_default('train_kimg', 1 << 16)
app.run(main)