@@ -75,30 +75,38 @@ Status InitializeSession(int num_threads, const string& graph,
7575 return Status::OK ();
7676}
7777
78+ template <class T >
79+ void InitializeTensor (const std::vector<float >& initialization_values,
80+ Tensor* input_tensor) {
81+ auto type_tensor = input_tensor->flat <T>();
82+ type_tensor = type_tensor.constant (0 );
83+ if (!initialization_values.empty ()) {
84+ for (int i = 0 ; i < initialization_values.size (); ++i) {
85+ type_tensor (i) = static_cast <T>(initialization_values[i]);
86+ }
87+ }
88+ }
89+
7890void CreateTensorsFromInputInfo (
7991 const std::vector<InputLayerInfo>& inputs,
8092 std::vector<std::pair<string, tensorflow::Tensor> >* input_tensors) {
8193 for (const InputLayerInfo& input : inputs) {
8294 Tensor input_tensor (input.data_type , input.shape );
8395 switch (input.data_type ) {
8496 case DT_INT32 : {
85- auto int_tensor = input_tensor.flat <int32>();
86- int_tensor = int_tensor.constant (0.0 );
97+ InitializeTensor<int32>(input.initialization_values , &input_tensor);
8798 break ;
8899 }
89100 case DT_FLOAT : {
90- auto float_tensor = input_tensor.flat <float >();
91- float_tensor = float_tensor.constant (0.0 );
101+ InitializeTensor<float >(input.initialization_values , &input_tensor);
92102 break ;
93103 }
94104 case DT_QUINT8 : {
95- auto int_tensor = input_tensor.flat <quint8>();
96- int_tensor = int_tensor.constant (0.0 );
105+ InitializeTensor<quint8>(input.initialization_values , &input_tensor);
97106 break ;
98107 }
99108 case DT_UINT8 : {
100- auto int_tensor = input_tensor.flat <uint8>();
101- int_tensor = int_tensor.constant (0.0 );
109+ InitializeTensor<uint8>(input.initialization_values , &input_tensor);
102110 break ;
103111 }
104112 default :
@@ -248,6 +256,7 @@ int Main(int argc, char** argv) {
248256 string input_layer_string = " input:0" ;
249257 string input_layer_shape_string = " 1,224,224,3" ;
250258 string input_layer_type_string = " float" ;
259+ string input_layer_values_string = " " ;
251260 string output_layer_string = " output:0" ;
252261 int num_runs = 50 ;
253262 string run_delay = " -1.0" ;
@@ -270,6 +279,8 @@ int Main(int argc, char** argv) {
270279 Flag (" input_layer" , &input_layer_string, " input layer names" ),
271280 Flag (" input_layer_shape" , &input_layer_shape_string, " input layer shape" ),
272281 Flag (" input_layer_type" , &input_layer_type_string, " input layer type" ),
282+ Flag (" input_layer_values" , &input_layer_values_string,
283+ " values to initialize the inputs with" ),
273284 Flag (" output_layer" , &output_layer_string, " output layer name" ),
274285 Flag (" num_runs" , &num_runs, " number of runs" ),
275286 Flag (" run_delay" , &run_delay, " delay between runs in seconds" ),
@@ -304,6 +315,8 @@ int Main(int argc, char** argv) {
304315 str_util::Split (input_layer_shape_string, ' :' );
305316 std::vector<string> input_layer_types =
306317 str_util::Split (input_layer_type_string, ' ,' );
318+ std::vector<string> input_layer_values =
319+ str_util::Split (input_layer_values_string, ' :' );
307320 std::vector<string> output_layers = str_util::Split (output_layer_string, ' ,' );
308321 if ((input_layers.size () != input_layer_shapes.size ()) ||
309322 (input_layers.size () != input_layer_types.size ())) {
@@ -374,6 +387,12 @@ int Main(int argc, char** argv) {
374387 input.shape .AddDim (sizes[i]);
375388 }
376389 input.name = input_layers[n];
390+ if (n < input_layer_values.size ()) {
391+ CHECK (str_util::SplitAndParseAsFloats (input_layer_values[n], ' ,' ,
392+ &input.initialization_values ))
393+ << " Incorrect initialization values string specified: "
394+ << input_layer_values[n];
395+ }
377396 inputs.push_back (input);
378397 }
379398
0 commit comments