Skip to content

Commit 6820604

Browse files
committed
添加 grpc Serving 测试的 grcp_client.py 代码
1 parent 3a65699 commit 6820604

File tree

3 files changed

+58
-1
lines changed

3 files changed

+58
-1
lines changed

caicloud.tensorflow/caicloud/clever/examples/ptb/README.md

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,22 @@ $ python ptb_word_lm.py --data_path=./simple-examples/data/ --model=small
2020

2121
## TaaS 平台任务
2222

23-
ptb\_platform.py 代码文件将 ptb\_word\_lm.py 文件中执行模型训练的相关处理逻辑按照 TaaS 的模型训练任务框架进行了调整
23+
ptb\_platform.py 代码文件将 ptb\_word\_lm.py 文件中执行模型训练的相关处理逻辑按照 CaiCloud TaaS 深度学习平台的模型训练任务框架进行了调整。我们可以直接使用该代码文件在 TaaS 平台上启动一个分布式 TensorFlow 模型训练任务
2424

2525
提供了 train-model.sh 文件用于验证单机版模型训练任务的执行。
2626

2727
```shell
2828
$ ./train-model.sh
2929
```
3030

31+
## TaaS gRPC Serving API 测试
32+
33+
ptb\_platform.py 文件中提供了训练模型导出的实现。我们可以将导出的模型在 CaiCloud TaaS 深度学习平台上启动一个模型托管 Serving。grpc\_client.py 文件提供了访问 TaaS 平台 gRPC Serving API 的 client 代码,该代码中使用 PTB 测试数据集的第一个样本数据来调用 gRPC Serving 的预测方法。
34+
35+
通过运行命令来测试。
36+
37+
```shell
38+
$ python grpc_client.py --data_path=/path/to/ptb/data
39+
```
40+
41+
test-grpc.sh 脚本提供了一个快速运行该命令的入口。
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# coding=utf-8
2+
# Copyright 2017 Caicloud authors. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ==============================================================================
16+
17+
from __future__ import print_function
18+
19+
import tensorflow as tf
20+
import ptb_word_lm
21+
import reader
22+
23+
from caicloud.clever.serving.client import grpc_client as serving_grpc_client
24+
25+
FLAGS = tf.flags.FLAGS
26+
27+
def run():
28+
client = serving_grpc_client.GRPCClient('localhost:50051')
29+
30+
# 读取 PTB 数据集
31+
print("Loading ptb data...")
32+
train_data, valid_data, test_data, _ = reader.ptb_raw_data(FLAGS.data_path)
33+
34+
inputs = {
35+
'input': tf.contrib.util.make_tensor_proto(test_data[0], shape=[1,1]),
36+
}
37+
outputs = client.call_predict(inputs)
38+
result = tf.contrib.util.make_ndarray(outputs["logits"])
39+
print('logits: {0}'.format(result))
40+
41+
if __name__ == '__main__':
42+
run()
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#!/bin/bash
2+
3+
python grpc_client.py --data_path=./simple-examples/data
4+

0 commit comments

Comments
 (0)