https://gitlab.com/mcoavoux/mtgpy-release-findings-2021.git
Tip revision: c9972219cd75049269d26632d2bb79619d661298 authored by mcoavoux on 20 May 2021, 13:04:44 UTC
up readme
up readme
Tip revision: c997221
Asgd.py
import torch
from torch.optim.optimizer import Optimizer
from torch.optim import Adam
class AAdam(Adam):
def __init__(self, *args, start=200, **kwargs):
super(AAdam, self).__init__(*args, **kwargs)
self.n_steps = 0
self.start = start
def __setstate__(self, state):
super(AAdam, self).__setstate__(state)
def step(self, closure=None):
super(AAdam, self).step(closure)
self.n_steps += 1
if self.n_steps > self.start:
for group in self.param_groups:
for p in group['params']:
param_state = self.state[p]
if 'cache' not in param_state:
param_state['cache'] = torch.zeros_like(p.data)
param_state['saved'] = torch.zeros_like(p.data)
param_state['cache'].add_(p)
def average(self):
if self.n_steps <= self.start:
return
for group in self.param_groups:
for p in group['params']:
param_state = self.state[p]
param_state['saved'].copy_(p.data)
p.data.copy_(param_state['cache'] / (self.n_steps - self.start))
def cancel_average(self):
if self.n_steps <= self.start:
return
for group in self.param_groups:
for p in group['params']:
param_state = self.state[p]
p.data.copy_(param_state['saved'])
class MyAsgd(Optimizer):
def __init__(self, params, lr=0.01, momentum=0, weight_decay=0,
k=-1, gaussian_noise=True, noise=0.01, dc=1e-7):
defaults = dict(lr=lr, momentum=momentum,
weight_decay=weight_decay)
super(MyAsgd, self).__init__(params, defaults)
self.n_steps = 1
self.k = k
self.gaussian_noise = gaussian_noise
self.noise = noise
self.lr = lr
self.lr_bert = None
if len(self.param_groups) > 1:
self.lr_bert = self.param_groups[1]["lr"]
self.dc = dc
#self.warmup = warmup
for group in self.param_groups:
for p in group['params']:
param_state = self.state[p]
param_state['noise'] = torch.zeros_like(p.data)
param_state['cache'] = torch.zeros_like(p.data)
param_state['cache'].copy_(p.data)
def _get_rate(self, i):
# i == 0: not bert
# i != 0: bert
lr_base = [self.lr, self.lr_bert][i]
if self.n_steps > 1000:
return lr_base / (1 + self.n_steps * self.dc)
standard = lr_base / (1 + self.n_steps * self.dc)
return standard * self.n_steps / 1000
def __setstate__(self, state):
super(MyAsgd, self).__setstate__(state)
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for i, group in enumerate(self.param_groups):
group['lr'] = self._get_rate(i)
weight_decay = group['weight_decay']
momentum = group['momentum']
for p in group['params']:
if p.grad is None:
continue
d_p = p.grad.data
if weight_decay != 0:
d_p.add_(weight_decay * p.data)
param_state = self.state[p]
if momentum != 0:
if 'momentum_buffer' not in param_state:
buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
buf.mul_(momentum).add_(d_p)
else:
buf = param_state['momentum_buffer']
buf.mul_(momentum).add_(d_p * (1 - dampening))
if self.gaussian_noise:
# if 'noise' not in param_state:
# param_state['noise'] = torch.zeros_like(p.data)
noise = param_state['noise']
noise.normal_(0, std=self.noise / (1 + self.n_steps)**0.55)
d_p.add_(noise)
if self.n_steps >= self.k:
# if 'cache' not in param_state:
# param_state['cache'] = torch.zeros_like(p.data)
# param_state['cache'].copy_(p.data)
param_state['cache'].add_(d_p * (-group['lr']*self.n_steps))
p.data.add_(d_p * -group['lr'])
self.n_steps += 1
return loss
def average(self):
if self.n_steps < self.k:
return
for group in self.param_groups:
for p in group['params']:
param_state = self.state[p]
if 'saved' not in param_state:
param_state['saved'] = torch.zeros_like(p.data)
param_state['saved'].copy_(p.data)
if 'cache' in param_state:
p.data.add_(param_state['cache'] * (-1/self.n_steps))
def cancel_average(self):
if self.n_steps < self.k:
return
for group in self.param_groups:
for p in group['params']:
param_state = self.state[p]
p.data.copy_(param_state['saved'])