|
14 | 14 | import torch.cuda |
15 | 15 | import warnings |
16 | 16 | from torch.utils.checkpoint import checkpoint, checkpoint_sequential |
17 | | -from torch.utils.trainer import Trainer |
18 | | -from torch.utils.trainer.plugins import * |
19 | | -from torch.utils.trainer.plugins.plugin import Plugin |
20 | 17 | from torch.autograd._functions.utils import prepare_onnx_paddings |
21 | 18 | from torch.autograd._functions.utils import check_onnx_broadcast |
22 | 19 | from common import IS_WINDOWS, IS_PPC, skipIfRocm |
|
26 | 23 | from common import TestCase, run_tests, download_file |
27 | 24 |
|
28 | 25 |
|
29 | | -class SimplePlugin(Plugin): |
30 | | - |
31 | | - def __init__(self, interval): |
32 | | - super(SimplePlugin, self).__init__(interval) |
33 | | - self.trainer = None |
34 | | - self.num_iteration = 0 |
35 | | - self.num_epoch = 0 |
36 | | - self.num_batch = 0 |
37 | | - self.num_update = 0 |
38 | | - |
39 | | - def register(self, trainer): |
40 | | - self.trainer = trainer |
41 | | - |
42 | | - def iteration(self, *args): |
43 | | - self.iteration_args = args |
44 | | - self.num_iteration += 1 |
45 | | - |
46 | | - def epoch(self, *args): |
47 | | - self.epoch_args = args |
48 | | - self.num_epoch += 1 |
49 | | - |
50 | | - def batch(self, *args): |
51 | | - self.batch_args = args |
52 | | - self.num_batch += 1 |
53 | | - |
54 | | - def update(self, *args): |
55 | | - self.update_args = args |
56 | | - self.num_update += 1 |
57 | | - |
58 | | - |
59 | | -class ModelMock(object): |
60 | | - |
61 | | - def __init__(self): |
62 | | - self.num_calls = 0 |
63 | | - self.output = torch.ones(1, 1, requires_grad=True) |
64 | | - |
65 | | - def __call__(self, i): |
66 | | - self.num_calls += 1 |
67 | | - return self.output * 2 |
68 | | - |
69 | | - |
70 | | -class CriterionMock(object): |
71 | | - |
72 | | - def __init__(self): |
73 | | - self.num_calls = 0 |
74 | | - |
75 | | - def __call__(self, out, target): |
76 | | - self.num_calls += 1 |
77 | | - return out |
78 | | - |
79 | | - |
80 | | -class OptimizerMock(object): |
81 | | - max_evals = 5 |
82 | | - min_evals = 1 |
83 | | - |
84 | | - def __init__(self): |
85 | | - self.num_steps = 0 |
86 | | - self.num_evals = 0 |
87 | | - |
88 | | - def step(self, closure): |
89 | | - for i in range(random.randint(self.min_evals, self.max_evals)): |
90 | | - loss = closure() |
91 | | - self.num_evals += 1 |
92 | | - self.num_steps += 1 |
93 | | - |
94 | | - def zero_grad(self): |
95 | | - pass |
96 | | - |
97 | | - |
98 | | -class DatasetMock(object): |
99 | | - |
100 | | - def __iter__(self): |
101 | | - for i in range(10): |
102 | | - yield torch.randn(2, 10), torch.randperm(10)[:2] |
103 | | - |
104 | | - def __len__(self): |
105 | | - return 10 |
106 | | - |
107 | | - |
108 | 26 | class RandomDatasetMock(object): |
109 | 27 |
|
110 | 28 | def __getitem__(self, index): |
@@ -279,84 +197,6 @@ def test_multi_drop(self): |
279 | 197 | self.assertEqual(len(list(dataiter)), 1) |
280 | 198 |
|
281 | 199 |
|
282 | | -class TestTrainer(TestCase): |
283 | | - |
284 | | - intervals = [ |
285 | | - [(1, 'iteration')], |
286 | | - [(1, 'epoch')], |
287 | | - [(1, 'batch')], |
288 | | - [(1, 'update')], |
289 | | - [(5, 'iteration')], |
290 | | - [(5, 'epoch')], |
291 | | - [(5, 'batch')], |
292 | | - [(5, 'update')], |
293 | | - [(1, 'iteration'), (1, 'epoch')], |
294 | | - [(5, 'update'), (1, 'iteration')], |
295 | | - [(2, 'epoch'), (1, 'batch')], |
296 | | - ] |
297 | | - |
298 | | - def setUp(self): |
299 | | - self.optimizer = OptimizerMock() |
300 | | - self.trainer = Trainer(ModelMock(), CriterionMock(), |
301 | | - self.optimizer, DatasetMock()) |
302 | | - self.num_epochs = 3 |
303 | | - self.dataset_size = len(self.trainer.dataset) |
304 | | - self.num_iters = self.num_epochs * self.dataset_size |
305 | | - |
306 | | - def test_register_plugin(self): |
307 | | - for interval in self.intervals: |
308 | | - simple_plugin = SimplePlugin(interval) |
309 | | - self.trainer.register_plugin(simple_plugin) |
310 | | - self.assertEqual(simple_plugin.trainer, self.trainer) |
311 | | - |
312 | | - def test_optimizer_step(self): |
313 | | - self.trainer.run(epochs=1) |
314 | | - self.assertEqual(self.trainer.optimizer.num_steps, 10) |
315 | | - |
316 | | - def test_plugin_interval(self): |
317 | | - for interval in self.intervals: |
318 | | - self.setUp() |
319 | | - simple_plugin = SimplePlugin(interval) |
320 | | - self.trainer.register_plugin(simple_plugin) |
321 | | - self.trainer.run(epochs=self.num_epochs) |
322 | | - units = { |
323 | | - ('iteration', self.num_iters), |
324 | | - ('epoch', self.num_epochs), |
325 | | - ('batch', self.num_iters), |
326 | | - ('update', self.num_iters) |
327 | | - } |
328 | | - for unit, num_triggers in units: |
329 | | - call_every = None |
330 | | - for i, i_unit in interval: |
331 | | - if i_unit == unit: |
332 | | - call_every = i |
333 | | - break |
334 | | - if call_every: |
335 | | - expected_num_calls = math.floor(num_triggers / call_every) |
336 | | - else: |
337 | | - expected_num_calls = 0 |
338 | | - num_calls = getattr(simple_plugin, 'num_' + unit) |
339 | | - self.assertEqual(num_calls, expected_num_calls, 0) |
340 | | - |
341 | | - def test_model_called(self): |
342 | | - self.trainer.run(epochs=self.num_epochs) |
343 | | - num_model_calls = self.trainer.model.num_calls |
344 | | - num_crit_calls = self.trainer.criterion.num_calls |
345 | | - self.assertEqual(num_model_calls, num_crit_calls) |
346 | | - for num_calls in [num_model_calls, num_crit_calls]: |
347 | | - lower_bound = OptimizerMock.min_evals * self.num_iters |
348 | | - upper_bound = OptimizerMock.max_evals * self.num_iters |
349 | | - self.assertEqual(num_calls, self.trainer.optimizer.num_evals) |
350 | | - self.assertLessEqual(lower_bound, num_calls) |
351 | | - self.assertLessEqual(num_calls, upper_bound) |
352 | | - |
353 | | - def test_model_gradient(self): |
354 | | - self.trainer.run(epochs=self.num_epochs) |
355 | | - output_var = self.trainer.model.output |
356 | | - expected_grad = torch.ones(1, 1) * 2 * self.optimizer.num_evals |
357 | | - self.assertEqual(output_var.grad.data, expected_grad) |
358 | | - |
359 | | - |
360 | 200 | test_dir = os.path.abspath(os.path.dirname(str(__file__))) |
361 | 201 |
|
362 | 202 |
|
|
0 commit comments