https://github.com/facebookresearch/pythia
Revision 5d943274052f5f6aa2584257b4e55a44cf260847 authored by Jun Chen on 03 June 2021, 06:45:35 UTC, committed by Facebook GitHub Bot on 03 June 2021, 06:46:37 UTC
Summary:
Pull Request resolved: https://github.com/facebookresearch/mmf/pull/956

Add a general knowledge distillation in MMF

# High-level Design

* Knowledge distillation model contains teacher and student models. Teacher model is used to generate soft targets and let student model mimic.
* Add a processor which can process text and feature of teacher and student models. The text/feature processors can be different.
* In yaml config, student model still uses the name like `text_processor` in dataset_config. `teacher_text_processor` or `teacher_feature_key` can be added to specify teacher's processor or feature_key. Teacher model is built and loaded from `modal_config.kd.teacher.pretrained_checkpoint` and its parameters can be frozen or tunable. Student config has the same format as teacher's and supports loading pretrained model.
* Support three distillation losses: MSELoss, CosineEmbeddingLoss, and KLDivergence. The final loss is a weighted sum of distillation and metric losses.

# Summary of experiment results
1. LR set as 1e-4 to 1e-3
2. KL divergence as distillation loss is the best
3. Classification loss weight should be small
4. Different teacher architecture doesn't really matter
5. It doesn't matter whether student parameters are pretrained or not
6. If image features are the same in teacher and student, extra data is not needed to get similar performance
7. Otherwise, unlabeled data will help student model performance

Reviewed By: vedanuj

Differential Revision: D27726297

fbshipit-source-id: 37673782265aa1e80af0172b974ae430d6fd7a6e
1 parent c2af5c1
History
Tip revision: 5d943274052f5f6aa2584257b4e55a44cf260847 authored by Jun Chen on 03 June 2021, 06:45:35 UTC
[feat] Implement knowledge distillation in MMF (#956)
Tip revision: 5d94327
File Mode Size
.circleci
.github
docs
mmf
mmf_cli
projects
tests
tools
website
.editorconfig -rw-r--r-- 191 bytes
.flake8 -rw-r--r-- 187 bytes
.gitignore -rw-r--r-- 267 bytes
.pre-commit-config.yaml -rw-r--r-- 1.1 KB
LICENSE -rw-r--r-- 1.5 KB
MANIFEST.in -rw-r--r-- 130 bytes
NOTICES -rw-r--r-- 4.5 KB
README.md -rw-r--r-- 2.2 KB
pyproject.toml -rw-r--r-- 958 bytes
requirements.txt -rw-r--r-- 475 bytes
setup.py -rw-r--r-- 5.1 KB

README.md

back to top