From 21b4d2f52c1ab34cf2da0adba1c525af71beef21 Mon Sep 17 00:00:00 2001 From: Leonard Lausen Date: Wed, 25 Jul 2018 00:50:27 +0000 Subject: [PATCH] Add verify_ssl option to gluon.utils.download (#11546) * Add verify_ssl option to gluon.utils.download Sometimes datasets may be hosted on servers that serve invalid SSL certificates. * Add warning * Add test * Mock gluon.utils.download tests * Add Py2 mock dependency to Jenkinsfile --- Jenkinsfile | 2 ++ python/mxnet/gluon/utils.py | 14 ++++++-- tests/python/unittest/test_gluon_utils.py | 44 ++++++++++++++++++++--- 3 files changed, 54 insertions(+), 6 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index e81fb5039908..ed998ee147cb 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -849,6 +849,7 @@ try { bat """xcopy C:\\mxnet\\data data /E /I /Y xcopy C:\\mxnet\\model model /E /I /Y call activate py2 + pip install mock set PYTHONPATH=${env.WORKSPACE}\\pkg_vc14_cpu\\python del /S /Q ${env.WORKSPACE}\\pkg_vc14_cpu\\python\\*.pyc C:\\mxnet\\test_cpu.bat""" @@ -893,6 +894,7 @@ try { bat """xcopy C:\\mxnet\\data data /E /I /Y xcopy C:\\mxnet\\model model /E /I /Y call activate py2 + pip install mock set PYTHONPATH=${env.WORKSPACE}\\pkg_vc14_gpu\\python del /S /Q ${env.WORKSPACE}\\pkg_vc14_gpu\\python\\*.pyc C:\\mxnet\\test_gpu.bat""" diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py index fcb7c97b9809..f04479d23716 100644 --- a/python/mxnet/gluon/utils.py +++ b/python/mxnet/gluon/utils.py @@ -175,7 +175,7 @@ def check_sha1(filename, sha1_hash): return sha1.hexdigest() == sha1_hash -def download(url, path=None, overwrite=False, sha1_hash=None, retries=5): +def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True): """Download an given URL Parameters @@ -192,6 +192,8 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5): but doesn't match. retries : integer, default 5 The number of times to attempt the download in case of failure or non 200 return codes + verify_ssl : bool, default True + Verify SSL certificates. Returns ------- @@ -200,6 +202,9 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5): """ if path is None: fname = url.split('/')[-1] + # Empty filenames are invalid + assert fname, 'Can\'t construct file-name from this URL. ' \ + 'Please set the `path` option manually.' else: path = os.path.expanduser(path) if os.path.isdir(path): @@ -208,6 +213,11 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5): fname = path assert retries >= 0, "Number of retries should be at least 0" + if not verify_ssl: + warnings.warn( + 'Unverified HTTPS request is being made (verify_ssl=False). ' + 'Adding certificate verification is strongly advised.') + if overwrite or not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)): dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname))) if not os.path.exists(dirname): @@ -217,7 +227,7 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5): # pylint: disable=W0703 try: print('Downloading %s from %s...'%(fname, url)) - r = requests.get(url, stream=True) + r = requests.get(url, stream=True, verify=verify_ssl) if r.status_code != 200: raise RuntimeError("Failed downloading url %s"%url) with open(fname, 'wb') as f: diff --git a/tests/python/unittest/test_gluon_utils.py b/tests/python/unittest/test_gluon_utils.py index a5d3b1401a33..431852427f53 100644 --- a/tests/python/unittest/test_gluon_utils.py +++ b/tests/python/unittest/test_gluon_utils.py @@ -15,20 +15,56 @@ # specific language governing permissions and limitations # under the License. +import io import os import tempfile +import warnings +try: + from unittest import mock +except ImportError: + import mock import mxnet as mx -from nose.tools import * +import requests +from nose.tools import raises + + +class MockResponse(requests.Response): + def __init__(self, status_code, content): + super(MockResponse, self).__init__() + assert isinstance(status_code, int) + self.status_code = status_code + self.raw = io.BytesIO(content.encode('utf-8')) @raises(Exception) +@mock.patch( + 'requests.get', mock.Mock(side_effect=requests.exceptions.ConnectionError)) def test_download_retries(): mx.gluon.utils.download("http://doesnotexist.notfound") + +@mock.patch( + 'requests.get', + mock.Mock(side_effect= + lambda *args, **kwargs: MockResponse(200, 'MOCK CONTENT' * 100))) def test_download_successful(): tmp = tempfile.mkdtemp() tmpfile = os.path.join(tmp, 'README.md') - mx.gluon.utils.download("https://github.com/raw/apache/incubator-mxnet/master/README.md", - path=tmpfile) - assert os.path.getsize(tmpfile) > 100 \ No newline at end of file + mx.gluon.utils.download( + "https://github.com/raw/apache/incubator-mxnet/master/README.md", + path=tmpfile) + assert os.path.getsize(tmpfile) > 100 + + +@mock.patch( + 'requests.get', + mock.Mock( + side_effect=lambda *args, **kwargs: MockResponse(200, 'MOCK CONTENT'))) +def test_download_ssl_verify(): + with warnings.catch_warnings(record=True) as warnings_: + mx.gluon.utils.download( + "https://mxnet.incubator.apache.org/index.html", verify_ssl=False) + assert any( + str(w.message).startswith('Unverified HTTPS request') + for w in warnings_)