Skip to content

Commit d1204eb

Browse files
authored
Improve MNIST example (#484)
1 parent f270e3e commit d1204eb

File tree

3 files changed

+101
-59
lines changed

3 files changed

+101
-59
lines changed

crates/acyclib/src/device/tensor.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ impl<D: Device> TensorRef<D> {
3535
self.inner.try_read().unwrap()
3636
}
3737

38-
fn borrow_mut(&self) -> RwLockWriteGuard<'_, Tensor<D>> {
38+
pub(crate) fn borrow_mut(&self) -> RwLockWriteGuard<'_, Tensor<D>> {
3939
self.inner.try_write().unwrap()
4040
}
4141

crates/acyclib/src/trainer/dataloader.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,4 +254,19 @@ impl<D: Device> PreparedBatchDevice<D> {
254254

255255
Ok(())
256256
}
257+
258+
pub fn copy_into_graph<G: GraphLike<D>>(&self, graph: &mut G) -> Result<(), TrainerError<D>> {
259+
for (id, matrices) in &self.inputs {
260+
if let Some(idx) = graph.primary().input_idx(id) {
261+
let tensors =
262+
graph.get_all(GraphNodeId::new(idx, GraphNodeIdTy::Values)).map_err(TrainerError::Unexpected)?;
263+
264+
for (tensor, matrix) in tensors.into_iter().zip(matrices.iter()) {
265+
matrix.copy_into(&mut tensor.borrow_mut().values).map_err(TrainerError::Unexpected)?;
266+
}
267+
}
268+
}
269+
270+
Ok(())
271+
}
257272
}

examples/extra/mnist.rs

Lines changed: 85 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -4,71 +4,70 @@
44
/// - train-labels.idx1-ubyte
55
/// - t10k-images.idx3-ubyte
66
/// - t10k-labels.idx1-ubyte
7-
use std::{fs::File, io::Read, time::Instant};
7+
use std::{collections::HashMap, fs::File, io::Read};
88

99
use acyclib::{
1010
device::{OperationError, tensor::Shape},
11-
graph::{Graph, GraphNodeId, GraphNodeIdTy, Node, builder::GraphBuilder},
11+
graph::{Graph, GraphNodeId, GraphNodeIdTy, Node, builder::GraphBuilder, like::GraphLike},
12+
trainer::{
13+
DataLoadingError, Trainer,
14+
dataloader::{DataLoader, HostDenseMatrix, HostMatrix, PreparedBatchDevice, PreparedBatchHost},
15+
optimiser::{Optimiser, adam::AdamW},
16+
schedule::{TrainingSchedule, TrainingSteps},
17+
},
1218
};
13-
use bullet_lib::nn::{DeviceError, ExecutionContext};
19+
use bullet_lib::nn::{DeviceError, ExecutionContext, optimiser::AdamWParams};
1420

1521
fn main() -> Result<(), OperationError<DeviceError>> {
1622
let images = Images::new("data/mnist/train-images.idx3-ubyte");
1723
let labels = Labels::new("data/mnist/train-labels.idx1-ubyte");
18-
let batch_size = labels.vals.len() / 10;
19-
20-
let validation_images = Images::new("data/mnist/t10k-images.idx3-ubyte");
21-
let validation_labels = Labels::new("data/mnist/t10k-labels.idx1-ubyte");
22-
let validation_batch_size = validation_labels.vals.len() / 10;
23-
24-
assert_eq!(batch_size, images.batch_size());
25-
assert_eq!(validation_batch_size, validation_images.batch_size());
26-
27-
let (mut graph, outputs) = make_model(&images);
28-
29-
graph.load_from_file("checkpoints/mnist.bin", false)?;
30-
31-
graph.get_input("inputs").dense_mut().load_from_slice(Some(batch_size), &images.vals)?;
32-
graph.get_input("targets").dense_mut().load_from_slice(Some(batch_size), &labels.vals)?;
33-
34-
let t = Instant::now();
35-
let lr = 0.0001;
36-
37-
for epoch in 1..=10 {
38-
graph.zero_grads()?;
39-
graph.forward()?;
40-
graph.backward()?;
4124

42-
if epoch % 10 == 0 {
43-
let valid_acc = calculate_accuracy(&mut graph, outputs, &validation_images, &validation_labels)?;
44-
let train_acc = calculate_accuracy(&mut graph, outputs, &images, &labels)?;
45-
46-
println!(
47-
"epoch {epoch} train accuracy {train_acc:.2}% validation accuarcy {valid_acc:.2}% time {:.3}s",
48-
t.elapsed().as_secs_f32()
49-
);
50-
51-
graph.get_input("inputs").dense_mut().load_from_slice(Some(batch_size), &images.vals)?;
52-
graph.get_input("targets").dense_mut().load_from_slice(Some(batch_size), &labels.vals)?;
53-
}
54-
55-
for id in &graph.weight_ids() {
56-
let idx = graph.weight_idx(id).unwrap();
57-
let weight = graph.get(GraphNodeId::new(idx, GraphNodeIdTy::Values)).unwrap();
58-
59-
if let Ok(grd) = graph.get(GraphNodeId::new(idx, GraphNodeIdTy::Gradients)) {
60-
weight.dense_mut().add(-lr, &*grd.dense())?;
61-
}
62-
}
63-
}
25+
let valid_images = Images::new("data/mnist/t10k-images.idx3-ubyte");
26+
let valid_labels = Labels::new("data/mnist/t10k-labels.idx1-ubyte");
27+
28+
let (graph, outputs) = make_model(&images);
29+
let params = AdamWParams { min_weight: -1000.0, max_weight: 1000.0, ..Default::default() };
30+
let optimiser = Optimiser::<_, _, AdamW<_>>::new(graph, params)?;
31+
let mut trainer = Trainer { optimiser, state: () };
32+
33+
let schedule = TrainingSchedule {
34+
steps: TrainingSteps {
35+
batch_size: labels.vals.len() / 10,
36+
batches_per_superbatch: 100,
37+
start_superbatch: 1,
38+
end_superbatch: 20,
39+
},
40+
log_rate: 10,
41+
lr_schedule: Box::new(|_, sb| 0.001 * 0.9f32.powi(sb as i32 - 1)),
42+
};
43+
44+
let dataloader = ImageDataLoader { images: images.clone(), labels: labels.clone() };
45+
46+
let valid = prepare(&valid_images, &valid_labels);
47+
let valid_gpu = PreparedBatchDevice::new(trainer.optimiser.graph.devices(), &valid)?;
48+
49+
trainer
50+
.train_custom(
51+
schedule,
52+
dataloader,
53+
|_, _, _, _| {},
54+
|trainer, _| {
55+
let graph = &mut trainer.optimiser.graph;
56+
let train_accuracy = calculate_accuracy(graph, outputs, &images, &labels).unwrap();
57+
valid_gpu.copy_into_graph(graph).unwrap();
58+
let valid_accuracy = calculate_accuracy(graph, outputs, &valid_images, &valid_labels).unwrap();
59+
println!("Train accuracy {train_accuracy:.2}% validation accuracy {valid_accuracy:.2}%");
60+
},
61+
)
62+
.unwrap();
6463

6564
Ok(())
6665
}
6766

