You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Today there are mainly three ways to scale up distributed training: Data Parallel, Tensor Parallel and Pipeline Parallel. Each of them works on a separate dimension where solutions have been built independently (i.e. PyTorch DDP, FSDP, ShardedTensor, PiPPy, etc.). When training really large models, users would like to use these technologies together (i.e. 3-D Parallelism), while the interoperability of the existing solutions are not great and often hard to use (i.e. users might want arbitrary combinations of the data parallel, tensor parallel and pipeline parallel). This is becoming an issue for users and one of the biggest reasons is that there’s no common abstractions that build the bridge between different parallelism strategies.
21
+
Today there are mainly three ways to scale up distributed training: Data Parallel, Tensor Parallel and Pipeline Parallel. Each of them works on a separate dimension where solutions have been built independently (i.e. PyTorch DDP, FSDP, ShardedTensor, PiPPy, etc.). When training really large models, users would like to use these technologies together (i.e. 3-D Parallelism), while the interoperability of the existing solutions are not great and often hard to use (i.e. users might want arbitrary combinations of the data parallel, tensor parallel and pipeline parallel). This is becoming an issue for users and one of the biggest reasons is that there is no common abstraction that build the bridge between different parallelism strategies.
22
22
23
-
An ideal scenario is that users could just build their models like in a single node/device, without worrying about how to do distributed training in a cluster, and our solutions could help them run distributed training in an efficient manner. For example, researchers just need to build their big transformer model, and PyTorch Distributed automatically figures out how to split the model and run pipeline parallel across different nodes, how to run data parallel and tensor parallel within each node. In order to achieve this, we need some common abstractions to represent data distribution and run the distributed computation.
23
+
An ideal scenario is that users could build their distributed program just like authoring in a single node/device, without worrying about how to do distributed training in a cluster, and our solutions could help them run distributed training in an efficient manner. For example, researchers just need to build the big transformer model, and PyTorch Distributed automatically figures out how to split the model and run pipeline parallel across different nodes, how to run data parallel and tensor parallel within each node. In order to achieve this, we need some common abstractions to distribute tensor values and distributed computations accordingly.
24
24
25
-
There're many recent works that working on tensor level parallelism to provide common abstractions, see the `Related Works` in the last section for more details. Inspired by [GSPMD](https://arxiv.org/pdf/2105.04663.pdf), [Oneflow](https://arxiv.org/pdf/2110.15032.pdf) and [TF’s DTensor](https://www.tensorflow.org/guide/dtensor_overview), we introduce a DistributedTensor concept to represent generic data distributions across hosts. DistributedTensor is the next evolution of ShardedTensor and provides basic abstractions to distribute storage and compute. It serves as one of the basic building blocks for distributed program translations and describes the layout of a distributed training program. With the DistributedTensor abstraction, we can seamlessly build parallelism strategies such as tensor parallelism, DDP and FSDP.
25
+
There're many recent works that working on tensor level parallelism to provide common abstractions, see the `Related Works` in the last section for more details. Inspired by [GSPMD](https://arxiv.org/pdf/2105.04663.pdf), [Oneflow](https://arxiv.org/pdf/2110.15032.pdf) and [TF’s DTensor](https://www.tensorflow.org/guide/dtensor_overview), we introduce DistributedTensor as the next generation of ShardedTensor to provide basic abstractions for distributing storage and computation. It serves as one of the basic building blocks for distributed program translations and describes the layout of a distributed training program. With the DistributedTensor abstraction, we can seamlessly build parallelism strategies such as tensor parallelism, DDP and FSDP.
26
26
27
27
## Value Propsition
28
28
29
29
DistributedTensor primarily:
30
-
- Offers a uniform way to save/load state dict during checkpointing, even when there’re complex data distribution strategies such as combining tensor parallelism with parameter sharding in FSDP.
31
-
-Could natively offer Tensor Parallelism solution in eager mode, just like our current ShardedTensor solution. Moreover, it gives additional flexibility for advanced users who want to mix sharding and replication.
32
-
-Could be the entry point of a SPMD programming model for ML System Engineers, providing good UX to mix up different types of parallelism, and could be used as a fundamental building block of a compilerbased distributed training.
30
+
- Offers a uniform way to save/load `state_dict`during checkpointing, even when there’re complex tensor storage distribution strategies such as combining tensor parallelism with parameter sharding in FSDP.
31
+
-Enables Tensor Parallelism in eager mode. Compared to ShardedTensor, DistributedTensor allows additional flexibility to mix sharding and replication.
32
+
-Serves as the entry point of an SPMD programming model and the foundational building block for compiler-based distributed training.
Users can use DistributedTensor tensor constructors directly to create a distributed tensor (i.e. `distributed.ones/empty`), but for existing modules like nn.Linear that are already having torch.Tensor as parameters, how to make them distributed parameters? We offer a way to directly distribute a torch.Tensor and a module level APIs to directly distribute the module parameters. Below is the high level API we introduce:
84
+
Users can use DistributedTensor tensor constructors directly to create a distributed tensor (i.e. `distributed.ones/empty`), but for existing modules like `nn.Linear` that are already having `torch.Tensor` as parameters, how to make them distributed parameters? We offer a way to directly distribute a `torch.Tensor` and a module level APIs to directly distribute the module parameters. Below is the high level API we introduce:
DistributedTensor provides efficient solutions for cases like Tensor Parallelism. But when using the DTensor's replication in a data parallel fashion, it might become observably slow compared to our existing solutions like DDP/FSDP. This is mainly because existing solutions like DDP/FSDP could have the global view of entire model architecture, thus could optimize for data parallel specifically, i.e. collective fusion and computation overlap, etc. DistributedTensor itself is only a Tensor-like object and only knows its local computation operation, it does not know the subsequent operations that happened afterwards.
137
+
DistributedTensor provides efficient solutions for cases like Tensor Parallelism. But when using the DTensor's replication in a data parallel fashion, it might become observably slower compared to our existing solutions like DDP/FSDP. This is mainly because mainly because DDP/FSDP have a global view of the entire model architecture, thus could optimize for data parallel specifically, i.e. collective fusion and computation overlap, etc. In contract, DistributedTensor as a Tensor-like object can only optimize within individual tensor operations.
138
138
139
-
In order to make the performance on par when using DistributedTensor directly to do data parallel training, DistributedTensor also needs the global view to do things like communication optimization. We are exploring a compilerbased solution accompanied with DistributedTensor so that we could run optimizations on top of it, which will be shared later.
139
+
To improve efficiency of DistributedTensor-based data parallel training, we are exploring a compiler-based solution on top of DistributedTensor, which can extract graph information from user programs to expose more performance optimization opportunities.
0 commit comments