Skip to content

Commit e3749a6

Browse files
aaudibertensorflower-gardener
authored andcommitted
[tf.data] Set limit on number of threads used in threadpool_dataset.
PiperOrigin-RevId: 410922677 Change-Id: Ib25814a99043ab10805b5d2d7088ae0e0b7b04fd
1 parent dc94fe9 commit e3749a6

File tree

1 file changed

+19
-7
lines changed

1 file changed

+19
-7
lines changed

tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,22 @@ namespace experimental {
3939
PrivateThreadPoolDatasetOp::kDatasetType;
4040
/* static */ constexpr const char* const PrivateThreadPoolDatasetOp::kDatasetOp;
4141

42+
namespace {
43+
// To prevent integer overflow issues when allocating threadpool memory for an
44+
// unreasonable number of threads.
45+
constexpr int kThreadLimit = 65536;
46+
47+
Status ValidateNumThreads(int32_t num_threads) {
48+
if (num_threads < 0) {
49+
return errors::InvalidArgument("`num_threads` must be >= 0");
50+
}
51+
if (num_threads >= kThreadLimit) {
52+
return errors::InvalidArgument("`num_threads` must be < ", kThreadLimit);
53+
}
54+
return Status::OK();
55+
}
56+
} // namespace
57+
4258
class ThreadPoolResource : public ResourceBase {
4359
public:
4460
ThreadPoolResource(Env* env, const ThreadOptions& thread_options,
@@ -83,9 +99,7 @@ class ThreadPoolHandleOp : public OpKernel {
8399
OP_REQUIRES_OK(ctx, ctx->GetAttr("num_threads", &num_threads_));
84100
OP_REQUIRES_OK(ctx, ctx->GetAttr("max_intra_op_parallelism",
85101
&max_intra_op_parallelism_));
86-
OP_REQUIRES(
87-
ctx, num_threads_ > 0,
88-
errors::InvalidArgument("`num_threads` must be greater than zero."));
102+
OP_REQUIRES_OK(ctx, ValidateNumThreads(num_threads_));
89103
}
90104

91105
// The resource is deleted from the resource manager only when it is private
@@ -531,8 +545,7 @@ void PrivateThreadPoolDatasetOp::MakeDatasetFromOptions(OpKernelContext* ctx,
531545
DatasetBase* input,
532546
int32_t num_threads,
533547
DatasetBase** output) {
534-
OP_REQUIRES(ctx, num_threads >= 0,
535-
errors::InvalidArgument("`num_threads` must be >= 0"));
548+
OP_REQUIRES_OK(ctx, ValidateNumThreads(num_threads));
536549
*output = new Dataset(ctx,
537550
DatasetContext(DatasetContext::Params(
538551
{PrivateThreadPoolDatasetOp::kDatasetType,
@@ -546,8 +559,7 @@ void PrivateThreadPoolDatasetOp::MakeDataset(OpKernelContext* ctx,
546559
int64_t num_threads = 0;
547560
OP_REQUIRES_OK(
548561
ctx, ParseScalarArgument<int64_t>(ctx, "num_threads", &num_threads));
549-
OP_REQUIRES(ctx, num_threads >= 0,
550-
errors::InvalidArgument("`num_threads` must be >= 0"));
562+
OP_REQUIRES_OK(ctx, ValidateNumThreads(num_threads));
551563
*output = new Dataset(ctx, input, num_threads);
552564
}
553565

0 commit comments

Comments
 (0)