Skip to content

Commit 08bdd69

Browse files
ChunliFfacebook-github-bot
authored andcommitted
Extract feature length information from SigridTransforms op (#20384)
Summary: Pull Request resolved: #20384 Pull Request resolved: #20171 Extract feature length information from SigridTransforms op Reviewed By: ipiszy Differential Revision: D15219408 fbshipit-source-id: 307d2b65b208d3af6977d90246d0372795c45815
1 parent 428104c commit 08bdd69

File tree

4 files changed

+78
-200
lines changed

4 files changed

+78
-200
lines changed

caffe2/opt/backend_transformer_base.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ ShapeInfoMap BackendTransformerBase::inferShapes(
120120
}
121121
}
122122
auto eng = BoundShapeInferencerRegistry()->Create("C10", spec);
123-
eng->InferBoundShapeAndType(*pred_net, shape_map);
123+
eng->InferBoundShapeAndType(*pred_net, shape_map, ws);
124124
const auto& out_map = eng->shape_info();
125125
shape_map.clear();
126126
for (const auto& kv : out_map) {

caffe2/opt/bound_shape_inference_test.cc

Lines changed: 10 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ TEST(BoundShapeInference, SparseLengthsSum) {
5151
"Weights", makeTensorInfo(ShapeInfo::DimType::CONSTANT, {1000, 16}));
5252
BoundShapeSpec spec(20, 1000);
5353
BoundShapeInferencer eng(spec);
54-
eng.InferBoundShapeAndType(net, shape_map);
54+
eng.InferBoundShapeAndType(net, shape_map, nullptr);
5555
const auto& out_shape = eng.shape_info();
5656
verifyShapeInfo(
5757
out_shape, "Weights", ShapeInfo::DimType::CONSTANT, {1000, 16});
@@ -86,7 +86,7 @@ TEST(BoundShapeInference, SparseLengthsSumFused8BitRowwise) {
8686
ShapeInfo::DimType::CONSTANT, {1000, 58}, TensorProto_DataType_INT8));
8787
BoundShapeSpec spec(20, 1000);
8888
BoundShapeInferencer eng(spec);
89-
eng.InferBoundShapeAndType(net, shape_map);
89+
eng.InferBoundShapeAndType(net, shape_map, nullptr);
9090
const auto& out_shape = eng.shape_info();
9191
verifyShapeInfo(
9292
out_shape,
@@ -127,7 +127,7 @@ TEST(BoundShapeInference, LengthsRangeFill) {
127127
ShapeInfoMap shape_map;
128128
BoundShapeSpec spec(20, 1000);
129129
BoundShapeInferencer eng(spec);
130-
eng.InferBoundShapeAndType(net, shape_map);
130+
eng.InferBoundShapeAndType(net, shape_map, nullptr);
131131
const auto& out_shape = eng.shape_info();
132132
verifyShapeInfo(
133133
out_shape,
@@ -175,7 +175,7 @@ TEST(BoundShapeInference, Reshape) {
175175
shape_map.emplace("B0", makeTensorInfo(ShapeInfo::DimType::CONSTANT, {16}));
176176
BoundShapeSpec spec(20, 1000);
177177
BoundShapeInferencer eng(spec);
178-
eng.InferBoundShapeAndType(net, shape_map);
178+
eng.InferBoundShapeAndType(net, shape_map, nullptr);
179179
const auto& out_shape = eng.shape_info();
180180
verifyShapeInfo(
181181
out_shape, "X0", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 1024});
@@ -203,7 +203,7 @@ TEST(BoundShapeInference, ConcatMissingInput) {
203203
"I0",
204204
makeTensorInfo(ShapeInfo::DimType::BATCH, {spec.max_batch_size, 60}));
205205
BoundShapeInferencer eng(spec);
206-
eng.InferBoundShapeAndType(net, shape_map);
206+
eng.InferBoundShapeAndType(net, shape_map, nullptr);
207207
const auto& out_shape = eng.shape_info();
208208
verifyShapeInfo(
209209
out_shape, "I0", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 60});
@@ -233,7 +233,7 @@ TEST(BoundShapeInference, ConcatInferInputBackwards) {
233233
"W0", makeTensorInfo(ShapeInfo::DimType::CONSTANT, {101, 16}));
234234
shape_map.emplace("B0", makeTensorInfo(ShapeInfo::DimType::CONSTANT, {16}));
235235
BoundShapeInferencer eng(spec);
236-
eng.InferBoundShapeAndType(net, shape_map);
236+
eng.InferBoundShapeAndType(net, shape_map, nullptr);
237237
const auto& out_shape = eng.shape_info();
238238
verifyShapeInfo(
239239
out_shape, "I0", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 60});
@@ -274,7 +274,7 @@ TEST(BoundShapeInference, Split) {
274274
"X1",
275275
makeTensorInfo(ShapeInfo::DimType::BATCH, {spec.max_batch_size, 2, 48}));
276276
BoundShapeInferencer eng(spec);
277-
eng.InferBoundShapeAndType(net, shape_map);
277+
eng.InferBoundShapeAndType(net, shape_map, nullptr);
278278
const auto& out_shape = eng.shape_info();
279279
verifyShapeInfo(
280280
out_shape, "X", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 48});
@@ -317,7 +317,7 @@ TEST(BoundShapeInference, FC) {
317317
shape_map.emplace("B1", makeTensorInfo(ShapeInfo::DimType::CONSTANT, {1024}));
318318
BoundShapeSpec spec(20, 1000);
319319
BoundShapeInferencer eng(spec);
320-
eng.InferBoundShapeAndType(net, shape_map);
320+
eng.InferBoundShapeAndType(net, shape_map, nullptr);
321321
const auto& out_shape = eng.shape_info();
322322
verifyShapeInfo(
323323
out_shape, "X0", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 1024});
@@ -342,54 +342,14 @@ TEST(BoundShapeInference, FC3D) {
342342
shape_map.emplace("B0", makeTensorInfo(ShapeInfo::DimType::CONSTANT, {16}));
343343
BoundShapeSpec spec(20, 1000);
344344
BoundShapeInferencer eng(spec);
345-
eng.InferBoundShapeAndType(net, shape_map);
345+
eng.InferBoundShapeAndType(net, shape_map, nullptr);
346346
const auto& out_shape = eng.shape_info();
347347
verifyShapeInfo(
348348
out_shape, "X0", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 1024});
349349
verifyShapeInfo(
350350
out_shape, "Out0", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 16});
351351
}
352352

