Skip to content

Commit dcd9d73

Browse files
vishwakftwfacebook-github-bot
authored andcommitted
Expunge torch.utils.trainer.* (#12487)
Differential Revision: D10273602 Pulled By: SsnL fbshipit-source-id: 630c1f8ee0e366f7092d4f93dbe1efa96fc860e0
1 parent 8468b7d commit dcd9d73

File tree

11 files changed

+0
-472
lines changed

11 files changed

+0
-472
lines changed

test/test_utils.py

Lines changed: 0 additions & 160 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,6 @@
1414
import torch.cuda
1515
import warnings
1616
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
17-
from torch.utils.trainer import Trainer
18-
from torch.utils.trainer.plugins import *
19-
from torch.utils.trainer.plugins.plugin import Plugin
2017
from torch.autograd._functions.utils import prepare_onnx_paddings
2118
from torch.autograd._functions.utils import check_onnx_broadcast
2219
from common import IS_WINDOWS, IS_PPC, skipIfRocm
@@ -26,85 +23,6 @@
2623
from common import TestCase, run_tests, download_file
2724

2825

29-
class SimplePlugin(Plugin):
30-
31-
def __init__(self, interval):
32-
super(SimplePlugin, self).__init__(interval)
33-
self.trainer = None
34-
self.num_iteration = 0
35-
self.num_epoch = 0
36-
self.num_batch = 0
37-
self.num_update = 0
38-
39-
def register(self, trainer):
40-
self.trainer = trainer
41-
42-
def iteration(self, *args):
43-
self.iteration_args = args
44-
self.num_iteration += 1
45-
46-
def epoch(self, *args):
47-
self.epoch_args = args
48-
self.num_epoch += 1
49-
50-
def batch(self, *args):
51-
self.batch_args = args
52-
self.num_batch += 1
53-
54-
def update(self, *args):
55-
self.update_args = args
56-
self.num_update += 1
57-
58-
59-
class ModelMock(object):
60-
61-
def __init__(self):
62-
self.num_calls = 0
63-
self.output = torch.ones(1, 1, requires_grad=True)
64-
65-
def __call__(self, i):
66-
self.num_calls += 1
67-
return self.output * 2
68-
69-
70-
class CriterionMock(object):
71-
72-
def __init__(self):
73-
self.num_calls = 0
74-
75-
def __call__(self, out, target):
76-
self.num_calls += 1
77-
return out
78-
79-
80-
class OptimizerMock(object):
81-
max_evals = 5
82-
min_evals = 1
83-
84-
def __init__(self):
85-
self.num_steps = 0
86-
self.num_evals = 0
87-
88-
def step(self, closure):
89-
for i in range(random.randint(self.min_evals, self.max_evals)):
90-
loss = closure()
91-
self.num_evals += 1
92-
self.num_steps += 1
93-
94-
def zero_grad(self):
95-
pass
96-
97-
98-
class DatasetMock(object):
99-
100-
def __iter__(self):
101-
for i in range(10):
102-
yield torch.randn(2, 10), torch.randperm(10)[:2]
103-
104-
def __len__(self):
105-
return 10
106-
107-
10826
class RandomDatasetMock(object):
10927

11028
def __getitem__(self, index):
@@ -279,84 +197,6 @@ def test_multi_drop(self):
279197
self.assertEqual(len(list(dataiter)), 1)
280198

281199

282-
class TestTrainer(TestCase):
283-
284-
intervals = [
285-
[(1, 'iteration')],
286-
[(1, 'epoch')],
287-
[(1, 'batch')],
288-
[(1, 'update')],
289-
[(5, 'iteration')],
290-
[(5, 'epoch')],
291-
[(5, 'batch')],
292-
[(5, 'update')],
293-
[(1, 'iteration'), (1, 'epoch')],
294-
[(5, 'update'), (1, 'iteration')],
295-
[(2, 'epoch'), (1, 'batch')],
296-
]
297-
298-
def setUp(self):
299-
self.optimizer = OptimizerMock()
300-
self.trainer = Trainer(ModelMock(), CriterionMock(),
301-
self.optimizer, DatasetMock())
302-
self.num_epochs = 3
303-
self.dataset_size = len(self.trainer.dataset)
304-
self.num_iters = self.num_epochs * self.dataset_size
305-
306-
def test_register_plugin(self):
307-
for interval in self.intervals:
308-
simple_plugin = SimplePlugin(interval)
309-
self.trainer.register_plugin(simple_plugin)
310-
self.assertEqual(simple_plugin.trainer, self.trainer)
311-
312-
def test_optimizer_step(self):
313-
self.trainer.run(epochs=1)
314-
self.assertEqual(self.trainer.optimizer.num_steps, 10)
315-
316-
def test_plugin_interval(self):
317-
for interval in self.intervals:
318-
self.setUp()
319-
simple_plugin = SimplePlugin(interval)
320-
self.trainer.register_plugin(simple_plugin)
321-
self.trainer.run(epochs=self.num_epochs)
322-
units = {
323-
('iteration', self.num_iters),
324-
('epoch', self.num_epochs),
325-
('batch', self.num_iters),
326-
('update', self.num_iters)
327-
}
328-
for unit, num_triggers in units:
329-
call_every = None
330-
for i, i_unit in interval:
331-
if i_unit == unit:
332-
call_every = i
333-
break
334-
if call_every:
335-
expected_num_calls = math.floor(num_triggers / call_every)
336-
else:
337-
expected_num_calls = 0
338-
num_calls = getattr(simple_plugin, 'num_' + unit)
339-
self.assertEqual(num_calls, expected_num_calls, 0)
340-
341-
def test_model_called(self):
342-
self.trainer.run(epochs=self.num_epochs)
343-
num_model_calls = self.trainer.model.num_calls
344-
num_crit_calls = self.trainer.criterion.num_calls
345-
self.assertEqual(num_model_calls, num_crit_calls)
346-
for num_calls in [num_model_calls, num_crit_calls]:
347-
lower_bound = OptimizerMock.min_evals * self.num_iters
348-
upper_bound = OptimizerMock.max_evals * self.num_iters
349-
self.assertEqual(num_calls, self.trainer.optimizer.num_evals)
350-
self.assertLessEqual(lower_bound, num_calls)
351-
self.assertLessEqual(num_calls, upper_bound)
352-
353-
def test_model_gradient(self):
354-
self.trainer.run(epochs=self.num_epochs)
355-
output_var = self.trainer.model.output
356-
expected_grad = torch.ones(1, 1) * 2 * self.optimizer.num_evals
357-
self.assertEqual(output_var.grad.data, expected_grad)
358-
359-
360200
test_dir = os.path.abspath(os.path.dirname(str(__file__)))
361201

362202

torch/utils/trainer/__init__.py

Lines changed: 0 additions & 2 deletions
This file was deleted.

torch/utils/trainer/plugins/__init__.py

Lines changed: 0 additions & 5 deletions
This file was deleted.

torch/utils/trainer/plugins/accuracy.py

Lines changed: 0 additions & 19 deletions
This file was deleted.

torch/utils/trainer/plugins/logger.py

Lines changed: 0 additions & 83 deletions
This file was deleted.

torch/utils/trainer/plugins/loss.py

Lines changed: 0 additions & 8 deletions
This file was deleted.

torch/utils/trainer/plugins/monitor.py

Lines changed: 0 additions & 57 deletions
This file was deleted.

torch/utils/trainer/plugins/plugin.py

Lines changed: 0 additions & 10 deletions
This file was deleted.

0 commit comments

Comments
 (0)