test_download.py
# Copyright (c) Facebook, Inc. and its affiliates.
import contextlib
import os
import tempfile
import unittest
from io import StringIO
from unittest import mock
import mmf.utils.download as download
import tests.test_utils as test_utils
TEST_DOWNLOAD_URL = (
"https://dl.fbaipublicfiles.com/mmf/data/tests/visual_entailment_small.zip"
)
TEST_DOWNLOAD_SHASUM = (
"e5831397710b71f58a02c243bb6e731989c8f37ef603aaf3ce18957ecd075bf5"
)
class TestUtilsDownload(unittest.TestCase):
@test_utils.skip_if_no_network
@test_utils.skip_if_macos
def test_download_file_class(self):
# Test normal scenario
resource = download.DownloadableFile(
TEST_DOWNLOAD_URL,
"visual_entailment_small.zip",
hashcode=TEST_DOWNLOAD_SHASUM,
compressed=True,
)
with tempfile.TemporaryDirectory() as d:
with contextlib.redirect_stdout(StringIO()):
resource.download_file(d)
self.assertTrue(os.path.exists(os.path.join(d, "visual_entailment_small")))
self.assertTrue(
os.path.exists(
os.path.join(d, "visual_entailment_small", "db", "train.jsonl")
)
)
self.assertTrue(
os.path.exists(
os.path.join(
d,
"visual_entailment_small",
"features",
"features.lmdb",
"data.mdb",
)
)
)
self.assertTrue(
os.path.exists(os.path.join(d, "visual_entailment_small.zip"))
)
# Test when checksum fails
resource = download.DownloadableFile(
TEST_DOWNLOAD_URL,
"visual_entailment_small.zip",
hashcode="some_random_string",
compressed=True,
)
with tempfile.TemporaryDirectory() as d:
with contextlib.redirect_stdout(StringIO()):
self.assertRaises(AssertionError, resource.download_file, d)
# Test when not compressed
resource = download.DownloadableFile(
TEST_DOWNLOAD_URL,
"visual_entailment_small.zip",
hashcode=TEST_DOWNLOAD_SHASUM,
compressed=False,
)
with tempfile.TemporaryDirectory() as d:
with contextlib.redirect_stdout(StringIO()):
resource.download_file(d)
self.assertTrue(
os.path.exists(os.path.join(d, "visual_entailment_small.zip"))
)
# Check already downloaded scenarios
with mock.patch.object(resource, "checksum") as mocked:
with contextlib.redirect_stdout(StringIO()):
resource.download_file(d)
mocked.assert_called_once_with(d)
with mock.patch("mmf.utils.download.download") as mocked:
with contextlib.redirect_stdout(StringIO()):
resource.download_file(d)
mocked.assert_called_once_with(
resource._url, d, resource._file_name, redownload=False
)
with mock.patch.object(resource, "checksum") as mocked:
resource._hashcode = "some_random_string"
with contextlib.redirect_stdout(StringIO()):
resource.download_file(d)
self.assertTrue(mocked.call_count, 2)
with mock.patch("mmf.utils.download.download") as mocked:
resource._hashcode = "some_random_string"
with contextlib.redirect_stdout(StringIO()):
self.assertRaises(AssertionError, resource.download_file, d)
mocked.assert_called_once_with(
resource._url, d, resource._file_name, redownload=True
)
# Test delete original
resource = download.DownloadableFile(
TEST_DOWNLOAD_URL,
"visual_entailment_small.zip",
hashcode=TEST_DOWNLOAD_SHASUM,
compressed=True,
delete_original=True,
)
with tempfile.TemporaryDirectory() as d:
with contextlib.redirect_stdout(StringIO()):
resource.download_file(d)
self.assertFalse(
os.path.exists(os.path.join(d, "visual_entailment_small.zip"))
)
def test_mark_done(self):
with tempfile.TemporaryDirectory() as d:
path = os.path.join(d, ".built.json")
self.assertFalse(os.path.exists(path))
download.mark_done(d, "0.1")
self.assertTrue(os.path.exists(path))
with open(path) as f:
import json
data = json.load(f)
self.assertEqual(list(data.keys()), ["created_at", "version"])
def test_built(self):
with tempfile.TemporaryDirectory() as d:
# First, test without built file
self.assertFalse(download.built(d, "0.2"))
download.mark_done(d, "0.1")
# Test correct version
self.assertTrue(download.built(d, "0.1"))
# Test wrong version
self.assertFalse(download.built(d, "0.2"))