Skip to content

Commit 95ae233

Browse files
mrrytensorflower-gardener
authored andcommitted
Fix segfault in grpc_server_lib when an invalid hostname-port pair is passed.
Previously, if the port was undefined, an out-of-bounds access would be made. This change adds the appropriate checks. Change: 119424297
1 parent 264bac9 commit 95ae233

4 files changed

Lines changed: 23 additions & 7 deletions

File tree

tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,14 @@ Status GrpcServer::Init() {
103103
return errors::InvalidArgument("Task ", server_def_.task_index(),
104104
" was not defined in job \"",
105105
server_def_.job_name(), "\"");
106-
} else if (!strings::safe_strto32(str_util::Split(iter->second, ':')[1],
107-
&requested_port_)) {
108-
return errors::Internal("Could not parse port for local server from \"",
109-
iter->second, "\"");
106+
}
107+
const std::vector<string> hostname_port =
108+
str_util::Split(iter->second, ':');
109+
if (hostname_port.size() != 2 ||
110+
!strings::safe_strto32(hostname_port[1], &requested_port_)) {
111+
return errors::InvalidArgument(
112+
"Could not parse port for local server from \"", iter->second,
113+
"\"");
110114
} else {
111115
break;
112116
}

tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,12 @@ class GrpcServer : public ServerInterface {
8989

9090
// Implementation of a TensorFlow master, and RPC polling thread.
9191
MasterEnv master_env_;
92-
AsyncServiceInterface* master_service_;
92+
AsyncServiceInterface* master_service_ = nullptr;
9393
std::unique_ptr<Thread> master_thread_ GUARDED_BY(mu_);
9494

9595
// Implementation of a TensorFlow worker, and RPC polling thread.
9696
WorkerEnv worker_env_;
97-
AsyncServiceInterface* worker_service_;
97+
AsyncServiceInterface* worker_service_ = nullptr;
9898
std::unique_ptr<Thread> worker_thread_ GUARDED_BY(mu_);
9999

100100
std::unique_ptr<::grpc::Server> server_ GUARDED_BY(mu_);

tensorflow/python/training/server_lib.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from tensorflow.core.protobuf import tensorflow_server_pb2
2424
from tensorflow.python import pywrap_tensorflow
25+
from tensorflow.python.framework import errors
2526
from tensorflow.python.util import compat
2627

2728

@@ -128,7 +129,12 @@ def __init__(self,
128129
"""
129130
server_def = _make_server_def(server_or_cluster_def,
130131
job_name, task_index, protocol)
131-
self._server = pywrap_tensorflow.NewServer(server_def.SerializeToString())
132+
try:
133+
self._server = pywrap_tensorflow.NewServer(server_def.SerializeToString())
134+
except pywrap_tensorflow.StatusNotOK as e:
135+
# pylint: disable=protected-access
136+
raise errors._make_specific_exception(None, None, e.error_message, e.code)
137+
# pylint: enable=protected-access
132138
if start:
133139
self.start()
134140

tensorflow/python/training/server_lib_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,12 @@ def testLargeFeed(self):
7979
self.assertEqual(0.5, min_val)
8080
self.assertEqual(0.5, max_val)
8181

82+
def testInvalidHostname(self):
83+
with self.assertRaisesRegexp(tf.errors.InvalidArgumentError, "port"):
84+
_ = tf.train.Server({"local": ["localhost"]},
85+
job_name="local",
86+
task_index=0)
87+
8288

8389
class ServerDefTest(tf.test.TestCase):
8490

0 commit comments

Comments
 (0)