File tree Expand file tree Collapse file tree 1 file changed +9
-0
lines changed
torch/csrc/jit/tensorexpr Expand file tree Collapse file tree 1 file changed +9
-0
lines changed Original file line number Diff line number Diff line change @@ -868,15 +868,24 @@ void CudaCodeGen::call(const std::vector<CallArg>& args) {
868868
869869 std::vector<int > gpu_block_extents_v (3 , 1 );
870870 std::vector<int > gpu_thread_extents_v (3 , 1 );
871+
871872 // evaluate all the block/thread extents into values
872873 // TODO: eventually, codegen these calculations and make them part of the
873874 // module.
874875 for (size_t i = 0 ; i < gpu_block_extents.size (); i++) {
876+ if (gpu_block_extents[i]->isConstant ()) {
877+ gpu_block_extents_v[i] = immediateAs<int >(gpu_block_extents[i]);
878+ continue ;
879+ }
875880 ExprEval<SimpleIREvaluator> eval (
876881 ExprHandle (gpu_block_extents[i]), buffer_args ());
877882 gpu_block_extents_v[i] = eval.value <int >(args);
878883 }
879884 for (size_t i = 0 ; i < gpu_thread_extents.size (); i++) {
885+ if (gpu_thread_extents[i]->isConstant ()) {
886+ gpu_thread_extents_v[i] = immediateAs<int >(gpu_thread_extents[i]);
887+ continue ;
888+ }
880889 ExprEval<SimpleIREvaluator> eval (
881890 ExprHandle (gpu_thread_extents[i]), buffer_args ());
882891 gpu_thread_extents_v[i] = eval.value <int >(args);
You can’t perform that action at this time.
0 commit comments