Skip to content

Commit e92b7c5

Browse files
committed
added conditions for merge and split
1 parent 0d6aff8 commit e92b7c5

File tree

5 files changed

+53
-7
lines changed

5 files changed

+53
-7
lines changed

src/nn/layers/activation.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,41 @@ struct ReLU : public Layer {
3939
}
4040
};
4141

42+
struct Linear : public Layer {
43+
Layer* prev;
44+
GradientOperation grad_op;
45+
int use_id;
46+
float scalar;
47+
explicit Linear(Layer* prev, float scalar = 1)
48+
: Layer(prev->size)
49+
, prev(prev)
50+
, scalar(scalar) {
51+
prev->use();
52+
use_id = prev->use();
53+
}
54+
55+
void compile(size_t batch_size) override {
56+
this->compile_suboutput(batch_size, Tape {size, batch_size});
57+
this->grad_op = use_id == prev->used() ? SET : INCREMENT;
58+
}
59+
60+
void forward() override {
61+
Layer::forward();
62+
operations::linear<data::GPU>(prev->dense_output.values, dense_output.values, scalar);
63+
}
64+
65+
void backward() override {
66+
Layer::backward();
67+
operations::linear_bp<data::GPU>(prev->dense_output.values,
68+
prev->dense_output.gradients,
69+
dense_output.values,
70+
dense_output.gradients,
71+
scalar,
72+
grad_op);
73+
}
74+
};
75+
76+
4277
struct Sigmoid : public Layer {
4378
Layer* prev;
4479
GradientOperation grad_op;

src/nn/layers/merge.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,20 @@ struct Merge : public Layer {
1717
Layer* l1;
1818
Layer* l2;
1919

20+
int use_id_1;
21+
int use_id_2;
22+
2023
Merge(Layer* l1, Layer* l2)
2124
: Layer(l1->size + l2->size)
2225
, l1(l1)
2326
, l2(l2) {
24-
l1->use();
25-
l2->use();}
27+
use_id_1 = l1->use();
28+
use_id_2 = l2->use();}
2629

2730
void compile(size_t batch_size) override {
2831
this->compile_suboutput(batch_size, Tape(size, batch_size));
32+
ERROR(use_id_1 == l1->used());
33+
ERROR(use_id_2 == l2->used());
2934
}
3035

3136
void compile_suboutput(size_t batch_size, const Tape& output) override {

src/nn/layers/split.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,12 @@ struct SplitHead : public Layer {
4343

4444
struct Split : public Layer {
4545
std::vector<SplitHead> heads;
46+
Layer* prev;
47+
int use_id;
48+
4649

4750
explicit Split(Layer* prev, const std::vector<size_t>& head_sizes)
48-
: Layer(prev->size) {
51+
: Layer(prev->size), prev(prev) {
4952
size_t total_assigned_size = 0;
5053
for (size_t i = 0; i < head_sizes.size(); ++i) {
5154
heads.emplace_back(prev, head_sizes[i], total_assigned_size);
@@ -55,7 +58,7 @@ struct Split : public Layer {
5558
size_t remaining_size = prev->size - total_assigned_size;
5659
heads.emplace_back(prev, remaining_size, total_assigned_size);
5760

58-
prev->use();
61+
use_id = prev->use();
5962
}
6063

6164
SplitHead* operator[](size_t index) {
@@ -66,6 +69,7 @@ struct Split : public Layer {
6669
}
6770

6871
void compile(size_t batch_size) override {
72+
ERROR(use_id == prev->used()); // make sure the operation would be "set"
6973
for (auto& head : heads) {
7074
head.compile(batch_size);
7175
}

src/nn/layers/weighted_sum.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ struct WeightedSum : public Layer {
4343
prev_2->dense_output.gradients,
4444
dense_output.gradients,
4545
alpha,
46-
beta);
46+
beta,
47+
grad_op_1,
48+
grad_op_2);
4749
}
4850
};
4951

src/operations/activation/activations.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
namespace operations {
77

88
DEFINE_ACTIVATION(linear
9-
, A[ida];
10-
, B_grd[idb];);
9+
, A[ida] * scalar;
10+
, B_grd[idb] * scalar;);
1111

1212
DEFINE_ACTIVATION(sigmoid
1313
, 1.0f / (1 + exp(-A[ida] * scalar));

0 commit comments

Comments
 (0)