@@ -39,6 +39,22 @@ namespace experimental {
3939 PrivateThreadPoolDatasetOp::kDatasetType ;
4040/* static */ constexpr const char * const PrivateThreadPoolDatasetOp::kDatasetOp ;
4141
42+ namespace {
43+ // To prevent integer overflow issues when allocating threadpool memory for an
44+ // unreasonable number of threads.
45+ constexpr int kThreadLimit = 65536 ;
46+
47+ Status ValidateNumThreads (int32_t num_threads) {
48+ if (num_threads < 0 ) {
49+ return errors::InvalidArgument (" `num_threads` must be >= 0" );
50+ }
51+ if (num_threads >= kThreadLimit ) {
52+ return errors::InvalidArgument (" `num_threads` must be < " , kThreadLimit );
53+ }
54+ return Status::OK ();
55+ }
56+ } // namespace
57+
4258class ThreadPoolResource : public ResourceBase {
4359 public:
4460 ThreadPoolResource (Env* env, const ThreadOptions& thread_options,
@@ -83,9 +99,7 @@ class ThreadPoolHandleOp : public OpKernel {
8399 OP_REQUIRES_OK (ctx, ctx->GetAttr (" num_threads" , &num_threads_));
84100 OP_REQUIRES_OK (ctx, ctx->GetAttr (" max_intra_op_parallelism" ,
85101 &max_intra_op_parallelism_));
86- OP_REQUIRES (
87- ctx, num_threads_ > 0 ,
88- errors::InvalidArgument (" `num_threads` must be greater than zero." ));
102+ OP_REQUIRES_OK (ctx, ValidateNumThreads (num_threads_));
89103 }
90104
91105 // The resource is deleted from the resource manager only when it is private
@@ -531,8 +545,7 @@ void PrivateThreadPoolDatasetOp::MakeDatasetFromOptions(OpKernelContext* ctx,
531545 DatasetBase* input,
532546 int32_t num_threads,
533547 DatasetBase** output) {
534- OP_REQUIRES (ctx, num_threads >= 0 ,
535- errors::InvalidArgument (" `num_threads` must be >= 0" ));
548+ OP_REQUIRES_OK (ctx, ValidateNumThreads (num_threads));
536549 *output = new Dataset (ctx,
537550 DatasetContext (DatasetContext::Params (
538551 {PrivateThreadPoolDatasetOp::kDatasetType ,
@@ -546,8 +559,7 @@ void PrivateThreadPoolDatasetOp::MakeDataset(OpKernelContext* ctx,
546559 int64_t num_threads = 0 ;
547560 OP_REQUIRES_OK (
548561 ctx, ParseScalarArgument<int64_t >(ctx, " num_threads" , &num_threads));
549- OP_REQUIRES (ctx, num_threads >= 0 ,
550- errors::InvalidArgument (" `num_threads` must be >= 0" ));
562+ OP_REQUIRES_OK (ctx, ValidateNumThreads (num_threads));
551563 *output = new Dataset (ctx, input, num_threads);
552564}
553565
0 commit comments