6867
fn make_model(images: &Images) -> (Graph<ExecutionContext>, Node) {
6968
let builder = GraphBuilder::default();
7069

71-
let inputs = builder.new_dense_input("inputs", Shape::new(images.rows, images.cols));
70+
let inputs = builder.new_dense_input("inputs", images.shape);
7271
let targets = builder.new_dense_input("targets", Shape::new(10, 1));
7372

7473
let l0 = builder.new_affine("l0", 28 * 28, 128);
@@ -94,8 +93,6 @@ fn calculate_accuracy(
9493
labels: &Labels,
9594
) -> Result<f32, OperationError<DeviceError>> {
9695
let batch_size = images.batch_size();
97-
graph.get_input("inputs").dense_mut().load_from_slice(Some(batch_size), &images.vals)?;
98-
graph.get_input("targets").dense_mut().load_from_slice(Some(batch_size), &labels.vals)?;
9996
let _ = graph.forward()?;
10097

10198
let vals = graph.get(GraphNodeId::new(output_node.idx(), GraphNodeIdTy::Values))?.borrow().get_dense_vals()?;
@@ -118,15 +115,45 @@ fn calculate_accuracy(
118115
}
119116
}
120117

121-
assert_eq!(batch_size, labels.indices.len());
122-
123118
Ok(100.0 * correct as f32 / batch_size as f32)
124119
}
125120

121+
fn prepare(images: &Images, labels: &Labels) -> PreparedBatchHost {
122+
let batch_size = labels.vals.len() / 10;
123+
let mut inputs = HashMap::new();
124+
125+
let wrap = |x: &Vec<f32>, s| HostMatrix::Dense(HostDenseMatrix::new(x.clone(), Some(batch_size), s));
126+
127+
let x = wrap(&images.vals, images.shape);
128+
inputs.insert("inputs".to_string(), x);
129+
130+
let y = wrap(&labels.vals, Shape::new(10, 1));
131+
inputs.insert("targets".to_string(), y);
132+
133+
PreparedBatchHost { batch_size, inputs }
134+
}
135+
136+
struct ImageDataLoader {
137+
images: Images,
138+
labels: Labels,
139+
}
140+
141+
impl DataLoader for ImageDataLoader {
142+
type Error = DataLoadingError;
143+
144+
fn map_batches<F: FnMut(PreparedBatchHost) -> bool>(self, batch_size: usize, mut f: F) -> Result<(), Self::Error> {
145+
assert_eq!(batch_size, self.labels.vals.len() / 10);
146+
147+
while !f(prepare(&self.images, &self.labels)) {}
148+
149+
Ok(())
150+
}
151+
}
152+
153+
#[derive(Clone)]
126154
struct Images {
127155
vals: Vec<f32>,
128-
rows: usize,
129-
cols: usize,
156+
shape: Shape,
130157
}
131158

132159
impl Images {
@@ -152,16 +179,16 @@ impl Images {
152179

153180
Self {
154181
vals: bytes.iter().map(|&x| f32::from(x) / f32::from(u8::MAX)).collect(),
155-
rows: rows as usize,
156-
cols: cols as usize,
182+
shape: Shape::new(rows as usize, cols as usize),
157183
}
158184
}
159185

160186
pub fn batch_size(&self) -> usize {
161-
self.vals.len() / (self.rows * self.cols)
187+
self.vals.len() / self.shape.size()
162188
}
163189
}
164190

191+
#[derive(Clone)]
165192
struct Labels {
166193
vals: Vec<f32>,
167194
indices: Vec<u8>,

0 commit comments

Comments
 (0)