Skip to content

Commit d9e15bc

Browse files
Alyssa Wangfacebook-github-bot
authored andcommitted
Perform weight re-init for embedding table in sparse_lookup.py (#22348)
Summary: Pull Request resolved: #22348 This is the last step of LRU hash eviction weight re-init. This diff checks if there's evicted values in sparse_lookup, if so call op created in D15709866 to re-init the values for indicies in evicted_values. Also created gradient op for the operator. The gradient op just passes the output gradient as input gradient. Reviewed By: itomatik Differential Revision: D16044736 fbshipit-source-id: 9afb85209b0de1038c5153bcb7dfc5f52e0b2abb
1 parent c9f41e9 commit d9e15bc

File tree

5 files changed

+122
-7
lines changed

5 files changed

+122
-7
lines changed

caffe2/operators/copy_rows_to_tensor_op.cc

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22

33
namespace caffe2 {
44
namespace {
5+
56
REGISTER_CPU_OPERATOR(CopyRowsToTensor, CopyRowsToTensorOp<CPUContext>);
7+
REGISTER_CPU_GRADIENT_OPERATOR(
8+
CopyRowsToTensorGradient,
9+
CopyRowsToTensorGradientOp<CPUContext>);
610

711
OPERATOR_SCHEMA(CopyRowsToTensor)
812
.NumInputs(3)
@@ -30,5 +34,36 @@ OPERATOR_SCHEMA(CopyRowsToTensor)
3034
return out;
3135
});
3236

37+
GRADIENT_OPERATOR_SCHEMA(CopyRowsToTensorGradient)
38+
.NumInputs(1)
39+
.NumOutputs(1)
40+
.AllowInplace({{0, 0}});
41+
42+
class GetCopyRowsToTensorGradient : public GradientMakerBase {
43+
using GradientMakerBase::GradientMakerBase;
44+
vector<OperatorDef> GetGradientDefs() override {
45+
if (g_output_[0].IsDense()) {
46+
return SingleGradientDef(
47+
"CopyRowsToTensorGradient",
48+
"",
49+
vector<string>{GO(0)},
50+
vector<string>{GI(0)});
51+
} else {
52+
return vector<OperatorDef>{CreateOperatorDef(
53+
"CopyRowsToTensorGradient",
54+
"",
55+
std::vector<string>{GO_I(0)},
56+
std::vector<string>{GI_I(0)}),
57+
CreateOperatorDef(
58+
"CopyRowsToTensorGradient",
59+
"",
60+
std::vector<string>{GO_V(0)},
61+
std::vector<string>{GI_V(0)})};
62+
}
63+
}
64+
};
65+
66+
REGISTER_GRADIENT(CopyRowsToTensor, GetCopyRowsToTensorGradient);
67+
3368
} // namespace
3469
} // namespace caffe2

caffe2/operators/copy_rows_to_tensor_op.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,4 +53,30 @@ class CopyRowsToTensorOp : public Operator<Context> {
5353
protected:
5454
INPUT_TAGS(INPUT_TENSOR, INDICES, ROW);
5555
};
56+
57+
template <class Context>
58+
class CopyRowsToTensorGradientOp : public Operator<Context> {
59+
public:
60+
USE_OPERATOR_CONTEXT_FUNCTIONS;
61+
CopyRowsToTensorGradientOp(const OperatorDef& operator_def, Workspace* ws)
62+
: Operator<Context>(operator_def, ws) {}
63+
64+
bool RunOnDevice() override {
65+
return DispatchHelper<
66+
TensorTypes<at::Half, float, double, int32_t, int64_t>>::
67+
call(this, Input(0));
68+
}
69+
template <typename T>
70+
bool DoRunWithType() {
71+
auto* output = Output(0);
72+
output->ResizeLike(Input(0));
73+
auto* output_data = output->template mutable_data<T>();
74+
auto& input = Input(0);
75+
const auto* input_data = input.template data<T>();
76+
std::memcpy(output_data, input_data, input.size(0) * sizeof(T));
77+
78+
return true;
79+
}
80+
};
81+
5682
} // namespace caffe2

