Skip to content

Commit

Permalink
Add verify_ssl option to gluon.utils.download (apache#11546)
Browse files Browse the repository at this point in the history
* Add verify_ssl option to gluon.utils.download

Sometimes datasets may be hosted on servers that serve invalid SSL certificates.

* Add warning

* Add test

* Mock gluon.utils.download tests

* Add Py2 mock dependency to Jenkinsfile
  • Loading branch information
leezu authored and szha committed Jul 25, 2018
1 parent 7ffe3b0 commit 21b4d2f
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 6 deletions.
2 changes: 2 additions & 0 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,7 @@ try {
bat """xcopy C:\\mxnet\\data data /E /I /Y
xcopy C:\\mxnet\\model model /E /I /Y
call activate py2
pip install mock
set PYTHONPATH=${env.WORKSPACE}\\pkg_vc14_cpu\\python
del /S /Q ${env.WORKSPACE}\\pkg_vc14_cpu\\python\\*.pyc
C:\\mxnet\\test_cpu.bat"""
Expand Down Expand Up @@ -893,6 +894,7 @@ try {
bat """xcopy C:\\mxnet\\data data /E /I /Y
xcopy C:\\mxnet\\model model /E /I /Y
call activate py2
pip install mock
set PYTHONPATH=${env.WORKSPACE}\\pkg_vc14_gpu\\python
del /S /Q ${env.WORKSPACE}\\pkg_vc14_gpu\\python\\*.pyc
C:\\mxnet\\test_gpu.bat"""
Expand Down
14 changes: 12 additions & 2 deletions python/mxnet/gluon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def check_sha1(filename, sha1_hash):
return sha1.hexdigest() == sha1_hash


def download(url, path=None, overwrite=False, sha1_hash=None, retries=5):
def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True):
"""Download an given URL
Parameters
Expand All @@ -192,6 +192,8 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5):
but doesn't match.
retries : integer, default 5
The number of times to attempt the download in case of failure or non 200 return codes
verify_ssl : bool, default True
Verify SSL certificates.
Returns
-------
Expand All @@ -200,6 +202,9 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5):
"""
if path is None:
fname = url.split('/')[-1]
# Empty filenames are invalid
assert fname, 'Can\'t construct file-name from this URL. ' \
'Please set the `path` option manually.'
else:
path = os.path.expanduser(path)
if os.path.isdir(path):
Expand All @@ -208,6 +213,11 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5):
fname = path
assert retries >= 0, "Number of retries should be at least 0"

if not verify_ssl:
warnings.warn(
'Unverified HTTPS request is being made (verify_ssl=False). '
'Adding certificate verification is strongly advised.')

if overwrite or not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)):
dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
if not os.path.exists(dirname):
Expand All @@ -217,7 +227,7 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5):
# pylint: disable=W0703
try:
print('Downloading %s from %s...'%(fname, url))
r = requests.get(url, stream=True)
r = requests.get(url, stream=True, verify=verify_ssl)
if r.status_code != 200:
raise RuntimeError("Failed downloading url %s"%url)
with open(fname, 'wb') as f:
Expand Down
44 changes: 40 additions & 4 deletions tests/python/unittest/test_gluon_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,56 @@
# specific language governing permissions and limitations
# under the License.

import io
import os
import tempfile
import warnings

try:
from unittest import mock
except ImportError:
import mock
import mxnet as mx
from nose.tools import *
import requests
from nose.tools import raises


class MockResponse(requests.Response):
def __init__(self, status_code, content):
super(MockResponse, self).__init__()
assert isinstance(status_code, int)
self.status_code = status_code
self.raw = io.BytesIO(content.encode('utf-8'))


@raises(Exception)
@mock.patch(
'requests.get', mock.Mock(side_effect=requests.exceptions.ConnectionError))
def test_download_retries():
mx.gluon.utils.download("http://doesnotexist.notfound")


@mock.patch(
'requests.get',
mock.Mock(side_effect=
lambda *args, **kwargs: MockResponse(200, 'MOCK CONTENT' * 100)))
def test_download_successful():
tmp = tempfile.mkdtemp()
tmpfile = os.path.join(tmp, 'README.md')
mx.gluon.utils.download("https://github.com/raw/apache/incubator-mxnet/master/README.md",
path=tmpfile)
assert os.path.getsize(tmpfile) > 100
mx.gluon.utils.download(
"https://github.com/raw/apache/incubator-mxnet/master/README.md",
path=tmpfile)
assert os.path.getsize(tmpfile) > 100


@mock.patch(
'requests.get',
mock.Mock(
side_effect=lambda *args, **kwargs: MockResponse(200, 'MOCK CONTENT')))
def test_download_ssl_verify():
with warnings.catch_warnings(record=True) as warnings_:
mx.gluon.utils.download(
"https://mxnet.incubator.apache.org/index.html", verify_ssl=False)
assert any(
str(w.message).startswith('Unverified HTTPS request')
for w in warnings_)

0 comments on commit 21b4d2f

Please sign in to comment.