Skip to content

Conversation

@AlexanderRadionov
Copy link
Contributor

Added ind_worker_queue parameter to data.DataLoader. It makes preprocessing determinate.

DataLoader in multiprocessing mode may cause non-deterministic issue. Even if radom_seed has frozen, each subprocess may get tasks in unstable order. This is caused by different I/O time while data loads. If you use augmentation while data loading, it makes results unreproduceble. Look at the https://discuss.pytorch.org/t/deterministic-non-deterministic-results-with-pytorch/9087

To fix this issue I have added the individual queue for each worker. In this case each worker get tasks in the stable order. In summary, subprocess produces the stable results.

To reproduce issue you may change ind_worker_queue to False and run the script several times.
Code to reproduce issue.

import time
import numpy as np
import random
from datetime import datetime
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

GLOBAL_SEED = 1024

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

GLOBAL_WORKER_ID = None
def worker_init_fn(worker_id):
    global GLOBAL_WORKER_ID
    GLOBAL_WORKER_ID = worker_id
    set_seed(GLOBAL_SEED + worker_id)


class RandomAugmentationDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __getitem__(self, item):
        global GLOBAL_WORKER_ID
        rnd = datetime.now().microsecond % 5 # it should be unstable and not depend on random_seed
        if rnd < 3:
            time.sleep(rnd * 0.1)
        return item, GLOBAL_WORKER_ID, self.data[item] + np.random.random_sample(self.data[item].shape)

    def __len__(self):
        return self.data.shape[0]

set_seed(GLOBAL_SEED)
data = np.random.random_sample((128, 128))
dataset = RandomAugmentationDataset(data)

