swh:1:snp:eee76444da62e238a10272cb71070ca8823b3f3d
Tip revision: da207d03e7994d9c5a097126dcd509abedc26bc0 authored by zachzhang07 on 21 November 2024, 08:07:14 UTC
Update readme.md
Update readme.md
Tip revision: da207d0
schedule.py
# coding=utf-8
# Copyright 2023 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Hyperparameter schedules."""
import torch
import numpy as np
def log_lerp(t, v0, v1):
"""Interpolate log-linearly from `v0` (t=0) to `v1` (t=1)."""
if v0 <= 0 or v1 <= 0:
raise ValueError(f'Interpolants {v0} and {v1} must be positive.')
lv0 = np.log(v0)
lv1 = np.log(v1)
return np.exp(np.clip(t, 0, 1) * (lv1 - lv0) + lv0)
class Schedule:
pass
class ConstSchedule(Schedule):
"""Fixes the hyperparameter to a constant value: no schedule is used."""
def __init__(self, val):
self.val = val
def __call__(self, step):
return self.val
def __repr__(self):
return f'ConstSchedule: {self.val}'
class DictSchedule(Schedule):
"""Dictionary maps iterations to hyperparameter values."""
def __init__(self, schedule):
self.schedule = schedule
def __call__(self, step):
return [
self.schedule[t] for t in sorted(self.schedule.keys()) if step >= t
][-1]
def __repr__(self):
return f'DictSchedule: {self.schedule}'
class LogLerpSchedule(Schedule):
"""Log-linearly interpolates a hyperparameter."""
def __init__(self, start, end, v0, v1, zero_before_start=False):
self.start = start
self.end = end
self.v0 = v0
self.v1 = v1
self.zero_before_start = zero_before_start
def __call__(self, step):
def h(_step):
t = (_step - self.start) / (self.end - self.start)
return log_lerp(t, self.v0, self.v1)
if self.zero_before_start and step < self.start:
return 0.0
else:
return h(step)
def __repr__(self):
return (
f'LogLerpSchedule: start: {self.start}, end: {self.end}, v0: {self.v0},'
f' v1: {self.v1}'
)