-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
high priorityoncall: distributedAdd this issue/PR to distributed oncall triage queueAdd this issue/PR to distributed oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
~/Workspace$ pip show torch
Name: torch
Version: 1.3.1
Code for allreduce sparse tensors
def run_sparse(rank, size):
a = torch.randn(5) * (rank + 1)
mask = a > 0 # sparse mask
i = mask.nonzero() # value indexes
v = a[mask] # sparse values
t = torch.sparse.FloatTensor(i.t(), v, a.size())
pprint("Before\t", t.to_dense())
dist.all_reduce(t)
pprint("After\t", t.to_dense())Results
~/Workspace$ python main.py
Message from 0 Before tensor([0.1167, 0.0000, 0.8951, 0.4808, 0.8937])
Message from 1 Before tensor([0.2334, 0.0000, 1.7901, 0.9616, 1.7874])
Message from 1 After tensor([0.2334, 0.0000, 1.7901, 0.9616, 1.7874])
Message from 0 After tensor([0.1167, 0.0000, 0.8951, 0.4808, 0.8937])
Since the usage (sparse allreduce) is not documented yet #1303, I assume the usage is the same as dense tensors. However, it does not yield the expected result.
The full code snippet is attached below
import os, time
import torch
import torch.distributed as dist
from torch.multiprocessing import Process
def pprint(*msg):
print("Message from %s" % rank, *msg)
def run(rank, size):
t = torch.ones(5) * (rank + 1)
pprint("Before\t", t)
dist.all_reduce(t)
pprint("After\t", t)
def run_sparse(rank, size):
a = torch.randn(5) * (rank + 1)
mask = a > 0 # sparse mask
i = mask.nonzero() # value indexes
v = a[mask] # sparse values
t = torch.sparse.FloatTensor(i.t(), v, a.size())
pprint("Before\t", t.to_dense())
dist.all_reduce(t)
pprint("After\t", t.to_dense())
def init_process(rank, size, fn, backend='gloo'):
""" Initialize the distributed environment. """
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(backend, rank=rank, world_size=size)
fn(rank, size)
if __name__ == "__main__":
size = 2
processes = []
for rank in range(size):
p = Process(target=init_process, args=(rank, size, run_sparse))
p.start()
processes.append(p)
for p in processes:
p.join()cc @ezyang @gchanan @zou3519 @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @xush6528
Metadata
Metadata
Assignees
Labels
high priorityoncall: distributedAdd this issue/PR to distributed oncall triage queueAdd this issue/PR to distributed oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module