https://github.com/SmartDataAnalytics/AK-DE-biGRU
Tip revision: 3dca7cc2b59cd54781c8eb6751c50cdbe84a8a07 authored by Debanjan Chaudhuri (Deep) on 11 September 2018, 22:01:34 UTC
Update README.md
Update README.md
Tip revision: 3dca7cc
util.py
import os
import torch
def save_model(model, name):
if not os.path.exists('models/'):
os.makedirs('models/')
torch.save(model.state_dict(), 'models/{}.bin'.format(name))
def load_model(model, name, gpu=True):
if gpu:
model.load_state_dict(torch.load('models/{}.bin'.format(name)))
else:
model.load_state_dict(torch.load('models/{}.bin'.format(name), map_location=lambda storage, loc: storage))
return model
def clip_gradient_threshold(model, min, max):
for p in model.parameters():
if p.grad is not None:
p.grad.data.clamp_(min, max)