Revision 0ee1127c81bf245225cd7db50d631e8677abaefc authored by Sasha Sheng on 08 February 2021, 10:50:33 UTC, committed by Facebook GitHub Bot on 08 February 2021, 10:52:20 UTC
Summary:
* pytorch lighting stub mostly involving training
* Tests for lightning trainer included
* built on top of the mmf grad accumulation fix: https://github.com/facebookresearch/mmf/pull/747

- [X] MVP 0. Training: Goal - Train a model from scratch and reach similar accuracy as using mmf_trainer
   - [X] Setup the training pipeline: done
   - [X] Training on the right device: done
   - [X] Clip gradients: done
   - [X] Optimizer: done
   - [X] FP16 Support: done
   - [X] LR scheduler (incl. warmup etc): done
   - [X] testcase: train visual_bert on vqa from scratch fo 10 iterations, compare the value: done
- [x] tests included in this PR (tests are only done for pytorch lightning integration):
   - [X] Vanilla Training w/o grad accumulate, make sure loss for 5 iters are the same between mmf and pl
      - [X] Optimizer working as intended as a part of this PR.
   - [X] `max_updates` and `max_epochs` calculation
   - [x] Training with grad accumulate
   - [x] Training with LR schedule achieves a different value compared to without LR schedule
   - [x] Training with LR schedule for PL is the same as training with LR schedule for `mmf_tranier`
   - [x] Training with gradient clipping make sure all grads are within the `grad_clipping` threshold
   - [x] Training with gradient clipping is the same as training with gradient clipping for `mmf_trainer`

Pull Request resolved: https://github.com/facebookresearch/mmf/pull/748

Reviewed By: apsdehal, simran2905

Differential Revision: D26192869

Pulled By: ytsheng

fbshipit-source-id: 203a91e893d6b878bbed80ed84960dd059cfc90c
1 parent fc72ef0
Raw File
pyproject.toml
[tool.isort]
# This is required to make sorting same as fbcode as all absolute imports
# are considered third party there
known_third_party = [
    "PIL", "cv2", "demjson", "fairscale", "h5py", "lib", "lmdb", "maskrcnn_benchmark", "mmf",
    "numpy", "omegaconf", "packaging", "pycocoevalcap", "pytorch_sphinx_theme",
    "recommonmark", "requests", "setuptools", "sklearn", "termcolor", "tests", "torch",
    "torchtext", "torchvision", "tqdm", "transformers", "pytorch_lightning"
]
skip_glob = "**/build/**,website/**"
combine_as_imports = true
force_grid_wrap = false
include_trailing_comma = true
line_length = 88
multi_line_output = 3
use_parentheses = true
lines_after_imports = 2

[tool.black]
line-length = 88
exclude = '''
/(
    \.git
  | \.hg
  | \.mypy_cache
  | \.tox
  | \.venv
  | _build
  | buck-out
  | build
  | dist
  | website
)/
'''
back to top