@@ -9,7 +9,10 @@ use schedule::TrainingSchedule;
99
1010use std:: { sync:: mpsc, thread, time:: Instant } ;
1111
12- use crate :: device:: { Device , OperationError } ;
12+ use crate :: {
13+ device:: { Device , OperationError } ,
14+ graph:: like:: GraphLike ,
15+ } ;
1316
1417#[ derive( Debug ) ]
1518pub enum DataLoadingError {
@@ -23,6 +26,7 @@ pub enum TrainerError<D: Device> {
2326 DataLoadingError ( DataLoadingError ) ,
2427 GradientCalculationError ( OperationError < D :: DeviceError > ) ,
2528 Unexpected ( OperationError < D :: DeviceError > ) ,
29+ MoreDevicesThanBatchSize ( usize , usize ) ,
2630 IoError ,
2731}
2832
@@ -32,12 +36,12 @@ impl<D: Device> From<DataLoadingError> for TrainerError<D> {
3236 }
3337}
3438
35- pub struct Trainer < D : Device , O : OptimiserState < D > , S > {
36- pub optimiser : Optimiser < D , O > ,
39+ pub struct Trainer < D : Device , G : GraphLike < D > , O : OptimiserState < D > , S > {
40+ pub optimiser : Optimiser < D , G , O > ,
3741 pub state : S ,
3842}
3943
40- impl < D : Device , O : OptimiserState < D > , S > Trainer < D , O , S > {
44+ impl < D : Device , G : GraphLike < D > , O : OptimiserState < D > , S > Trainer < D , G , O , S > {
4145 pub fn train_custom (
4246 & mut self ,
4347 schedule : TrainingSchedule ,
@@ -52,7 +56,9 @@ impl<D: Device, O: OptimiserState<D>, S> Trainer<D, O, S> {
5256 let lr = schedule. lr_schedule ;
5357 let steps = schedule. steps ;
5458
55- self . optimiser . graph . synchronise ( ) . unwrap ( ) ;
59+ if self . optimiser . graph . devices ( ) . len ( ) > steps. batch_size {
60+ return Err ( TrainerError :: MoreDevicesThanBatchSize ( self . optimiser . graph . devices ( ) . len ( ) , steps. batch_size ) ) ;
61+ }
5662
5763 let ( sender, receiver) = mpsc:: sync_channel :: < PreparedBatchHost > ( 32 ) ;
5864
@@ -88,7 +94,7 @@ impl<D: Device, O: OptimiserState<D>, S> Trainer<D, O, S> {
8894 let first_batch =
8995 receiver. recv ( ) . map_err ( |_| TrainerError :: DataLoadingError ( DataLoadingError :: NoBatchesReceived ) ) ?;
9096
91- let mut batch_on_device = PreparedBatchDevice :: new ( self . optimiser . graph . device ( ) , & first_batch)
97+ let mut batch_on_device = PreparedBatchDevice :: new ( self . optimiser . graph . devices ( ) , & first_batch)
9298 . map_err ( |_| TrainerError :: DataLoadingError ( DataLoadingError :: CopyToDevice ) ) ?;
9399
94100 let mut batch_queued = true ;
@@ -122,14 +128,14 @@ impl<D: Device, O: OptimiserState<D>, S> Trainer<D, O, S> {
122128
123129 batch_on_device. load_into_graph ( & mut self . optimiser . graph ) ?;
124130
125- fn step < D : Device , S : OptimiserState < D > > (
126- optim : & mut Optimiser < D , S > ,
131+ fn step < D : Device , G : GraphLike < D > , S : OptimiserState < D > > (
132+ optim : & mut Optimiser < D , G , S > ,
127133 gradient_factor : f32 ,
128134 learning_rate : f32 ,
129135 ) -> Result < ( ) , OperationError < D :: DeviceError > > {
130- optim. graph . execute ( "zero_grads" ) ?;
131- optim. graph . execute ( "forward" ) ?;
132- optim. graph . execute ( "backward" ) ?;
136+ optim. graph . execute_fn ( "zero_grads" ) ?;
137+ optim. graph . execute_fn ( "forward" ) ?;
138+ optim. graph . execute_fn ( "backward" ) ?;
133139 optim. update ( gradient_factor, learning_rate)
134140 }
135141
@@ -143,7 +149,7 @@ impl<D: Device, O: OptimiserState<D>, S> Trainer<D, O, S> {
143149 batch_queued = false ;
144150 }
145151
146- let error = self . optimiser . graph . get_output_val ( ) . unwrap ( ) / this_batch_size as f32 ;
152+ let error = self . optimiser . graph . get_output_value ( ) . unwrap ( ) / this_batch_size as f32 ;
147153
148154 running_loss += error;
149155 superbatch_positions += this_batch_size;
0 commit comments