|
1 | 1 | from __future__ import print_function |
2 | 2 | import sys |
3 | 3 | import os |
| 4 | +import re |
4 | 5 | import math |
5 | 6 | import shutil |
6 | 7 | import random |
@@ -385,6 +386,105 @@ def _transform_MultiMarginCriterion(self, input, target): |
385 | 386 | return input, target.sub(1) |
386 | 387 |
|
387 | 388 |
|
| 389 | +class TestBottleneck(TestCase): |
| 390 | + def _run(self, command): |
| 391 | + """Returns (return-code, stdout, stderr)""" |
| 392 | + import subprocess |
| 393 | + from common import PY3 |
| 394 | + |
| 395 | + p = subprocess.Popen(command, stdout=subprocess.PIPE, |
| 396 | + stderr=subprocess.PIPE, shell=True) |
| 397 | + output, err = p.communicate() |
| 398 | + rc = p.returncode |
| 399 | + if PY3: |
| 400 | + output = output.decode("ascii") |
| 401 | + err = err.decode("ascii") |
| 402 | + return (rc, output, err) |
| 403 | + |
| 404 | + def _run_bottleneck(self, test_file, scriptargs=''): |
| 405 | + import os |
| 406 | + curdir = os.path.dirname(os.path.abspath(__file__)) |
| 407 | + filepath = '{}/{}'.format(curdir, test_file) |
| 408 | + if scriptargs != '': |
| 409 | + mark = '-- ' |
| 410 | + scriptargs = ' {}'.format(scriptargs) |
| 411 | + else: |
| 412 | + mark = '' |
| 413 | + rc, out, err = self._run( |
| 414 | + 'python -m torch.utils.bottleneck {}{}{}'.format(mark, filepath, scriptargs)) |
| 415 | + return rc, out, err |
| 416 | + |
| 417 | + def _check_run_args(self): |
| 418 | + # Check that this fails due to missing args |
| 419 | + rc, out, err = self._run_bottleneck('bottleneck/test_args.py') |
| 420 | + self.assertEqual(rc, 2, None, self._fail_msg('Missing args should error', out + err)) |
| 421 | + |
| 422 | + # This should succeed |
| 423 | + rc, out, err = self._run_bottleneck('bottleneck/test_args.py', '--foo foo --bar bar') |
| 424 | + self.assertEqual(rc, 0, None, self._fail_msg('Should pass args to script', out + err)) |
| 425 | + |
| 426 | + def _fail_msg(self, msg, output): |
| 427 | + return '{}, output was:\n{}'.format(msg, output) |
| 428 | + |
| 429 | + def _check_environment_summary(self, output): |
| 430 | + results = re.search('Environment Summary', output) |
| 431 | + self.assertIsNotNone(results, self._fail_msg('Should have Enviroment Summary', output)) |
| 432 | + |
| 433 | + # Up to five lines away from the heading, there should be the version number |
| 434 | + results = re.search(r'Environment Summary.*(\n.*){,5}\nPyTorch \d+\.\d+', output) |
| 435 | + self.assertIsNotNone(results, self._fail_msg('Should have PyTorch version', output)) |
| 436 | + |
| 437 | + def _check_cprof_summary(self, output): |
| 438 | + results = re.search('cProfile output', output) |
| 439 | + self.assertIsNotNone(results, self._fail_msg('Should have cProfile output', output)) |
| 440 | + |
| 441 | + # This assumes that after the cProfile output section we have |
| 442 | + # the autograd profiler output |
| 443 | + results = re.search(r'cProfile output.*(\n.*){6,50}\n.*autograd profiler output', output) |
| 444 | + self.assertIsNotNone(results, self._fail_msg( |
| 445 | + 'Distance between cProfile and autograd prof out not in [6, 50] lines', output)) |
| 446 | + |
| 447 | + def _check_autograd_summary(self, output): |
| 448 | + results = re.search('autograd profiler output', output) |
| 449 | + self.assertIsNotNone(results, self._fail_msg('Should have autograd profiler output', output)) |
| 450 | + |
| 451 | + # This assumes that after the autograd profiler output is the end of the |
| 452 | + # output. |
| 453 | + results = re.search(r'autograd profiler output.*(\n.*){6,100}', output) |
| 454 | + self.assertIsNotNone(results, self._fail_msg( |
| 455 | + 'Distance between autograd prof output and end of output not in [6, 100] lines', output)) |
| 456 | + |
| 457 | + def _check_cuda(self, output): |
| 458 | + if torch.cuda.is_available(): |
| 459 | + results = re.search('CUDA mode', output) |
| 460 | + self.assertIsNotNone(results, self._fail_msg('Should tell users CUDA', output)) |
| 461 | + else: |
| 462 | + results = re.search('CUDA mode', output) |
| 463 | + self.assertIsNone(results, self._fail_msg('Should not tell users about CUDA', output)) |
| 464 | + |
| 465 | + @unittest.skipIf(torch.cuda.is_available(), 'CPU-only test') |
| 466 | + def test_cpu_only(self): |
| 467 | + rc, out, err = self._run_bottleneck('bottleneck/test.py') |
| 468 | + self.assertEqual(rc, 0, 'Run failed with\n{}'.format(err)) |
| 469 | + |
| 470 | + self._check_run_args() |
| 471 | + self._check_environment_summary(out) |
| 472 | + self._check_autograd_summary(out) |
| 473 | + self._check_cprof_summary(out) |
| 474 | + self._check_cuda(out) |
| 475 | + |
| 476 | + @unittest.skipIf(not torch.cuda.is_available(), 'No CUDA') |
| 477 | + def test_cuda(self): |
| 478 | + rc, out, err = self._run_bottleneck('bottleneck/test_cuda.py') |
| 479 | + self.assertEqual(rc, 0, 'Run failed with\n{}'.format(err)) |
| 480 | + |
| 481 | + self._check_run_args() |
| 482 | + self._check_environment_summary(out) |
| 483 | + self._check_autograd_summary(out) |
| 484 | + self._check_cprof_summary(out) |
| 485 | + self._check_cuda(out) |
| 486 | + |
| 487 | + |
388 | 488 | class TestONNXUtils(TestCase): |
389 | 489 | def test_prepare_onnx_paddings(self): |
390 | 490 | sizes = [2, 3, 4] |
|
0 commit comments