https://github.com/facebookresearch/pythia
Tip revision: dabf95f523cd07e93380c6931e5140ade0f50b2f authored by Sethu Sankaran on 26 October 2021, 19:18:43 UTC
Revert D30704069: [feat] Add a refiner head that can be used with MMFT
Revert D30704069: [feat] Add a refiner head that can be used with MMFT
Tip revision: dabf95f
xla.py
# Copyright (c) Facebook, Inc. and its affiliates.
import torch
from mmf.utils.distributed import is_main
try:
import torch_xla.core.xla_model as xm
except ImportError:
xm = None
def save_xla_ckpt(ckpt, file_or_path):
"""
Similar to xm.save, but only try to convert "model" and "optimizer" in an MMF
checkpoint to CPU, since they hold PyTorch tensors. Other items like lr_scheduler
often cannot be saved with xm.save due to its errors in handling mappingproxy.
Only save on the global main process (which is different from the default behavior
of xm.save that saves a checkpoint on each node).
"""
should_write_data = is_main()
is_full_ckpt = isinstance(ckpt, dict) and "model" in ckpt and "optimizer" in ckpt
if is_full_ckpt:
ckpt["model"] = xm._maybe_convert_to_cpu(
ckpt["model"], convert=should_write_data
)
ckpt["optimizer"] = xm._maybe_convert_to_cpu(
ckpt["optimizer"], convert=should_write_data
)
else:
ckpt = xm._maybe_convert_to_cpu(ckpt, convert=should_write_data)
if should_write_data:
torch.save(ckpt, file_or_path)
xm.rendezvous("mmf.utils.checkpoint.save_xla_ckpt")