44/// - train-labels.idx1-ubyte
55/// - t10k-images.idx3-ubyte
66/// - t10k-labels.idx1-ubyte
7- use std:: { fs :: File , io :: Read , time :: Instant } ;
7+ use std:: { collections :: HashMap , fs :: File , io :: Read } ;
88
99use acyclib:: {
1010 device:: { OperationError , tensor:: Shape } ,
11- graph:: { Graph , GraphNodeId , GraphNodeIdTy , Node , builder:: GraphBuilder } ,
11+ graph:: { Graph , GraphNodeId , GraphNodeIdTy , Node , builder:: GraphBuilder , like:: GraphLike } ,
12+ trainer:: {
13+ DataLoadingError , Trainer ,
14+ dataloader:: { DataLoader , HostDenseMatrix , HostMatrix , PreparedBatchDevice , PreparedBatchHost } ,
15+ optimiser:: { Optimiser , adam:: AdamW } ,
16+ schedule:: { TrainingSchedule , TrainingSteps } ,
17+ } ,
1218} ;
13- use bullet_lib:: nn:: { DeviceError , ExecutionContext } ;
19+ use bullet_lib:: nn:: { DeviceError , ExecutionContext , optimiser :: AdamWParams } ;
1420
1521fn main ( ) -> Result < ( ) , OperationError < DeviceError > > {
1622 let images = Images :: new ( "data/mnist/train-images.idx3-ubyte" ) ;
1723 let labels = Labels :: new ( "data/mnist/train-labels.idx1-ubyte" ) ;
18- let batch_size = labels. vals . len ( ) / 10 ;
19-
20- let validation_images = Images :: new ( "data/mnist/t10k-images.idx3-ubyte" ) ;
21- let validation_labels = Labels :: new ( "data/mnist/t10k-labels.idx1-ubyte" ) ;
22- let validation_batch_size = validation_labels. vals . len ( ) / 10 ;
23-
24- assert_eq ! ( batch_size, images. batch_size( ) ) ;
25- assert_eq ! ( validation_batch_size, validation_images. batch_size( ) ) ;
26-
27- let ( mut graph, outputs) = make_model ( & images) ;
28-
29- graph. load_from_file ( "checkpoints/mnist.bin" , false ) ?;
30-
31- graph. get_input ( "inputs" ) . dense_mut ( ) . load_from_slice ( Some ( batch_size) , & images. vals ) ?;
32- graph. get_input ( "targets" ) . dense_mut ( ) . load_from_slice ( Some ( batch_size) , & labels. vals ) ?;
33-
34- let t = Instant :: now ( ) ;
35- let lr = 0.0001 ;
36-
37- for epoch in 1 ..=10 {
38- graph. zero_grads ( ) ?;
39- graph. forward ( ) ?;
40- graph. backward ( ) ?;
4124
42- if epoch % 10 == 0 {
43- let valid_acc = calculate_accuracy ( & mut graph, outputs, & validation_images, & validation_labels) ?;
44- let train_acc = calculate_accuracy ( & mut graph, outputs, & images, & labels) ?;
45-
46- println ! (
47- "epoch {epoch} train accuracy {train_acc:.2}% validation accuarcy {valid_acc:.2}% time {:.3}s" ,
48- t. elapsed( ) . as_secs_f32( )
49- ) ;
50-
51- graph. get_input ( "inputs" ) . dense_mut ( ) . load_from_slice ( Some ( batch_size) , & images. vals ) ?;
52- graph. get_input ( "targets" ) . dense_mut ( ) . load_from_slice ( Some ( batch_size) , & labels. vals ) ?;
53- }
54-
55- for id in & graph. weight_ids ( ) {
56- let idx = graph. weight_idx ( id) . unwrap ( ) ;
57- let weight = graph. get ( GraphNodeId :: new ( idx, GraphNodeIdTy :: Values ) ) . unwrap ( ) ;
58-
59- if let Ok ( grd) = graph. get ( GraphNodeId :: new ( idx, GraphNodeIdTy :: Gradients ) ) {
60- weight. dense_mut ( ) . add ( -lr, & * grd. dense ( ) ) ?;
61- }
62- }
63- }
25+ let valid_images = Images :: new ( "data/mnist/t10k-images.idx3-ubyte" ) ;
26+ let valid_labels = Labels :: new ( "data/mnist/t10k-labels.idx1-ubyte" ) ;
27+
28+ let ( graph, outputs) = make_model ( & images) ;
29+ let params = AdamWParams { min_weight : -1000.0 , max_weight : 1000.0 , ..Default :: default ( ) } ;
30+ let optimiser = Optimiser :: < _ , _ , AdamW < _ > > :: new ( graph, params) ?;
31+ let mut trainer = Trainer { optimiser, state : ( ) } ;
32+
33+ let schedule = TrainingSchedule {
34+ steps : TrainingSteps {
35+ batch_size : labels. vals . len ( ) / 10 ,
36+ batches_per_superbatch : 100 ,
37+ start_superbatch : 1 ,
38+ end_superbatch : 20 ,
39+ } ,
40+ log_rate : 10 ,
41+ lr_schedule : Box :: new ( |_, sb| 0.001 * 0.9f32 . powi ( sb as i32 - 1 ) ) ,
42+ } ;
43+
44+ let dataloader = ImageDataLoader { images : images. clone ( ) , labels : labels. clone ( ) } ;
45+
46+ let valid = prepare ( & valid_images, & valid_labels) ;
47+ let valid_gpu = PreparedBatchDevice :: new ( trainer. optimiser . graph . devices ( ) , & valid) ?;
48+
49+ trainer
50+ . train_custom (
51+ schedule,
52+ dataloader,
53+ |_, _, _, _| { } ,
54+ |trainer, _| {
55+ let graph = & mut trainer. optimiser . graph ;
56+ let train_accuracy = calculate_accuracy ( graph, outputs, & images, & labels) . unwrap ( ) ;
57+ valid_gpu. copy_into_graph ( graph) . unwrap ( ) ;
58+ let valid_accuracy = calculate_accuracy ( graph, outputs, & valid_images, & valid_labels) . unwrap ( ) ;
59+ println ! ( "Train accuracy {train_accuracy:.2}% validation accuracy {valid_accuracy:.2}%" ) ;
60+ } ,
61+ )
62+ . unwrap ( ) ;
6463
6564 Ok ( ( ) )
6665}
6766
6867fn make_model ( images : & Images ) -> ( Graph < ExecutionContext > , Node ) {
6968 let builder = GraphBuilder :: default ( ) ;
7069
71- let inputs = builder. new_dense_input ( "inputs" , Shape :: new ( images. rows , images . cols ) ) ;
70+ let inputs = builder. new_dense_input ( "inputs" , images. shape ) ;
7271 let targets = builder. new_dense_input ( "targets" , Shape :: new ( 10 , 1 ) ) ;
7372
7473 let l0 = builder. new_affine ( "l0" , 28 * 28 , 128 ) ;
@@ -94,8 +93,6 @@ fn calculate_accuracy(
9493 labels : & Labels ,
9594) -> Result < f32 , OperationError < DeviceError > > {
9695 let batch_size = images. batch_size ( ) ;
97- graph. get_input ( "inputs" ) . dense_mut ( ) . load_from_slice ( Some ( batch_size) , & images. vals ) ?;
98- graph. get_input ( "targets" ) . dense_mut ( ) . load_from_slice ( Some ( batch_size) , & labels. vals ) ?;
9996 let _ = graph. forward ( ) ?;
10097
10198 let vals = graph. get ( GraphNodeId :: new ( output_node. idx ( ) , GraphNodeIdTy :: Values ) ) ?. borrow ( ) . get_dense_vals ( ) ?;
@@ -118,15 +115,45 @@ fn calculate_accuracy(
118115 }
119116 }
120117
121- assert_eq ! ( batch_size, labels. indices. len( ) ) ;
122-
123118 Ok ( 100.0 * correct as f32 / batch_size as f32 )
124119}
125120
121+ fn prepare ( images : & Images , labels : & Labels ) -> PreparedBatchHost {
122+ let batch_size = labels. vals . len ( ) / 10 ;
123+ let mut inputs = HashMap :: new ( ) ;
124+
125+ let wrap = |x : & Vec < f32 > , s| HostMatrix :: Dense ( HostDenseMatrix :: new ( x. clone ( ) , Some ( batch_size) , s) ) ;
126+
127+ let x = wrap ( & images. vals , images. shape ) ;
128+ inputs. insert ( "inputs" . to_string ( ) , x) ;
129+
130+ let y = wrap ( & labels. vals , Shape :: new ( 10 , 1 ) ) ;
131+ inputs. insert ( "targets" . to_string ( ) , y) ;
132+
133+ PreparedBatchHost { batch_size, inputs }
134+ }
135+
136+ struct ImageDataLoader {
137+ images : Images ,
138+ labels : Labels ,
139+ }
140+
141+ impl DataLoader for ImageDataLoader {
142+ type Error = DataLoadingError ;
143+
144+ fn map_batches < F : FnMut ( PreparedBatchHost ) -> bool > ( self , batch_size : usize , mut f : F ) -> Result < ( ) , Self :: Error > {
145+ assert_eq ! ( batch_size, self . labels. vals. len( ) / 10 ) ;
146+
147+ while !f ( prepare ( & self . images , & self . labels ) ) { }
148+
149+ Ok ( ( ) )
150+ }
151+ }
152+
153+ #[ derive( Clone ) ]
126154struct Images {
127155 vals : Vec < f32 > ,
128- rows : usize ,
129- cols : usize ,
156+ shape : Shape ,
130157}
131158
132159impl Images {
@@ -152,16 +179,16 @@ impl Images {
152179
153180 Self {
154181 vals : bytes. iter ( ) . map ( |& x| f32:: from ( x) / f32:: from ( u8:: MAX ) ) . collect ( ) ,
155- rows : rows as usize ,
156- cols : cols as usize ,
182+ shape : Shape :: new ( rows as usize , cols as usize ) ,
157183 }
158184 }
159185
160186 pub fn batch_size ( & self ) -> usize {
161- self . vals . len ( ) / ( self . rows * self . cols )
187+ self . vals . len ( ) / self . shape . size ( )
162188 }
163189}
164190
191+ #[ derive( Clone ) ]
165192struct Labels {
166193 vals : Vec < f32 > ,
167194 indices : Vec < u8 > ,
0 commit comments