Skip to content

Commit b81ea65

Browse files
committed
always disallow setting those attrs, add a test
1 parent 963b340 commit b81ea65

File tree

2 files changed

+17
-9
lines changed

2 files changed

+17
-9
lines changed

test/test_dataloader.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,14 @@ def _test_error(self, loader):
299299
math.ceil(float(len(loader.dataset)) / loader.batch_size))
300300
return
301301

302+
def test_invalid_assign_after_init(self):
303+
dl = DataLoader(self.dataset)
304+
for attr in ('batch_size', 'sampler', 'drop_last'):
305+
def fn():
306+
setattr(dl, attr, {})
307+
308+
self.assertRaises(ValueError, fn)
309+
302310
def test_sequential(self):
303311
self._test_sequential(DataLoader(self.dataset))
304312

torch/utils/data/dataloader.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -411,16 +411,18 @@ def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sam
411411

412412
if batch_sampler is not None:
413413
if batch_size > 1 or shuffle or sampler is not None or drop_last:
414-
raise ValueError('batch_sampler is mutually exclusive with '
415-
'batch_size, shuffle, sampler, and drop_last')
414+
raise ValueError('batch_sampler option is mutually exclusive '
415+
'with batch_size, shuffle, sampler, and '
416+
'drop_last')
416417
self.batch_size = None
417418
self.drop_last = None
418419

419420
if sampler is not None and shuffle:
420-
raise ValueError('sampler is mutually exclusive with shuffle')
421+
raise ValueError('sampler option is mutually exclusive with '
422+
'shuffle')
421423

422424
if self.num_workers < 0:
423-
raise ValueError('num_workers cannot be negative; '
425+
raise ValueError('num_workers option cannot be negative; '
424426
'use num_workers=0 to disable multiprocessing.')
425427

426428
if batch_sampler is None:
@@ -436,11 +438,9 @@ def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sam
436438
self.__initialized = True
437439

438440
def __setattr__(self, attr, val):
439-
if self.__initialized and self.batch_sampler is not None and \
440-
attr in ('batch_size', 'sampler', 'drop_last') and \
441-
val is not None:
442-
raise ValueError('{} should not be set when batch_sampler is '
443-
'used'.format(attr))
441+
if self.__initialized and attr in ('batch_size', 'sampler', 'drop_last'):
442+
raise ValueError('{} attribute should not be set after {} is '
443+
'initialized'.format(attr, self.__class__.__name__))
444444

445445
super(DataLoader, self).__setattr__(attr, val)
446446

0 commit comments

Comments
 (0)