Raw File
test_download.py
# Copyright (c) Facebook, Inc. and its affiliates.

import contextlib
import os
import tempfile
import unittest
from io import StringIO
from unittest import mock

import mmf.utils.download as download
import tests.test_utils as test_utils


TEST_DOWNLOAD_URL = (
    "https://dl.fbaipublicfiles.com/mmf/data/tests/visual_entailment_small.zip"
)
TEST_DOWNLOAD_SHASUM = (
    "e5831397710b71f58a02c243bb6e731989c8f37ef603aaf3ce18957ecd075bf5"
)


class TestUtilsDownload(unittest.TestCase):
    @test_utils.skip_if_no_network
    @test_utils.skip_if_macos
    def test_download_file_class(self):
        # Test normal scenario
        resource = download.DownloadableFile(
            TEST_DOWNLOAD_URL,
            "visual_entailment_small.zip",
            hashcode=TEST_DOWNLOAD_SHASUM,
            compressed=True,
        )

        with tempfile.TemporaryDirectory() as d:
            with contextlib.redirect_stdout(StringIO()):
                resource.download_file(d)
            self.assertTrue(os.path.exists(os.path.join(d, "visual_entailment_small")))
            self.assertTrue(
                os.path.exists(
                    os.path.join(d, "visual_entailment_small", "db", "train.jsonl")
                )
            )
            self.assertTrue(
                os.path.exists(
                    os.path.join(
                        d,
                        "visual_entailment_small",
                        "features",
                        "features.lmdb",
                        "data.mdb",
                    )
                )
            )
            self.assertTrue(
                os.path.exists(os.path.join(d, "visual_entailment_small.zip"))
            )

        # Test when checksum fails
        resource = download.DownloadableFile(
            TEST_DOWNLOAD_URL,
            "visual_entailment_small.zip",
            hashcode="some_random_string",
            compressed=True,
        )

        with tempfile.TemporaryDirectory() as d:
            with contextlib.redirect_stdout(StringIO()):
                self.assertRaises(AssertionError, resource.download_file, d)

        # Test when not compressed
        resource = download.DownloadableFile(
            TEST_DOWNLOAD_URL,
            "visual_entailment_small.zip",
            hashcode=TEST_DOWNLOAD_SHASUM,
            compressed=False,
        )

        with tempfile.TemporaryDirectory() as d:
            with contextlib.redirect_stdout(StringIO()):
                resource.download_file(d)
            self.assertTrue(
                os.path.exists(os.path.join(d, "visual_entailment_small.zip"))
            )
            # Check already downloaded scenarios

            with mock.patch.object(resource, "checksum") as mocked:
                with contextlib.redirect_stdout(StringIO()):
                    resource.download_file(d)
                mocked.assert_called_once_with(d)

            with mock.patch("mmf.utils.download.download") as mocked:
                with contextlib.redirect_stdout(StringIO()):
                    resource.download_file(d)
                mocked.assert_called_once_with(
                    resource._url, d, resource._file_name, redownload=False
                )
            with mock.patch.object(resource, "checksum") as mocked:
                resource._hashcode = "some_random_string"
                with contextlib.redirect_stdout(StringIO()):
                    resource.download_file(d)
                self.assertTrue(mocked.call_count, 2)

            with mock.patch("mmf.utils.download.download") as mocked:
                resource._hashcode = "some_random_string"
                with contextlib.redirect_stdout(StringIO()):
                    self.assertRaises(AssertionError, resource.download_file, d)
                mocked.assert_called_once_with(
                    resource._url, d, resource._file_name, redownload=True
                )

        # Test delete original
        resource = download.DownloadableFile(
            TEST_DOWNLOAD_URL,
            "visual_entailment_small.zip",
            hashcode=TEST_DOWNLOAD_SHASUM,
            compressed=True,
            delete_original=True,
        )

        with tempfile.TemporaryDirectory() as d:
            with contextlib.redirect_stdout(StringIO()):
                resource.download_file(d)
            self.assertFalse(
                os.path.exists(os.path.join(d, "visual_entailment_small.zip"))
            )

    def test_mark_done(self):
        with tempfile.TemporaryDirectory() as d:
            path = os.path.join(d, ".built.json")
            self.assertFalse(os.path.exists(path))
            download.mark_done(d, "0.1")
            self.assertTrue(os.path.exists(path))

            with open(path) as f:
                import json

                data = json.load(f)
                self.assertEqual(list(data.keys()), ["created_at", "version"])

    def test_built(self):
        with tempfile.TemporaryDirectory() as d:
            # First, test without built file
            self.assertFalse(download.built(d, "0.2"))
            download.mark_done(d, "0.1")
            # Test correct version
            self.assertTrue(download.built(d, "0.1"))
            # Test wrong version
            self.assertFalse(download.built(d, "0.2"))
back to top