Skip to content

Commit b06dd9b

Browse files
authored
Multi-GPU Support (#463)
For your standard NNUE network, this is **generally not beneficial** on consumer multi-GPU systems as the GPU-to-GPU bandwidth is too low unless you increase the batch size massively. Ideally you would have an NVLink system or be training an arch that is slow and quite dense.
1 parent 3ce1393 commit b06dd9b

File tree

23 files changed

+550
-106
lines changed

23 files changed

+550
-106
lines changed

crates/acyclib/src/device.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
pub mod cpu;
22
pub mod function;
3+
pub mod multi;
34
pub mod operation;
45
pub mod tensor;
56

crates/acyclib/src/device/multi.rs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
use std::sync::Arc;
2+
3+
use crate::device::{
4+
Device,
5+
cpu::{CpuError, CpuThread},
6+
tensor::TensorRef,
7+
};
8+
9+
pub trait MultiDeviceComm<D: Device> {
10+
fn new(devices: Vec<Arc<D>>) -> Self;
11+
12+
fn reduce_sum_into_rank(&self, rank: usize, buffers: &[TensorRef<D>]) -> Result<(), D::DeviceError>;
13+
14+
fn scatter_rank_into_rest(&self, rank: usize, buffers: &[TensorRef<D>]) -> Result<(), D::DeviceError>;
15+
}
16+
17+
pub trait MultiDevice: Device {
18+
type Comm: MultiDeviceComm<Self>;
19+
}
20+
21+
impl MultiDevice for CpuThread {
22+
type Comm = ();
23+
}
24+
25+
impl MultiDeviceComm<CpuThread> for () {
26+
fn new(_: Vec<Arc<CpuThread>>) -> Self {}
27+
28+
fn reduce_sum_into_rank(&self, rank: usize, buffers: &[TensorRef<CpuThread>]) -> Result<(), CpuError> {
29+
let mut buf = buffers[rank].dense_mut();
30+
31+
for (i, other) in buffers.iter().enumerate() {
32+
if rank != i {
33+
buf.add(1.0, &other.dense()).map_err(|_| CpuError)?;
34+
}
35+
}
36+
37+
Ok(())
38+
}
39+
40+
fn scatter_rank_into_rest(&self, rank: usize, buffers: &[TensorRef<CpuThread>]) -> Result<(), CpuError> {
41+
let buf = buffers[rank].dense();
42+
43+
for (i, other) in buffers.iter().enumerate() {
44+
if rank != i {
45+
other.dense_mut().copy_from(&buf)?;
46+
}
47+
}
48+
49+
Ok(())
50+
}
51+
}

crates/acyclib/src/graph.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
pub mod builder;
22
pub mod ir;
3+
pub mod like;
4+
pub mod multi;
35

46
use std::{collections::HashMap, fmt::Debug, sync::Arc};
57

crates/acyclib/src/graph/builder.rs

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,12 @@ use std::{
1111

1212
use crate::{
1313
dag::NodeId,
14-
device::{Device, function::Reduce, tensor::Shape},
14+
device::{
15+
Device,
16+
function::Reduce,
17+
multi::{MultiDevice, MultiDeviceComm},
18+
tensor::Shape,
19+
},
1520
graph::{
1621
Graph, GraphNodeId, GraphNodeIdTy,
1722
ir::{
@@ -21,6 +26,7 @@ use crate::{
2126
},
2227
passes::GraphIRPass,
2328
},
29+
multi::MultiDeviceGraph,
2430
},
2531
};
2632

@@ -121,8 +127,8 @@ where
121127
SparseAffineActivate: GraphIROperationCompilable<B>,
122128
Select: GraphIROperationCompilable<B>,
123129
{
124-
pub fn build(self, device: D) -> Graph<D> {
125-
let mut ir = self.ir.into_inner().unwrap();
130+
fn optimise(&mut self) {
131+
let mut ir = self.ir.try_lock().unwrap();
126132
let root = ir.root().unwrap();
127133

128134
if ir.get(root.idx).unwrap().ty().batched {
@@ -136,7 +142,7 @@ where
136142
let unoptim = format!("subgraph cluster_0 {{\nlabel=\"Unoptimised\";\n{opts}{unoptim}}}");
137143

138144
ir.optimise().unwrap();
139-
for pass in self.custom_passes.into_inner().unwrap() {
145+
for pass in self.custom_passes.try_lock().unwrap().iter() {
140146
ir.apply_any_pass(pass.as_ref()).unwrap();
141147
}
142148

@@ -147,14 +153,18 @@ where
147153
write!(&mut file, "digraph G {{\n{unoptim}\n{optim}}}").unwrap();
148154
} else {
149155
ir.optimise().unwrap();
150-
for pass in self.custom_passes.into_inner().unwrap() {
156+
for pass in self.custom_passes.try_lock().unwrap().iter() {
151157
ir.apply_any_pass(pass.as_ref()).unwrap();
152158
}
153159
}
154160

155161
if self.dump_ir_on_build {
156162
println!("{}", ir.formatted().unwrap());
157163
}
164+
}
165+
166+
fn compile(&self, device: D) -> Graph<D> {
167+
let ir = self.ir.try_lock().unwrap();
158168

159169
let graph = ir.compile(device).unwrap();
160170

@@ -187,4 +197,28 @@ where
187197

188198
graph
189199
}
200+
201+
pub fn build(mut self, device: D) -> Graph<D> {
202+
self.optimise();
203+
self.compile(device)
204+
}
205+
}
206+
207+
impl<D: Device<Marker = B> + MultiDevice, B: BackendMarker<Backend = D>> GraphBuilder<B>
208+
where
209+
SparseAffineActivate: GraphIROperationCompilable<B>,
210+
Select: GraphIROperationCompilable<B>,
211+
{
212+
pub fn build_multi(mut self, devices: Vec<D>) -> MultiDeviceGraph<D> {
213+
if devices.is_empty() {
214+
panic!("No devices specified for multi-device training!");
215+
}
216+
217+
self.optimise();
218+
219+
let graphs = devices.into_iter().map(|d| self.compile(d)).collect::<Vec<_>>();
220+
let comm = D::Comm::new(graphs.iter().map(|g| g.device()).collect());
221+
222+
MultiDeviceGraph { comm, graphs }
223+
}
190224
}

crates/acyclib/src/graph/ir/compile.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ where
9999
Ok(())
100100
}
101101

102-
pub fn compile(self, device: B::Backend) -> Result<Graph<B::Backend>, GraphIRCompileError> {
102+
pub fn compile(&self, device: B::Backend) -> Result<Graph<B::Backend>, GraphIRCompileError> {
103103
let root = self.root()?.idx;
104104
let root_data = self.get(root).unwrap().ty();
105105

crates/acyclib/src/graph/like.rs

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
use std::sync::Arc;
2+
3+
use crate::{
4+
device::{Device, OperationError, tensor::TensorRef},
5+
graph::{Graph, GraphNodeId},
6+
};
7+
8+
pub trait GraphLike<D: Device> {
9+
fn devices(&self) -> Vec<Arc<D>>;
10+
11+
fn primary(&self) -> &Graph<D>;
12+
13+
fn primary_mut(&mut self) -> &mut Graph<D>;
14+
15+
fn get_all(&self, id: GraphNodeId) -> Result<Vec<TensorRef<D>>, OperationError<<D as Device>::DeviceError>>;
16+
17+
fn get_output_value(&self) -> Result<f32, OperationError<D::DeviceError>>;
18+
19+
fn execute_fn(&mut self, name: &str) -> Result<(), OperationError<D::DeviceError>>;
20+
21+
fn reduce_sum_into_first(&self, buffers: &[TensorRef<D>]) -> Result<(), D::DeviceError>;
22+
23+
fn scatter_first_into_rest(&self, buffers: &[TensorRef<D>]) -> Result<(), D::DeviceError>;
24+
}
25+
26+
impl<D: Device> GraphLike<D> for Graph<D> {
27+
fn devices(&self) -> Vec<Arc<D>> {
28+
vec![self.device()]
29+
}
30+
31+
fn primary(&self) -> &Graph<D> {
32+
self
33+
}
34+
35+
fn primary_mut(&mut self) -> &mut Graph<D> {
36+
self
37+
}
38+
39+
fn get_all(&self, id: GraphNodeId) -> Result<Vec<TensorRef<D>>, OperationError<<D as Device>::DeviceError>> {
40+
self.get(id).map(|x| vec![x])
41+
}
42+
43+
fn get_output_value(&self) -> Result<f32, OperationError<<D as Device>::DeviceError>> {
44+
self.get_output_val()
45+
}
46+
47+
fn execute_fn(&mut self, name: &str) -> Result<(), OperationError<<D as Device>::DeviceError>> {
48+
self.execute(name)
49+
}
50+
51+
fn reduce_sum_into_first(&self, _: &[TensorRef<D>]) -> Result<(), <D as Device>::DeviceError> {
52+
Ok(())
53+
}
54+
55+
fn scatter_first_into_rest(&self, _: &[TensorRef<D>]) -> Result<(), <D as Device>::DeviceError> {
56+
Ok(())
57+
}
58+
}

crates/acyclib/src/graph/multi.rs

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
use std::sync::Arc;
2+
3+
use crate::{
4+
device::{
5+
Device, OperationError,
6+
multi::{MultiDevice, MultiDeviceComm},
7+
tensor::TensorRef,
8+
},
9+
graph::{Graph, GraphNodeId, like::GraphLike},
10+
};
11+
12+
pub struct MultiDeviceGraph<D: Device + MultiDevice> {
13+
pub(super) comm: D::Comm,
14+
pub(super) graphs: Vec<Graph<D>>,
15+
}
16+
17+
impl<D: Device + MultiDevice> GraphLike<D> for MultiDeviceGraph<D> {
18+
fn devices(&self) -> Vec<Arc<D>> {
19+
self.graphs.iter().map(Graph::device).collect()
20+
}
21+
22+
fn primary(&self) -> &Graph<D> {
23+
&self.graphs[0]
24+
}
25+
26+
fn primary_mut(&mut self) -> &mut Graph<D> {
27+
&mut self.graphs[0]
28+
}
29+
30+
fn get_all(&self, id: GraphNodeId) -> Result<Vec<TensorRef<D>>, OperationError<<D as Device>::DeviceError>> {
31+
self.graphs.iter().map(|g| g.get(id)).collect()
32+
}
33+
34+
fn get_output_value(&self) -> Result<f32, OperationError<<D as Device>::DeviceError>> {
35+
let mut sum = 0.0;
36+
37+
for g in &self.graphs {
38+
sum += g.get_output_val()?;
39+
}
40+
41+
Ok(sum)
42+
}
43+
44+
fn execute_fn(&mut self, name: &str) -> Result<(), OperationError<<D as Device>::DeviceError>> {
45+
for g in &mut self.graphs {
46+
g.execute(name)?;
47+
}
48+
49+
Ok(())
50+
}
51+
52+
fn reduce_sum_into_first(&self, buffers: &[TensorRef<D>]) -> Result<(), <D as Device>::DeviceError> {
53+
self.comm.reduce_sum_into_rank(0, buffers)
54+
}
55+
56+
fn scatter_first_into_rest(&self, buffers: &[TensorRef<D>]) -> Result<(), <D as Device>::DeviceError> {
57+
self.comm.scatter_rank_into_rest(0, buffers)
58+
}
59+
}

crates/acyclib/src/trainer.rs

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@ use schedule::TrainingSchedule;
99

1010
use std::{sync::mpsc, thread, time::Instant};
1111

12-
use crate::device::{Device, OperationError};
12+
use crate::{
13+
device::{Device, OperationError},
14+
graph::like::GraphLike,
15+
};
1316

1417
#[derive(Debug)]
1518
pub enum DataLoadingError {
@@ -23,6 +26,7 @@ pub enum TrainerError<D: Device> {
2326
DataLoadingError(DataLoadingError),
2427
GradientCalculationError(OperationError<D::DeviceError>),
2528
Unexpected(OperationError<D::DeviceError>),
29+
MoreDevicesThanBatchSize(usize, usize),
2630
IoError,
2731
}
2832

@@ -32,12 +36,12 @@ impl<D: Device> From<DataLoadingError> for TrainerError<D> {
3236
}
3337
}
3438

35-
pub struct Trainer<D: Device, O: OptimiserState<D>, S> {
36-
pub optimiser: Optimiser<D, O>,
39+
pub struct Trainer<D: Device, G: GraphLike<D>, O: OptimiserState<D>, S> {
40+
pub optimiser: Optimiser<D, G, O>,
3741
pub state: S,
3842
}
3943

40-
impl<D: Device, O: OptimiserState<D>, S> Trainer<D, O, S> {
44+
impl<D: Device, G: GraphLike<D>, O: OptimiserState<D>, S> Trainer<D, G, O, S> {
4145
pub fn train_custom(
4246
&mut self,
4347
schedule: TrainingSchedule,
@@ -52,7 +56,9 @@ impl<D: Device, O: OptimiserState<D>, S> Trainer<D, O, S> {
5256
let lr = schedule.lr_schedule;
5357
let steps = schedule.steps;
5458

55-
self.optimiser.graph.synchronise().unwrap();
59+
if self.optimiser.graph.devices().len() > steps.batch_size {
60+
return Err(TrainerError::MoreDevicesThanBatchSize(self.optimiser.graph.devices().len(), steps.batch_size));
61+
}
5662

5763
let (sender, receiver) = mpsc::sync_channel::<PreparedBatchHost>(32);
5864

@@ -88,7 +94,7 @@ impl<D: Device, O: OptimiserState<D>, S> Trainer<D, O, S> {
8894
let first_batch =
8995
receiver.recv().map_err(|_| TrainerError::DataLoadingError(DataLoadingError::NoBatchesReceived))?;
9096

91-
let mut batch_on_device = PreparedBatchDevice::new(self.optimiser.graph.device(), &first_batch)
97+
let mut batch_on_device = PreparedBatchDevice::new(self.optimiser.graph.devices(), &first_batch)
9298
.map_err(|_| TrainerError::DataLoadingError(DataLoadingError::CopyToDevice))?;
9399

94100
let mut batch_queued = true;
@@ -122,14 +128,14 @@ impl<D: Device, O: OptimiserState<D>, S> Trainer<D, O, S> {
122128

123129
batch_on_device.load_into_graph(&mut self.optimiser.graph)?;
124130

125-
fn step<D: Device, S: OptimiserState<D>>(
126-
optim: &mut Optimiser<D, S>,
131+
fn step<D: Device, G: GraphLike<D>, S: OptimiserState<D>>(
132+
optim: &mut Optimiser<D, G, S>,
127133
gradient_factor: f32,
128134
learning_rate: f32,
129135
) -> Result<(), OperationError<D::DeviceError>> {
130-
optim.graph.execute("zero_grads")?;
131-
optim.graph.execute("forward")?;
132-
optim.graph.execute("backward")?;
136+
optim.graph.execute_fn("zero_grads")?;
137+
optim.graph.execute_fn("forward")?;
138+
optim.graph.execute_fn("backward")?;
133139
optim.update(gradient_factor, learning_rate)
134140
}
135141

@@ -143,7 +149,7 @@ impl<D: Device, O: OptimiserState<D>, S> Trainer<D, O, S> {
143149
batch_queued = false;
144150
}
145151

146-
let error = self.optimiser.graph.get_output_val().unwrap() / this_batch_size as f32;
152+
let error = self.optimiser.graph.get_output_value().unwrap() / this_batch_size as f32;
147153

148154
running_loss += error;
149155
superbatch_positions += this_batch_size;

0 commit comments

Comments
 (0)