Skip to content

Commit 9bfb72b

Browse files
tensorflower-gardenergunan
authored andcommitted
Added retries to download function.
Change: 139964468
1 parent e088c88 commit 9bfb72b

3 files changed

Lines changed: 170 additions & 1 deletion

File tree

tensorflow/contrib/learn/python/learn/datasets/BUILD

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,18 @@ filegroup(
4444
visibility = ["//tensorflow:__subpackages__"],
4545
)
4646

47+
py_test(
48+
name = "base_test",
49+
size = "small",
50+
srcs = ["base_test.py"],
51+
srcs_version = "PY2AND3",
52+
deps = [
53+
"//tensorflow:tensorflow_py",
54+
"//tensorflow/contrib/learn",
55+
"//tensorflow/python:framework_test_lib",
56+
],
57+
)
58+
4759
py_test(
4860
name = "load_csv_test",
4961
size = "small",

tensorflow/contrib/learn/python/learn/datasets/base.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323
import csv
2424
import os
2525
from os import path
26+
import random
2627
import tempfile
28+
import time
2729

2830
import numpy as np
2931
from six.moves import urllib
@@ -121,6 +123,73 @@ def load_boston(data_path=None):
121123
features_dtype=np.float)
122124

123125

126+
def retry(initial_delay,
127+
max_delay,
128+
factor=2.0,
129+
jitter=0.25,
130+
is_retriable=None):
131+
"""Simple decorator for wrapping retriable functions.
132+
133+
Args:
134+
initial_delay: the initial delay.
135+
factor: each subsequent retry, the delay is multiplied by this value.
136+
(must be >= 1).
137+
jitter: to avoid lockstep, the returned delay is multiplied by a random
138+
number between (1-jitter) and (1+jitter). To add a 20% jitter, set
139+
jitter = 0.2. Must be < 1.
140+
max_delay: the maximum delay allowed (actual max is
141+
max_delay * (1 + jitter).
142+
is_retriable: (optional) a function that takes an Exception as an argument
143+
and returns true if retry should be applied.
144+
"""
145+
if factor < 1:
146+
raise ValueError('factor must be >= 1; was %f' % (factor,))
147+
148+
if jitter >= 1:
149+
raise ValueError('jitter must be < 1; was %f' % (jitter,))
150+
151+
# Generator to compute the individual delays
152+
def delays():
153+
delay = initial_delay
154+
while delay <= max_delay:
155+
yield delay * random.uniform(1 - jitter, 1 + jitter)
156+
delay *= factor
157+
158+
def wrap(fn):
159+
"""Wrapper function factory invoked by decorator magic."""
160+
161+
def wrapped_fn(*args, **kwargs):
162+
"""The actual wrapper function that applies the retry logic."""
163+
for delay in delays():
164+
try:
165+
return fn(*args, **kwargs)
166+
except Exception as e: # pylint: disable=broad-except)
167+
if is_retriable is None:
168+
continue
169+
170+
if is_retriable(e):
171+
time.sleep(delay)
172+
else:
173+
raise
174+
return fn(*args, **kwargs)
175+
return wrapped_fn
176+
return wrap
177+
178+
179+
_RETRIABLE_ERRNOS = {
180+
110, # Connection timed out [socket.py]
181+
}
182+
183+
184+
def _is_retriable(e):
185+
return isinstance(e, IOError) and e.errno in _RETRIABLE_ERRNOS
186+
187+
188+
@retry(initial_delay=1.0, max_delay=16.0, is_retriable=_is_retriable)
189+
def urlretrieve_with_retry(url, filename):
190+
urllib.request.urlretrieve(url, filename)
191+
192+
124193
def maybe_download(filename, work_directory, source_url):
125194
"""Download the data from source url, unless it's already here.
126195
@@ -138,7 +207,7 @@ def maybe_download(filename, work_directory, source_url):
138207
if not gfile.Exists(filepath):
139208
with tempfile.NamedTemporaryFile() as tmpfile:
140209
temp_file_name = tmpfile.name
141-
urllib.request.urlretrieve(source_url, temp_file_name)
210+
urlretrieve_with_retry(source_url, temp_file_name)
142211
gfile.Copy(temp_file_name, filepath)
143212
with gfile.GFile(filepath) as f:
144213
size = f.size()
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
import tensorflow as tf
21+
22+
from tensorflow.contrib.learn.python.learn.datasets import base
23+
24+
mock = tf.test.mock
25+
26+
27+
_TIMEOUT = IOError(110, "timeout")
28+
29+
30+
class BaseTest(tf.test.TestCase):
31+
"""Test load csv functions."""
32+
33+
def testUrlretrieveRetriesOnIOError(self):
34+
with mock.patch.object(base, "time") as mock_time:
35+
with mock.patch.object(base, "urllib") as mock_urllib:
36+
mock_urllib.request.urlretrieve.side_effect = [
37+
_TIMEOUT,
38+
_TIMEOUT,
39+
_TIMEOUT,
40+
_TIMEOUT,
41+
_TIMEOUT,
42+
None
43+
]
44+
base.urlretrieve_with_retry("http://dummy.com", "/tmp/dummy")
45+
46+
# Assert full backoff was tried
47+
actual_list = [arg[0][0] for arg in mock_time.sleep.call_args_list]
48+
expected_list = [1, 2, 4, 8, 16]
49+
for actual, expected in zip(actual_list, expected_list):
50+
self.assertLessEqual(abs(actual - expected), 0.25 * expected)
51+
self.assertEquals(len(actual_list), len(expected_list))
52+
53+
def testUrlretrieveRaisesAfterRetriesAreExhausted(self):
54+
with mock.patch.object(base, "time") as mock_time:
55+
with mock.patch.object(base, "urllib") as mock_urllib:
56+
mock_urllib.request.urlretrieve.side_effect = [
57+
_TIMEOUT,
58+
_TIMEOUT,
59+
_TIMEOUT,
60+
_TIMEOUT,
61+
_TIMEOUT,
62+
_TIMEOUT,
63+
]
64+
with self.assertRaises(IOError):
65+
base.urlretrieve_with_retry("http://dummy.com", "/tmp/dummy")
66+
67+
# Assert full backoff was tried
68+
actual_list = [arg[0][0] for arg in mock_time.sleep.call_args_list]
69+
expected_list = [1, 2, 4, 8, 16]
70+
for actual, expected in zip(actual_list, expected_list):
71+
self.assertLessEqual(abs(actual - expected), 0.25 * expected)
72+
self.assertEquals(len(actual_list), len(expected_list))
73+
74+
def testUrlretrieveRaisesOnNonRetriableErrorWithoutRetry(self):
75+
with mock.patch.object(base, "time") as mock_time:
76+
with mock.patch.object(base, "urllib") as mock_urllib:
77+
mock_urllib.request.urlretrieve.side_effect = [
78+
IOError(2, "No such file or directory"),
79+
]
80+
with self.assertRaises(IOError):
81+
base.urlretrieve_with_retry("http://dummy.com", "/tmp/dummy")
82+
83+
# Assert no retries
84+
self.assertFalse(mock_time.called)
85+
86+
87+
if __name__ == "__main__":
88+
tf.test.main()

0 commit comments

Comments
 (0)