caffe2/python/layers/sparse_lookup.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,12 @@ def __init__(self, model, input_record, inner_shape, reducer,
125125

126126
self.weight_init = weight_init or default_init_op
127127

128+
self.evicted_values = None
129+
if schema.equal_schemas(self.input_record, IdListWithEvicted) or \
130+
schema.equal_schemas(self.input_record, IdScoreListWithEvicted,
131+
check_field_types=False):
132+
self.evicted_values = self.input_record._evicted_values
133+
128134
# If fp16 is used, make sure fp16 init op is used
129135
if self.trainer_version == "fp16":
130136
assert self.reducer in self._fp16_compatible_reducers, (
@@ -169,6 +175,14 @@ def __init__(self, model, input_record, inner_shape, reducer,
169175
average_length=avg_length),
170176
regularizer=regularizer
171177
)
178+
if self.evicted_values:
179+
self.reinit_vec = self.create_param(
180+
param_name="reinit_vec",
181+
shape=inner_shape,
182+
initializer=self.weight_init,
183+
optimizer=model.NoOptim,
184+
regularizer=None,
185+
)
172186

173187
self.scale_bias_init = ('ConstantFill', {'value': 0.0})
174188

@@ -407,6 +421,9 @@ def _add_ops_id_score_list(self, net, version):
407421
"Trying to create with {}".format(self.reducer)
408422

409423
def _add_ops(self, net, version='fp32'):
424+
if self.evicted_values:
425+
net.CopyRowsToTensor(
426+
[self.w, self.evicted_values.get(), self.reinit_vec], [self.w])
410427
if _is_id_list(self.input_record):
411428
self._add_ops_id_list(net, version=version)
412429
elif _is_id_score_list(self.input_record):

caffe2/python/layers_test.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
is_request_only_scalar,
3232
get_key,
3333
)
34-
3534
import logging
3635
logger = logging.getLogger(__name__)
3736

@@ -231,6 +230,46 @@ def testFCwithAxis2(self):
231230

232231
train_init_net, train_net = self.get_training_nets()
233232

233+
def testSparseLookupSumPoolingWithEviction(self):
234+
# Create test embedding table of 1 row
235+
record = schema.NewRecord(self.model.net, schema.Struct(
236+
('sparse', schema.Struct(
237+
('sparse_feature_0', schema.ListWithEvicted(
238+
schema.Scalar(np.int64,
239+
metadata=schema.Metadata(categorical_limit=1)),)),)),
240+
))
241+
embedding_dim = 8
242+
lengths_blob = record.sparse.sparse_feature_0.lengths.get()
243+
values_blob = record.sparse.sparse_feature_0.items.get()
244+
evicted_values_blob = record.sparse.sparse_feature_0._evicted_values.get()
245+
lengths = np.array([1]).astype(np.int32)
246+
values = np.array([0]).astype(np.int64)
247+
# Need to reset row 0
248+
evicted_values = np.array([0]).astype(np.int64)
249+
workspace.FeedBlob(lengths_blob, lengths)
250+
workspace.FeedBlob(values_blob, values)
251+
workspace.FeedBlob(evicted_values_blob, evicted_values)
252+
253+
embedding_after_pooling = self.model.SparseLookup(
254+
record.sparse.sparse_feature_0, [embedding_dim], 'Sum', weight_init=("ConstantFill", {"value": 1.0}))
255+
256+
self.model.output_schema = schema.Struct()
257+
self.assertEqual(
258+
schema.Scalar((np.float32, (embedding_dim, ))),
259+
embedding_after_pooling
260+
)
261+
train_init_net, train_net = self.get_training_nets()
262+
workspace.RunNetOnce(train_init_net)
263+
embedding_after_init = workspace.FetchBlob("sparse_lookup/w")
264+
# Change row 0's value before reset
265+
new_values = np.array([[2, 2, 2, 2, 2, 2, 2, 2]]).astype(np.float32)
266+
workspace.FeedBlob("sparse_lookup/w", new_values)
267+
workspace.RunNetOnce(train_net.Proto())
268+
embedding_after_training = workspace.FetchBlob("sparse_lookup/w")
269+
# Verify row 0's value does not change after reset
270+
self.assertEquals(embedding_after_training.all(), embedding_after_init.all())
271+
272+
234273

235274
def testSparseLookupSumPooling(self):
236275
record = schema.NewRecord(self.model.net, schema.Struct(

caffe2/python/operator_test/copy_rows_to_tensor_op_test.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@ def get_input_tensors():
1515
height = np.random.randint(1, 10)
1616
width = np.random.randint(1, 10)
1717
dtype = np.float32
18-
print("height", height)
19-
print("width", width)
2018
input_tensor = hu.arrays(
2119
dims=[height, width],
2220
dtype=dtype,
@@ -43,12 +41,12 @@ def ref(input_tensor, indices, row):
4341
for idx in indices:
4442
input_tensor[idx] = row
4543
return [input_tensor]
46-
44+
op = core.CreateOperator(
45+
"CopyRowsToTensor", ["input_tensor", "indices", "row"], ["input_tensor"]
46+
)
4747
self.assertReferenceChecks(
4848
device_option=gc,
49-
op=core.CreateOperator(
50-
"CopyRowsToTensor", ["input_tensor", "indices", "row"], ["input_tensor"]
51-
),
49+
op=op,
5250
inputs=[input_tensor, indices, row],
5351
reference=ref,
5452
)

0 commit comments

Comments
 (0)