353-
TEST(BoundShapeInference, ClipRangesGatherSigridHash) {
354-
FLAGS_caffe2_extract_feature_length_for_shape_inference = true;
355-
NetDef net;
356-
net.add_op()->CopyFrom(CreateOperatorDef(
357-
"ClipRangesGatherSigridHash",
358-
"",
359-
{"R0", "V0"},
360-
{"F0_lengths_0", "F0_values_0", "F1_lengths_0", "F1_values_0"},
361-
{MakeArgument<std::vector<int>>("max_lengths", {200, 400})}));
362-
ShapeInfoMap shape_map;
363-
BoundShapeSpec spec(50, 1000);
364-
BoundShapeInferencer eng(spec);
365-
eng.InferBoundShapeAndType(net, shape_map);
366-
const auto& out_shape = eng.shape_info();
367-
verifyShapeInfo(
368-
out_shape,
369-
"F0_lengths_0",
370-
ShapeInfo::DimType::BATCH,
371-
{spec.max_batch_size},
372-
TensorProto_DataType_INT32);
373-
verifyShapeInfo(
374-
out_shape,
375-
"F0_values_0",
376-
ShapeInfo::DimType::SEQ,
377-
{spec.max_batch_size * 200},
378-
TensorProto_DataType_INT64);
379-
verifyShapeInfo(
380-
out_shape,
381-
"F1_lengths_0",
382-
ShapeInfo::DimType::BATCH,
383-
{spec.max_batch_size},
384-
TensorProto_DataType_INT32);
385-
verifyShapeInfo(
386-
out_shape,
387-
"F1_values_0",
388-
ShapeInfo::DimType::SEQ,
389-
{spec.max_batch_size * 400},
390-
TensorProto_DataType_INT64);
391-
}
392-
393353
TEST(BoundShapeInference, Combo0) {
394354
NetDef net;
395355
net.add_op()->CopyFrom(CreateOperatorDef(
@@ -421,56 +381,9 @@ TEST(BoundShapeInference, Combo0) {
421381
"Indices", makeTensorInfo(ShapeInfo::DimType::CONSTANT, {2}));
422382
BoundShapeSpec spec(20, 1000);
423383
BoundShapeInferencer eng(spec);
424-
eng.InferBoundShapeAndType(net, shape_map);
384+
eng.InferBoundShapeAndType(net, shape_map, nullptr);
425385
const auto& out_shape = eng.shape_info();
426386
LOG(INFO) << eng.PrintShapeInfo();
427387
verifyShapeInfo(
428388
out_shape, "Gout", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 2});
429389
}
430-
431-
TEST(BoundShapeInference, Combo1) {
432-
FLAGS_caffe2_extract_feature_length_for_shape_inference = true;
433-
NetDef net;
434-
net.add_op()->CopyFrom(CreateOperatorDef(
435-
"ClipRangesGatherSigridHash",
436-
"",
437-
{"R0", "V0"},
438-
{"F0_lengths_0", "F0_values_0", "F1_lengths_0", "F1_values_0"},
439-
{MakeArgument<std::vector<int>>("max_lengths", {300, 400})}));
440-
441-
net.add_op()->CopyFrom(CreateOperatorDef(
442-
"SparseLengthsSumFused8BitRowwise",
443-
"",
444-
{"Weights", "F0_values_0", "F0_lengths_0"},
445-
{"Out"},
446-
{}));
447-
ShapeInfoMap shape_map;
448-
shape_map.emplace(
449-
"Weights",
450-
makeTensorInfo(
451-
ShapeInfo::DimType::CONSTANT, {1000, 58}, TensorProto_DataType_INT8));
452-
BoundShapeSpec spec(20, 1000);
453-
BoundShapeInferencer eng(spec);
454-
eng.InferBoundShapeAndType(net, shape_map);
455-
const auto& out_shape = eng.shape_info();
456-
verifyShapeInfo(
457-
out_shape,
458-
"Weights",
459-
ShapeInfo::DimType::CONSTANT,
460-
{1000, 58},
461-
TensorProto_DataType_INT8);
462-
verifyShapeInfo(
463-
out_shape,
464-
"F0_values_0",
465-
ShapeInfo::DimType::SEQ,
466-
{spec.max_batch_size * 300},
467-
TensorProto_DataType_INT64);
468-
verifyShapeInfo(
469-
out_shape,
470-
"F0_lengths_0",
471-
ShapeInfo::DimType::BATCH,
472-
{spec.max_batch_size},
473-
TensorProto_DataType_INT32);
474-
verifyShapeInfo(
475-
out_shape, "Out", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 50});
476-
}

0 commit comments

Comments
 (0)