1+ #include < c10/util/Exception.h>
12#include < torch/csrc/jit/frontend/ir_emitter.h>
3+ #include < torch/csrc/jit/ir/ir_views.h>
24#include < torch/csrc/jit/jit_log.h>
35#include < torch/csrc/jit/passes/inliner.h>
6+ #include < torch/csrc/jit/runtime/graph_iterator.h>
47#include < torch/csrc/jit/runtime/operator.h>
58#include < torch/csrc/jit/runtime/symbolic_shape_registry.h>
69#include < torch/csrc/jit/runtime/symbolic_shape_registry_util.h>
@@ -160,26 +163,121 @@ const at::optional<const FunctionSchema*> getInplaceVariant(
160163 return at::nullopt ;
161164}
162165
163- void registerSchema (
164- const FunctionSchema* schema_string,
165- const std::string& shape_compute_function_name,
166- std::unordered_map<std::string, std::shared_ptr<Graph>>& reused_functions,
167- const CompilationUnit& module ) {
168- if (reused_functions.count (shape_compute_function_name)) {
169- auto graph = reused_functions[shape_compute_function_name];
166+ TypePtr mapTensorToListOfInts (TypePtr type) {
167+ if (type->cast <TensorType>()) {
168+ return ListType::ofInts ();
169+ }
170+ at::ArrayRef<TypePtr> contained = type->containedTypes ();
171+ if (contained.empty ()) {
172+ return type;
173+ }
174+ return type->withContained (
175+ fmap (type->containedTypes (), mapTensorToListOfInts));
176+ }
170177
171- // allow extra unused arguments to map multiple functions to e.g. unary
178+ void checkForWhileLoop (
179+ const FunctionSchema* schema,
180+ std::shared_ptr<Graph> graph) {
181+ DepthFirstGraphNodeIterator graph_it (graph);
182+ for (auto * node = graph_it.next (); node != nullptr ; node = graph_it.next ()) {
183+ if (node->kind () != prim::Loop) {
184+ continue ;
185+ }
186+ LoopView loop (node);
187+ if (loop.loopType () != LoopView::For) {
188+ TORCH_WARN (
189+ " While loops are not yet implemented in unrolling which may make this shape function difficult to partially evaluate: " ,
190+ *node,
191+ " for schema " ,
192+ *schema);
193+ }
194+ }
195+ }
196+
197+ void checkInputReturnedAsOutput (
198+ const FunctionSchema* schema,
199+ const std::shared_ptr<Graph>& graph) {
200+ // Could use alias db here as well but would have to warn because it's
201+ // imprecise
202+ for (size_t i : c10::irange (graph->inputs ().size ())) {
203+ Value* input = graph->inputs ().at (i);
204+ for (size_t j : c10::irange (graph->outputs ().size ())) {
205+ Value* output = graph->outputs ().at (j);
206+ TORCH_CHECK (
207+ input != output,
208+ " For schema: " ,
209+ *schema,
210+ " input index " ,
211+ i,
212+ " is returned as output index " ,
213+ j,
214+ " . Shape functions must return new unaliased lists" );
215+ }
216+ }
217+ }
218+
219+ void checkInputAndOutputTypes (
220+ const FunctionSchema* schema,
221+ const std::shared_ptr<Graph>& graph) {
222+ // allow extra unused arguments to map multiple functions to e.g. unary
223+ TORCH_CHECK (
224+ graph->inputs ().size () <= schema->arguments ().size (),
225+ " Shape function must have fewer arguments than schema. Got " ,
226+ graph->inputs ().size (),
227+ " graph arguments and " ,
228+ schema->arguments ().size (),
229+ " schema arguments of schema: " ,
230+ *schema);
231+
232+ for (auto i : c10::irange (graph->inputs ().size ())) {
233+ auto inp_type = schema->arguments ().at (i).type ();
234+ auto mapped_type = mapTensorToListOfInts (inp_type);
235+ auto graph_type = graph->inputs ().at (i)->type ();
172236 TORCH_INTERNAL_ASSERT (
173- graph->inputs ().size () <= schema_string->arguments ().size ());
237+ mapped_type->isSubtypeOf (graph->inputs ().at (i)->type ()),
238+ " For schema type: " ,
239+ inp_type->str (),
240+ " Expected supertype of " ,
241+ mapped_type->str (),
242+ " but got graph_type " ,
243+ graph_type->str (),
244+ " at index " ,
245+ i,
246+ " of schema: " ,
247+ *schema);
248+ }
174249
175- cached_schema_to_graph[schema_string] = graph;
176- return ;
250+ TORCH_CHECK (
251+ graph->outputs ().size () == schema->returns ().size (),
252+ " Shape function equal number of outputs as schema. Got " ,
253+ graph->outputs ().size (),
254+ " graph outputs and " ,
255+ schema->returns ().size (),
256+ " schema returns of schema: " ,
257+ *schema);
258+
259+ for (auto i : c10::irange (schema->returns ().size ())) {
260+ auto out_type = schema->returns ().at (i).type ();
261+ auto mapped_type = mapTensorToListOfInts (out_type);
262+ auto graph_type = graph->outputs ().at (i)->type ();
263+ TORCH_INTERNAL_ASSERT (
264+ mapped_type->isSubtypeOf (graph->outputs ().at (i)->type ()),
265+ " For schema type: " ,
266+ out_type->str (),
267+ " Expected supertype of " ,
268+ mapped_type->str (),
269+ " but got graph_type " ,
270+ graph_type->str (),
271+ " at output index " ,
272+ i,
273+ " of schema: " ,
274+ *schema);
177275 }
276+ }
178277
179- Function& shape_compute_function =
180- module .get_function (shape_compute_function_name);
181- std::shared_ptr<Graph> graph =
182- toGraphFunction (shape_compute_function).graph ();
278+ void transformShapeFunction (
279+ const FunctionSchema* schema_string,
280+ std::shared_ptr<Graph> graph) {
183281 Inline (*graph);
184282
185283 // ATEN operators can return multiple unboxed values, this in contrast to
@@ -197,9 +295,33 @@ void registerSchema(
197295 graph->registerOutput (v);
198296 }
199297 }
200- // allow extra unused arguments to map multiple functions to e.g. unary
201- TORCH_INTERNAL_ASSERT (
202- graph->inputs ().size () <= schema_string->arguments ().size ());
298+ }
299+
300+ void registerSchema (
301+ const FunctionSchema* schema_string,
302+ const std::string& shape_compute_function_name,
303+ std::unordered_map<std::string, std::shared_ptr<Graph>>& reused_functions,
304+ const CompilationUnit& module ) {
305+ if (reused_functions.count (shape_compute_function_name)) {
306+ auto graph = reused_functions[shape_compute_function_name];
307+
308+ // allow extra unused arguments to map multiple functions to e.g. unary
309+ TORCH_INTERNAL_ASSERT (
310+ graph->inputs ().size () <= schema_string->arguments ().size ());
311+
312+ cached_schema_to_graph[schema_string] = graph;
313+ return ;
314+ }
315+
316+ Function& shape_compute_function =
317+ module .get_function (shape_compute_function_name);
318+ std::shared_ptr<Graph> graph =
319+ toGraphFunction (shape_compute_function).graph ();
320+
321+ transformShapeFunction (schema_string, graph);
322+ // NB: we lint the shape functions registered in source
323+ // in a test file
324+ // LintShapeComputeGraph(schema_string, graph);
203325
204326 cached_schema_to_graph[schema_string] = graph;
205327 reused_functions[shape_compute_function_name] = graph;
@@ -299,8 +421,34 @@ void RegisterShapeComputeGraphForSchema(
299421 if (cached_schema_to_graph.size () == 0 ) {
300422 loadFunctions ();
301423 }
424+ transformShapeFunction (&schema, g);
425+ LintShapeComputeGraph (&schema, g);
426+
302427 cached_schema_to_graph[&schema] = g;
303428}
304429
430+ std::vector<const FunctionSchema*> RegisteredShapeComputeSchemas () {
431+ std::lock_guard<std::mutex> guard (lock);
432+ if (cached_schema_to_graph.size () == 0 ) {
433+ loadFunctions ();
434+ }
435+
436+ std::vector<const FunctionSchema*> schemas;
437+ schemas.reserve (cached_schema_to_graph.size ());
438+ for (const auto & pair : cached_schema_to_graph) {
439+ schemas.push_back (pair.first );
440+ }
441+ return schemas;
442+ }
443+
444+ void LintShapeComputeGraph (
445+ const FunctionSchema* schema,
446+ const std::shared_ptr<Graph>& graph) {
447+ checkInputAndOutputTypes (schema, graph);
448+ checkForWhileLoop (schema, graph);
449+ checkInputReturnedAsOutput (schema, graph);
450+ // TODO: other checks ? list ops which we don't symbolically optimize, etc ?
451+ }
452+
305453} // namespace jit
306454} // namespace torch
0 commit comments