Skip to content

Allreduce for sparse tensors is not working #31413

@Lyken17

Description

@Lyken17
~/Workspace$ pip show torch
Name: torch
Version: 1.3.1

Code for allreduce sparse tensors

def run_sparse(rank, size):
    a = torch.randn(5) * (rank + 1)
    mask = a > 0  # sparse mask
    i = mask.nonzero()  # value indexes
    v = a[mask]  # sparse values
    t = torch.sparse.FloatTensor(i.t(), v, a.size())

    pprint("Before\t", t.to_dense())
    dist.all_reduce(t)
    pprint("After\t", t.to_dense())

Results

~/Workspace$ python main.py
Message from 0 Before	 tensor([0.1167, 0.0000, 0.8951, 0.4808, 0.8937])
Message from 1 Before	 tensor([0.2334, 0.0000, 1.7901, 0.9616, 1.7874])
Message from 1 After	 tensor([0.2334, 0.0000, 1.7901, 0.9616, 1.7874])
Message from 0 After	 tensor([0.1167, 0.0000, 0.8951, 0.4808, 0.8937])

Since the usage (sparse allreduce) is not documented yet #1303, I assume the usage is the same as dense tensors. However, it does not yield the expected result.

The full code snippet is attached below

import os, time
import torch
import torch.distributed as dist
from torch.multiprocessing import Process

def pprint(*msg):
    print("Message from %s" % rank, *msg)

def run(rank, size):
    t = torch.ones(5) * (rank + 1)
    pprint("Before\t", t)
    dist.all_reduce(t)
    pprint("After\t", t)

def run_sparse(rank, size):
    a = torch.randn(5) * (rank + 1)
    mask = a > 0  # sparse mask
    i = mask.nonzero()  # value indexes
    v = a[mask]  # sparse values
    t = torch.sparse.FloatTensor(i.t(), v, a.size())

    pprint("Before\t", t.to_dense())
    dist.all_reduce(t)
    pprint("After\t", t.to_dense())


def init_process(rank, size, fn, backend='gloo'):
    """ Initialize the distributed environment. """
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(backend, rank=rank, world_size=size)
    fn(rank, size)


if __name__ == "__main__":
    size = 2
    processes = []
    for rank in range(size):
        p = Process(target=init_process, args=(rank, size, run_sparse))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

cc @ezyang @gchanan @zou3519 @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @xush6528

Metadata

Metadata

Assignees

Labels

high priorityoncall: distributedAdd this issue/PR to distributed oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions