diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py index b11dd99d6481..f04479d23716 100644 --- a/python/mxnet/gluon/utils.py +++ b/python/mxnet/gluon/utils.py @@ -202,6 +202,9 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ """ if path is None: fname = url.split('/')[-1] + # Empty filenames are invalid + assert fname, 'Can\'t construct file-name from this URL. ' \ + 'Please set the `path` option manually.' else: path = os.path.expanduser(path) if os.path.isdir(path): diff --git a/tests/python/unittest/test_gluon_utils.py b/tests/python/unittest/test_gluon_utils.py index 26ccef253590..6edb313f0a86 100644 --- a/tests/python/unittest/test_gluon_utils.py +++ b/tests/python/unittest/test_gluon_utils.py @@ -15,28 +15,53 @@ # specific language governing permissions and limitations # under the License. +import io import os import tempfile +import warnings +import mock import mxnet as mx -from nose.tools import * +import requests +from nose.tools import raises + + +class MockResponse(requests.Response): + def __init__(self, status_code, content): + super(MockResponse, self).__init__() + assert isinstance(status_code, int) + self.status_code = status_code + self.raw = io.BytesIO(content.encode('utf-8')) @raises(Exception) +@mock.patch( + 'requests.get', mock.Mock(side_effect=requests.exceptions.ConnectionError)) def test_download_retries(): mx.gluon.utils.download("http://doesnotexist.notfound") + +@mock.patch( + 'requests.get', + mock.Mock(side_effect= + lambda *args, **kwargs: MockResponse(200, 'MOCK CONTENT' * 100))) def test_download_successful(): tmp = tempfile.mkdtemp() tmpfile = os.path.join(tmp, 'README.md') - mx.gluon.utils.download("https://github.com/raw/apache/incubator-mxnet/master/README.md", - path=tmpfile) + mx.gluon.utils.download( + "https://github.com/raw/apache/incubator-mxnet/master/README.md", + path=tmpfile) assert os.path.getsize(tmpfile) > 100 + +@mock.patch( + 'requests.get', + mock.Mock( + side_effect=lambda *args, **kwargs: MockResponse(200, 'MOCK CONTENT'))) def test_download_ssl_verify(): with warnings.catch_warnings(record=True) as warnings_: mx.gluon.utils.download( - "https://mxnet.incubator.apache.org/", verify_ssl=False) - assert any( - str(w.message[0]).startswith('Unverified HTTPS request') - for w in warnings_) + "https://mxnet.incubator.apache.org/index.html", verify_ssl=False) + assert any( + str(w.message).startswith('Unverified HTTPS request') + for w in warnings_)