expected = [
    (2070.646908929146, [97, 0, 79, 2, 85, 121, 117, 56, 83, 43, 70, 60, 108, 47, 101, 95], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
    (2063.337454820789, [71, 91, 20, 34, 107, 27, 52, 55, 69, 90, 127, 84, 81, 105, 46, 124], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
    (2031.0978096882036, [118, 86, 44, 14, 18, 92, 98, 36, 49, 13, 111, 28, 116, 10, 104, 87], [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]),
    (2045.9375171891238, [103, 15, 19, 80, 29, 120, 1, 106, 123, 24, 39, 31, 3, 65, 45, 50], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
    (2083.1401737400893, [11, 89, 12, 58, 54, 32, 67, 17, 110, 40, 82, 9, 30, 94, 125, 35], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
    (2035.8920455019213, [41, 74, 33, 68, 64, 8, 76, 42, 126, 53, 122, 26, 114, 75, 57, 112], [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]),
    (2078.9789912514643, [109, 119, 113, 37, 5, 22, 59, 51, 62, 25, 38, 66, 6, 102, 73, 96], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
    (2054.459126893385, [16, 72, 88, 115, 48, 61, 7, 93, 4, 77, 63, 100, 99, 21, 78, 23], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
]

s = time.time()
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=3, worker_init_fn=worker_init_fn, ind_worker_queue=True)
for i, (items, worker_ids, batch) in enumerate(dataloader):
    assert round(expected[i][0], 5) == round(np.sum(batch.numpy()), 5)
    assert expected[i][1] == items.numpy().tolist()

e = time.time()
print("time", e-s)

@ssnl
Copy link
Collaborator

ssnl commented Jan 12, 2018

Thanks! Could you add a test in test_dataloader.py as well?

This comment was marked as off-topic.

@meownoid
Copy link

Personally I don't think that relying on iterator consumption order is a good idea. One accident next call can ruin everything.

@AlexanderRadionov
Copy link
Contributor Author

@ssnl Done

Copy link
Collaborator

@ssnl ssnl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late review. Thanks for the contribution. I think after addressing the comments and resolve conflict, this will be good to go :).

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@ssnl
Copy link
Collaborator

ssnl commented Feb 13, 2018

@pytorchbot add to whitelist

@ssnl
Copy link
Collaborator

ssnl commented Feb 13, 2018

Sorry I forgot to enable CI on this PR. Should be testing now.

@apaszke
Copy link
Contributor

apaszke commented Feb 13, 2018

@AlexanderRadionov do you have any benchmarks for the new data loader? Me and @colesbury are worried that it will lower the throughput, and it's important for training on e.g. 8x Volta GPUs.

@AlexanderRadionov
Copy link
Contributor Author

I made some syntetic tests, just iterate by dataloader. And i have not see any issue.
Here are code:

import time
from datetime import datetime
import numpy as np
import random
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

GLOBAL_SEED = 1024

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

GLOBAL_WORKER_ID = None
def worker_init_fn(worker_id):
    global GLOBAL_WORKER_ID
    GLOBAL_WORKER_ID = worker_id
    set_seed(GLOBAL_SEED + worker_id)


class RandomAugmentationDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __getitem__(self, item):
        global GLOBAL_WORKER_ID
        return item, GLOBAL_WORKER_ID, self.data[item] + np.random.random_sample(self.data[item].shape)

    def __len__(self):
        return self.data.shape[0]

set_seed(GLOBAL_SEED)
data = np.random.random_sample((12800, 12800))
dataset = RandomAugmentationDataset(data)

s = time.time()
for i in range(0, 11):
    dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=3, worker_init_fn=worker_init_fn, ind_worker_queue=True)
    for i, (items, worker_ids, batch) in enumerate(dataloader):
        pass
e = time.time()
print("time", e-s)

I run it with ind_worker_queue=True and ind_worker_queue=False several times.
False (default):
time 17.274786710739136
time 16.919730186462402
True (overrided):
time 16.521615266799927
time 16.571028470993042

I think my changes should not cause slowdown because individual worker behavior is not default.

@apaszke
Copy link
Contributor

apaszke commented Feb 13, 2018

I'm thinking more of a real ImageNet-like workload that actually has to read from disk, and some requests may get stalled for longer than others, slowing down the whole loading process (unlike previously).

@AlexanderRadionov
Copy link
Contributor Author

AlexanderRadionov commented Feb 13, 2018

I tried to emulate read time difference in my original code posted at head of merge request (random sleep). Here are results:

False (default):
time 123.02024149894714
time 123.15355587005615

True (overrided):
time 126.91952013969421
time 126.23874187469482

Do you think is it problem for non-default behavior? @colesbury @apaszke

@ssnl
Copy link
Collaborator

ssnl commented Feb 16, 2018

@apaszke Yes the process will be slowed. But have reproducible results is very important. Do you have suggestions on how to do multiprocess deterministic dataloading in a faster way?

@apaszke
Copy link
Contributor

apaszke commented Feb 16, 2018

@AlexanderRadionov what was the duration of the sleep you've used? I'd still want to try this on a real machine before merging, but I think it's ok. The only change I'd request would be to remove the flag, and make this mode always enabled.

@AlexanderRadionov
Copy link
Contributor Author

AlexanderRadionov commented Feb 16, 2018

@apaszke I've used sleep random [0.001; 0.003]s duration.

Also I tested on real dataset data. Dataset have 96265 images, total 8,7G.

False (default)
3 workers: 237s
6 workers: 178s

True (overrided)
3 workers: 242s
6 workers: 177s

@apaszke
Copy link
Contributor

apaszke commented Feb 21, 2018

LGTM. I just want to run a real-world ImageNet benchmark and should be good to merge

@AlexanderRadionov
Copy link
Contributor Author

@apaszke I have no enough resources to run ImageNet benchmark.

@apaszke
Copy link
Contributor

apaszke commented Feb 22, 2018

@AlexanderRadionov no worries, I wasn't requesting that from you. I was going to do it myself

@grafi-tt
Copy link

grafi-tt commented Mar 1, 2018

Just FYI, another strategy to ensuring determinacy is storing PRNG states on shared memory. It doesn't requires fixing worker assignment, though its implementation is complex and serialization / deserialization overheads are incurred.

A PR to Chainer chainer/chainer#4230 uses this strategy, because fixing process assignment is impossible in Chainer as it uses multiprocessing.Pool.

@apaszke
Copy link
Contributor

apaszke commented Mar 1, 2018

@grafi-tt I'm not sure how this would work. You need to ensure that the sequence of locks for the PRNG will be the same at every run, but this is very hard or prohibitively slow if you sample multiple times while loading a data item (e.g. sample a mask, sample a rotation, ...).

@grafi-tt
Copy link

grafi-tt commented Mar 5, 2018

@apaszke It's simple: cache the state. When a worker process is resumed, a PRNG state is deserialized from shared memory (by e.g. numpy.RandomState.set_state()). When the worker process yields a result, the PRNG state is serialized to shared memory (by e.g. numpy.RandomState.get_state()). During the worker is running, there is no overhead.

@apaszke
Copy link
Contributor

apaszke commented Mar 5, 2018

@grafi-tt oh I see. This makes sense but is still not ideal:

  • you still need a bunch of synchronization, and everyone needs to lock the shared memory and put their random state in it, while everyone else is just waiting (this is a fairly expensive memcpy)
  • you really want to make the random sequences in different workers to seem independent of each other. If I understand correctly you will can end up with this behavior with two workers:
    • worker 1 starts, reads state 0
    • worker 2 starts, reads state 0
    • both workers do some work and sample e.g. twice. since both started from the same random state, they both end up in state 1
    • worker 1 finishes, saves state 1
    • worker 2 finishes, saves state 1
    • then the process repeats, and the sampling results are identical in both workers.

@ezyang
Copy link
Contributor

ezyang commented Mar 6, 2018

@pytorchbot test this please

@ezyang
Copy link
Contributor

ezyang commented Mar 6, 2018

@pytorchbot retest this please

3 similar comments
@ezyang
Copy link
Contributor

ezyang commented Mar 6, 2018

@pytorchbot retest this please

@ezyang
Copy link
Contributor

ezyang commented Mar 6, 2018

@pytorchbot retest this please

@ezyang
Copy link
Contributor

ezyang commented Mar 6, 2018

@pytorchbot retest this please

@ssnl
Copy link
Collaborator

ssnl commented Mar 7, 2018

@pytorchbot retest this please

@ezyang ezyang merged commit 8317803 into pytorch:master Mar 23, 2018
sighingnow added a commit to sighingnow/pytorch that referenced this pull request Mar 25, 2018
* upstream/master: (663 commits)
  Fix "command not found" error in perf test (pytorch#5982)
  add pip mkl-devel to the error message when mkl is found but mkl headers are not (pytorch#5984)
  Support batch LowerCholeskyTransform (pytorch#5980)
  Linearly interpolating upsampling fix (pytorch#5927)
  Store perf numbers in S3 (pytorch#5951)
  Modidy setup docs for Windows (pytorch#5981)
  Group Normalization (pytorch#5968)
  [distributions] Implement Power transform (pytorch#5976)
  Disable TestBottleneck test_cuda on Windows (pytorch#5977)
  Fix crash when cat-ing empty cuda tensors (pytorch#5971)
  Update no_unions flag for nanopb gen and update ONNX proto files (pytorch#5972)
  Expose gradients w.r.t. input & weight for conv1d, conv2d, conv3d in Python (pytorch#5408)
  Fixed non-determinate preprocessing on DataLoader (pytorch#4640)
  add AVX2 implementation for sigmoid function (pytorch#5010)
  Implement torch.util.bottleneck (pytorch#5216)
  Remove pragma once from cpp file (pytorch#5965)
  fix mvn docs (pytorch#5967)
  Fix incorrect rendering of Tensor.index_*_ doc examples. (pytorch#5969)
  Implement range for loop in script (pytorch#5827)
  Add windows doc (pytorch#5859)
  ...

# Conflicts:
#	aten/src/TH/generic/THTensorMath.c
#	torch/_tensor_docs.py
#	torch/csrc/generic/methods/TensorCompare.cwrap
@brando90
Copy link

is this on pytorch 3.1?

@ssnl
Copy link
Collaborator

ssnl commented Mar 25, 2018 via email

@brando90
Copy link

@ssnl then were is this?

@ssnl
Copy link
Collaborator

ssnl commented Mar 25, 2018 via email

@brando90
Copy link

@ssnl when will it be released. I'd be really nice to have determinism in the data loader.

@ssnl
Copy link
Collaborator

ssnl commented Mar 25, 2018 via email

@bombs-kim
Copy link
Contributor

bombs-kim commented Jun 12, 2018

@AlexanderRadionov @ssnl @apaszke I agree that having reproducible results is very important, but that doesn't mean that you have to synchronize workers in the dataloader. Multiprocessing are working non-deterministically in nature and if you try to have some unnatural synchronization there must be some performance trade-off whether it's trivial or not. Also the synchronization makes the code more complicated and harder to maintain.

I think what you really need is a deterministic dataset, rather than a deterministic data loader.

set_seed(GLOBAL_SEED)

class SomeDataset(Datset):
    def __init__(self, data):
        self.data = data
        self.random_numbers = make_random_numbers_in_advance(len(self.data))

    def __getitem__(self, idx):
        random_number = self.random_numbers[idx]
        # do some works with random_number and self.data[idx]
        return processed_item

OZA15015 pushed a commit to OZA15015/pruning that referenced this pull request Sep 6, 2020
Also:
* Single worker limitation not needed anymore, been fixed in PyTorch
  since v0.4.0 (pytorch/pytorch#4640)
* compress_classifier.py: If run in evaluation mode (--eval), enable
  deterministic mode.
* Call utils.set_deterministic at data loaders creation if
  deterministic argument is set (don't assume user calls it outside)
* Disable CUDNN benchmark mode in utils.set_deterministic
  (https://pytorch.org/docs/stable/notes/randomness.html#cudnn)
fangvv pushed a commit to fangvv/distiller that referenced this pull request May 23, 2023
Also:
* Single worker limitation not needed anymore, been fixed in PyTorch
  since v0.4.0 (pytorch/pytorch#4640)
* compress_classifier.py: If run in evaluation mode (--eval), enable
  deterministic mode.
* Call utils.set_deterministic at data loaders creation if
  deterministic argument is set (don't assume user calls it outside)
* Disable CUDNN benchmark mode in utils.set_deterministic
  (https://pytorch.org/docs/stable/notes/randomness.html#cudnn)
terran0213 pushed a commit to terran0213/scale-distil that referenced this pull request Oct 28, 2025
Also:
* Single worker limitation not needed anymore, been fixed in PyTorch
  since v0.4.0 (pytorch/pytorch#4640)
* compress_classifier.py: If run in evaluation mode (--eval), enable
  deterministic mode.
* Call utils.set_deterministic at data loaders creation if
  deterministic argument is set (don't assume user calls it outside)
* Disable CUDNN benchmark mode in utils.set_deterministic
  (https://pytorch.org/docs/stable/notes/randomness.html#cudnn)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants