@@ -51,7 +51,7 @@ TEST(BoundShapeInference, SparseLengthsSum) {
5151 " Weights" , makeTensorInfo (ShapeInfo::DimType::CONSTANT, {1000 , 16 }));
5252 BoundShapeSpec spec (20 , 1000 );
5353 BoundShapeInferencer eng (spec);
54- eng.InferBoundShapeAndType (net, shape_map);
54+ eng.InferBoundShapeAndType (net, shape_map, nullptr );
5555 const auto & out_shape = eng.shape_info ();
5656 verifyShapeInfo (
5757 out_shape, " Weights" , ShapeInfo::DimType::CONSTANT, {1000 , 16 });
@@ -86,7 +86,7 @@ TEST(BoundShapeInference, SparseLengthsSumFused8BitRowwise) {
8686 ShapeInfo::DimType::CONSTANT, {1000 , 58 }, TensorProto_DataType_INT8));
8787 BoundShapeSpec spec (20 , 1000 );
8888 BoundShapeInferencer eng (spec);
89- eng.InferBoundShapeAndType (net, shape_map);
89+ eng.InferBoundShapeAndType (net, shape_map, nullptr );
9090 const auto & out_shape = eng.shape_info ();
9191 verifyShapeInfo (
9292 out_shape,
@@ -127,7 +127,7 @@ TEST(BoundShapeInference, LengthsRangeFill) {
127127 ShapeInfoMap shape_map;
128128 BoundShapeSpec spec (20 , 1000 );
129129 BoundShapeInferencer eng (spec);
130- eng.InferBoundShapeAndType (net, shape_map);
130+ eng.InferBoundShapeAndType (net, shape_map, nullptr );
131131 const auto & out_shape = eng.shape_info ();
132132 verifyShapeInfo (
133133 out_shape,
@@ -175,7 +175,7 @@ TEST(BoundShapeInference, Reshape) {
175175 shape_map.emplace (" B0" , makeTensorInfo (ShapeInfo::DimType::CONSTANT, {16 }));
176176 BoundShapeSpec spec (20 , 1000 );
177177 BoundShapeInferencer eng (spec);
178- eng.InferBoundShapeAndType (net, shape_map);
178+ eng.InferBoundShapeAndType (net, shape_map, nullptr );
179179 const auto & out_shape = eng.shape_info ();
180180 verifyShapeInfo (
181181 out_shape, " X0" , ShapeInfo::DimType::BATCH, {spec.max_batch_size , 1024 });
@@ -203,7 +203,7 @@ TEST(BoundShapeInference, ConcatMissingInput) {
203203 " I0" ,
204204 makeTensorInfo (ShapeInfo::DimType::BATCH, {spec.max_batch_size , 60 }));
205205 BoundShapeInferencer eng (spec);
206- eng.InferBoundShapeAndType (net, shape_map);
206+ eng.InferBoundShapeAndType (net, shape_map, nullptr );
207207 const auto & out_shape = eng.shape_info ();
208208 verifyShapeInfo (
209209 out_shape, " I0" , ShapeInfo::DimType::BATCH, {spec.max_batch_size , 60 });
@@ -233,7 +233,7 @@ TEST(BoundShapeInference, ConcatInferInputBackwards) {
233233 " W0" , makeTensorInfo (ShapeInfo::DimType::CONSTANT, {101 , 16 }));
234234 shape_map.emplace (" B0" , makeTensorInfo (ShapeInfo::DimType::CONSTANT, {16 }));
235235 BoundShapeInferencer eng (spec);
236- eng.InferBoundShapeAndType (net, shape_map);
236+ eng.InferBoundShapeAndType (net, shape_map, nullptr );
237237 const auto & out_shape = eng.shape_info ();
238238 verifyShapeInfo (
239239 out_shape, " I0" , ShapeInfo::DimType::BATCH, {spec.max_batch_size , 60 });
@@ -274,7 +274,7 @@ TEST(BoundShapeInference, Split) {
274274 " X1" ,
275275 makeTensorInfo (ShapeInfo::DimType::BATCH, {spec.max_batch_size , 2 , 48 }));
276276 BoundShapeInferencer eng (spec);
277- eng.InferBoundShapeAndType (net, shape_map);
277+ eng.InferBoundShapeAndType (net, shape_map, nullptr );
278278 const auto & out_shape = eng.shape_info ();
279279 verifyShapeInfo (
280280 out_shape, " X" , ShapeInfo::DimType::BATCH, {spec.max_batch_size , 48 });
@@ -317,7 +317,7 @@ TEST(BoundShapeInference, FC) {
317317 shape_map.emplace (" B1" , makeTensorInfo (ShapeInfo::DimType::CONSTANT, {1024 }));
318318 BoundShapeSpec spec (20 , 1000 );
319319 BoundShapeInferencer eng (spec);
320- eng.InferBoundShapeAndType (net, shape_map);
320+ eng.InferBoundShapeAndType (net, shape_map, nullptr );
321321 const auto & out_shape = eng.shape_info ();
322322 verifyShapeInfo (
323323 out_shape, " X0" , ShapeInfo::DimType::BATCH, {spec.max_batch_size , 1024 });
@@ -342,54 +342,14 @@ TEST(BoundShapeInference, FC3D) {
342342 shape_map.emplace (" B0" , makeTensorInfo (ShapeInfo::DimType::CONSTANT, {16 }));
343343 BoundShapeSpec spec (20 , 1000 );
344344 BoundShapeInferencer eng (spec);
345- eng.InferBoundShapeAndType (net, shape_map);
345+ eng.InferBoundShapeAndType (net, shape_map, nullptr );
346346 const auto & out_shape = eng.shape_info ();
347347 verifyShapeInfo (
348348 out_shape, " X0" , ShapeInfo::DimType::BATCH, {spec.max_batch_size , 1024 });
349349 verifyShapeInfo (
350350 out_shape, " Out0" , ShapeInfo::DimType::BATCH, {spec.max_batch_size , 16 });
351351}
352352
353- TEST (BoundShapeInference, ClipRangesGatherSigridHash) {
354- FLAGS_caffe2_extract_feature_length_for_shape_inference = true ;
355- NetDef net;
356- net.add_op ()->CopyFrom (CreateOperatorDef (
357- " ClipRangesGatherSigridHash" ,
358- " " ,
359- {" R0" , " V0" },
360- {" F0_lengths_0" , " F0_values_0" , " F1_lengths_0" , " F1_values_0" },
361- {MakeArgument<std::vector<int >>(" max_lengths" , {200 , 400 })}));
362- ShapeInfoMap shape_map;
363- BoundShapeSpec spec (50 , 1000 );
364- BoundShapeInferencer eng (spec);
365- eng.InferBoundShapeAndType (net, shape_map);
366- const auto & out_shape = eng.shape_info ();
367- verifyShapeInfo (
368- out_shape,
369- " F0_lengths_0" ,
370- ShapeInfo::DimType::BATCH,
371- {spec.max_batch_size },
372- TensorProto_DataType_INT32);
373- verifyShapeInfo (
374- out_shape,
375- " F0_values_0" ,
376- ShapeInfo::DimType::SEQ,
377- {spec.max_batch_size * 200 },
378- TensorProto_DataType_INT64);
379- verifyShapeInfo (
380- out_shape,
381- " F1_lengths_0" ,
382- ShapeInfo::DimType::BATCH,
383- {spec.max_batch_size },
384- TensorProto_DataType_INT32);
385- verifyShapeInfo (
386- out_shape,
387- " F1_values_0" ,
388- ShapeInfo::DimType::SEQ,
389- {spec.max_batch_size * 400 },
390- TensorProto_DataType_INT64);
391- }
392-
393353TEST (BoundShapeInference, Combo0) {
394354 NetDef net;
395355 net.add_op ()->CopyFrom (CreateOperatorDef (
@@ -421,56 +381,9 @@ TEST(BoundShapeInference, Combo0) {
421381 " Indices" , makeTensorInfo (ShapeInfo::DimType::CONSTANT, {2 }));
422382 BoundShapeSpec spec (20 , 1000 );
423383 BoundShapeInferencer eng (spec);
424- eng.InferBoundShapeAndType (net, shape_map);
384+ eng.InferBoundShapeAndType (net, shape_map, nullptr );
425385 const auto & out_shape = eng.shape_info ();
426386 LOG (INFO) << eng.PrintShapeInfo ();
427387 verifyShapeInfo (
428388 out_shape, " Gout" , ShapeInfo::DimType::BATCH, {spec.max_batch_size , 2 });
429389}
430-
431- TEST (BoundShapeInference, Combo1) {
432- FLAGS_caffe2_extract_feature_length_for_shape_inference = true ;
433- NetDef net;
434- net.add_op ()->CopyFrom (CreateOperatorDef (
435- " ClipRangesGatherSigridHash" ,
436- " " ,
437- {" R0" , " V0" },
438- {" F0_lengths_0" , " F0_values_0" , " F1_lengths_0" , " F1_values_0" },
439- {MakeArgument<std::vector<int >>(" max_lengths" , {300 , 400 })}));
440-
441- net.add_op ()->CopyFrom (CreateOperatorDef (
442- " SparseLengthsSumFused8BitRowwise" ,
443- " " ,
444- {" Weights" , " F0_values_0" , " F0_lengths_0" },
445- {" Out" },
446- {}));
447- ShapeInfoMap shape_map;
448- shape_map.emplace (
449- " Weights" ,
450- makeTensorInfo (
451- ShapeInfo::DimType::CONSTANT, {1000 , 58 }, TensorProto_DataType_INT8));
452- BoundShapeSpec spec (20 , 1000 );
453- BoundShapeInferencer eng (spec);
454- eng.InferBoundShapeAndType (net, shape_map);
455- const auto & out_shape = eng.shape_info ();
456- verifyShapeInfo (
457- out_shape,
458- " Weights" ,
459- ShapeInfo::DimType::CONSTANT,
460- {1000 , 58 },
461- TensorProto_DataType_INT8);
462- verifyShapeInfo (
463- out_shape,
464- " F0_values_0" ,
465- ShapeInfo::DimType::SEQ,
466- {spec.max_batch_size * 300 },
467- TensorProto_DataType_INT64);
468- verifyShapeInfo (
469- out_shape,
470- " F0_lengths_0" ,
471- ShapeInfo::DimType::BATCH,
472- {spec.max_batch_size },
473- TensorProto_DataType_INT32);
474- verifyShapeInfo (
475- out_shape, " Out" , ShapeInfo::DimType::BATCH, {spec.max_batch_size , 50 });
476- }
0 commit comments