Skip to content

Commit e0ac732

Browse files
Pass grpc deadline to session Run via RunOptions.
Change: 146411764
1 parent d75aab5 commit e0ac732

File tree

4 files changed

+77
-35
lines changed

4 files changed

+77
-35
lines changed

tensorflow_serving/model_servers/main.cc

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,12 @@ ModelServerConfig BuildSingleModelConfig(
138138
return config;
139139
}
140140

141+
int DeadlineToTimeoutMillis(const gpr_timespec deadline) {
142+
return gpr_time_to_millis(
143+
gpr_time_sub(gpr_convert_clock_type(deadline, GPR_CLOCK_MONOTONIC),
144+
gpr_now(GPR_CLOCK_MONOTONIC)));
145+
}
146+
141147
grpc::Status ToGRPCStatus(const tensorflow::Status& status) {
142148
const int kErrorMessageLimit = 1024;
143149
string error_message;
@@ -161,8 +167,12 @@ class PredictionServiceImpl final : public PredictionService::Service {
161167

162168
grpc::Status Predict(ServerContext* context, const PredictRequest* request,
163169
PredictResponse* response) override {
164-
const grpc::Status status =
165-
ToGRPCStatus(predictor_->Predict(core_.get(), *request, response));
170+
tensorflow::RunOptions run_options = tensorflow::RunOptions();
171+
// By default, this is infinite which is the same default as RunOptions.
172+
run_options.set_timeout_in_ms(
173+
DeadlineToTimeoutMillis(context->raw_deadline()));
174+
const grpc::Status status = ToGRPCStatus(
175+
predictor_->Predict(run_options, core_.get(), *request, response));
166176
if (!status.ok()) {
167177
VLOG(1) << "Predict failed: " << status.error_message();
168178
}

tensorflow_serving/servables/tensorflow/predict_impl.cc

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ namespace serving {
3434
namespace {
3535

3636
// Implementation of Predict using the legacy SessionBundle GenericSignature.
37-
Status SessionBundlePredict(ServerCore* core, const PredictRequest& request,
37+
Status SessionBundlePredict(const RunOptions& run_options, ServerCore* core,
38+
const PredictRequest& request,
3839
PredictResponse* response) {
3940
// Validate signatures.
4041
ServableHandle<SessionBundle> bundle;
@@ -111,8 +112,9 @@ Status SessionBundlePredict(ServerCore* core, const PredictRequest& request,
111112

112113
// Run session.
113114
std::vector<Tensor> outputs;
114-
TF_RETURN_IF_ERROR(
115-
bundle->session->Run(inputs, output_tensor_names, {}, &outputs));
115+
RunMetadata run_metadata;
116+
TF_RETURN_IF_ERROR(bundle->session->Run(
117+
run_options, inputs, output_tensor_names, {}, &outputs, &run_metadata));
116118

117119
// Validate and return output.
118120
if (outputs.size() != output_tensor_names.size()) {
@@ -232,7 +234,8 @@ Status PostProcessPredictionResult(
232234
}
233235

234236
// Implementation of Predict using the SavedModel SignatureDef format.
235-
Status SavedModelPredict(ServerCore* core, const PredictRequest& request,
237+
Status SavedModelPredict(const RunOptions& run_options, ServerCore* core,
238+
const PredictRequest& request,
236239
PredictResponse* response) {
237240
// Validate signatures.
238241
ServableHandle<SavedModelBundle> bundle;
@@ -255,26 +258,29 @@ Status SavedModelPredict(ServerCore* core, const PredictRequest& request,
255258
&output_tensor_names,
256259
&output_tensor_aliases));
257260
std::vector<Tensor> outputs;
258-
TF_RETURN_IF_ERROR(
259-
bundle->session->Run(input_tensors, output_tensor_names, {}, &outputs));
261+
RunMetadata run_metadata;
262+
TF_RETURN_IF_ERROR(bundle->session->Run(run_options, input_tensors,
263+
output_tensor_names, {}, &outputs,
264+
&run_metadata));
260265

261266
return PostProcessPredictionResult(signature, output_tensor_aliases, outputs,
262267
response);
263268
}
264269

265270
} // namespace
266271

267-
Status TensorflowPredictor::Predict(ServerCore* core,
272+
Status TensorflowPredictor::Predict(const RunOptions& run_options,
273+
ServerCore* core,
268274
const PredictRequest& request,
269275
PredictResponse* response) {
270276
if (!request.has_model_spec()) {
271277
return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
272278
"Missing ModelSpec");
273279
}
274280
if (use_saved_model_) {
275-
return SavedModelPredict(core, request, response);
281+
return SavedModelPredict(run_options, core, request, response);
276282
}
277-
return SessionBundlePredict(core, request, response);
283+
return SessionBundlePredict(run_options, core, request, response);
278284
}
279285

280286
} // namespace serving

tensorflow_serving/servables/tensorflow/predict_impl.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717
#define TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_PREDICT_IMPL_H_
1818

1919
#include "tensorflow/core/lib/core/status.h"
20+
#include "tensorflow/core/protobuf/config.pb.h"
2021
#include "tensorflow_serving/apis/predict.pb.h"
2122
#include "tensorflow_serving/model_servers/server_core.h"
2223

@@ -29,8 +30,8 @@ class TensorflowPredictor {
2930
explicit TensorflowPredictor(bool use_saved_model)
3031
: use_saved_model_(use_saved_model) {}
3132

32-
Status Predict(ServerCore* core, const PredictRequest& request,
33-
PredictResponse* response);
33+
Status Predict(const RunOptions& run_options, ServerCore* core,
34+
const PredictRequest& request, PredictResponse* response);
3435

3536
private:
3637
// If use_saved_model_ is true, a SavedModelBundle handle will be retrieved

tensorflow_serving/servables/tensorflow/predict_impl_test.cc

Lines changed: 47 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ class PredictImplTest : public ::testing::TestWithParam<bool> {
9999
return server_core_bad_model_.get();
100100
}
101101

102+
RunOptions GetRunOptions() { return RunOptions(); }
103+
102104
private:
103105
static std::unique_ptr<ServerCore> server_core_;
104106
static std::unique_ptr<ServerCore> server_core_bad_model_;
@@ -117,20 +119,26 @@ TEST_P(PredictImplTest, MissingOrEmptyModelSpec) {
117119

118120
// Empty request is invalid.
119121
TensorflowPredictor predictor(GetParam());
120-
EXPECT_EQ(tensorflow::error::INVALID_ARGUMENT,
121-
predictor.Predict(GetServerCore(), request, &response).code());
122+
EXPECT_EQ(
123+
tensorflow::error::INVALID_ARGUMENT,
124+
predictor.Predict(GetRunOptions(), GetServerCore(), request, &response)
125+
.code());
122126

123127
ModelSpec* model_spec = request.mutable_model_spec();
124128
model_spec->clear_name();
125129

126130
// Model name is not specified.
127-
EXPECT_EQ(tensorflow::error::INVALID_ARGUMENT,
128-
predictor.Predict(GetServerCore(), request, &response).code());
131+
EXPECT_EQ(
132+
tensorflow::error::INVALID_ARGUMENT,
133+
predictor.Predict(GetRunOptions(), GetServerCore(), request, &response)
134+
.code());
129135

130136
// Model name is wrong, not found.
131137
model_spec->set_name("test");
132-
EXPECT_EQ(tensorflow::error::NOT_FOUND,
133-
predictor.Predict(GetServerCore(), request, &response).code());
138+
EXPECT_EQ(
139+
tensorflow::error::NOT_FOUND,
140+
predictor.Predict(GetRunOptions(), GetServerCore(), request, &response)
141+
.code());
134142
}
135143

136144
TEST_P(PredictImplTest, EmptyInputList) {
@@ -143,8 +151,10 @@ TEST_P(PredictImplTest, EmptyInputList) {
143151

144152
TensorflowPredictor predictor(GetParam());
145153
// The input is empty.
146-
EXPECT_EQ(tensorflow::error::INVALID_ARGUMENT,
147-
predictor.Predict(GetServerCore(), request, &response).code());
154+
EXPECT_EQ(
155+
tensorflow::error::INVALID_ARGUMENT,
156+
predictor.Predict(GetRunOptions(), GetServerCore(), request, &response)
157+
.code());
148158
}
149159

150160
TEST_P(PredictImplTest, InputTensorsDontMatchModelSpecInputs) {
@@ -163,8 +173,10 @@ TEST_P(PredictImplTest, InputTensorsDontMatchModelSpecInputs) {
163173
TensorflowPredictor predictor(GetParam());
164174
auto inputs = request.mutable_inputs();
165175
(*inputs)["key"] = tensor_proto;
166-
EXPECT_EQ(tensorflow::error::INVALID_ARGUMENT,
167-
predictor.Predict(GetServerCore(), request, &response).code());
176+
EXPECT_EQ(
177+
tensorflow::error::INVALID_ARGUMENT,
178+
predictor.Predict(GetRunOptions(), GetServerCore(), request, &response)
179+
.code());
168180
}
169181

170182
TEST_P(PredictImplTest, OutputFiltersDontMatchModelSpecOutputs) {
@@ -183,17 +195,22 @@ TEST_P(PredictImplTest, OutputFiltersDontMatchModelSpecOutputs) {
183195

184196
TensorflowPredictor predictor(GetParam());
185197
// Output filter like this doesn't exist.
186-
EXPECT_EQ(tensorflow::error::INVALID_ARGUMENT,
187-
predictor.Predict(GetServerCore(), request, &response).code());
198+
EXPECT_EQ(
199+
tensorflow::error::INVALID_ARGUMENT,
200+
predictor.Predict(GetRunOptions(), GetServerCore(), request, &response)
201+
.code());
188202

189203
request.clear_output_filter();
190204
request.add_output_filter(kOutputTensorKey);
191-
TF_EXPECT_OK(predictor.Predict(GetServerCore(), request, &response));
205+
TF_EXPECT_OK(
206+
predictor.Predict(GetRunOptions(), GetServerCore(), request, &response));
192207
request.add_output_filter(kOutputTensorKey);
193208

194209
// Duplicate output filter specified.
195-
EXPECT_EQ(tensorflow::error::INVALID_ARGUMENT,
196-
predictor.Predict(GetServerCore(), request, &response).code());
210+
EXPECT_EQ(
211+
tensorflow::error::INVALID_ARGUMENT,
212+
predictor.Predict(GetRunOptions(), GetServerCore(), request, &response)
213+
.code());
197214
}
198215

199216
TEST_P(PredictImplTest, InputTensorsHaveWrongType) {
@@ -213,8 +230,10 @@ TEST_P(PredictImplTest, InputTensorsHaveWrongType) {
213230

214231
TensorflowPredictor predictor(GetParam());
215232
// Input tensors are all wrong.
216-
EXPECT_EQ(tensorflow::error::INTERNAL,
217-
predictor.Predict(GetServerCore(), request, &response).code());
233+
EXPECT_EQ(
234+
tensorflow::error::INTERNAL,
235+
predictor.Predict(GetRunOptions(), GetServerCore(), request, &response)
236+
.code());
218237
}
219238

220239
TEST_P(PredictImplTest, ModelMissingSignatures) {
@@ -228,7 +247,9 @@ TEST_P(PredictImplTest, ModelMissingSignatures) {
228247
// Model is missing signatures.
229248
TensorflowPredictor predictor(GetParam());
230249
EXPECT_EQ(tensorflow::error::FAILED_PRECONDITION,
231-
predictor.Predict(GetServerCoreWithBadModel(), request, &response)
250+
predictor
251+
.Predict(GetRunOptions(), GetServerCoreWithBadModel(), request,
252+
&response)
232253
.code());
233254
}
234255

@@ -246,7 +267,8 @@ TEST_P(PredictImplTest, PredictionSuccess) {
246267
(*request.mutable_inputs())[kInputTensorKey] = tensor_proto;
247268

248269
TensorflowPredictor predictor(GetParam());
249-
TF_EXPECT_OK(predictor.Predict(GetServerCore(), request, &response));
270+
TF_EXPECT_OK(
271+
predictor.Predict(GetRunOptions(), GetServerCore(), request, &response));
250272
TensorProto output_tensor_proto;
251273
output_tensor_proto.add_float_val(3);
252274
output_tensor_proto.set_dtype(tensorflow::DT_FLOAT);
@@ -276,10 +298,13 @@ TEST_P(PredictImplTest, PredictionWithNamedRegressionSignature) {
276298
TensorflowPredictor predictor(GetParam());
277299
// This request is expected to work with SavedModel, but not SessionBundle.
278300
if (GetParam()) {
279-
TF_ASSERT_OK(predictor.Predict(GetServerCore(), request, &response));
301+
TF_ASSERT_OK(predictor.Predict(GetRunOptions(), GetServerCore(), request,
302+
&response));
280303
} else {
281-
ASSERT_EQ(tensorflow::error::INVALID_ARGUMENT,
282-
predictor.Predict(GetServerCore(), request, &response).code());
304+
ASSERT_EQ(
305+
tensorflow::error::INVALID_ARGUMENT,
306+
predictor.Predict(GetRunOptions(), GetServerCore(), request, &response)
307+
.code());
283308
return;
284309
}
285310
TensorProto output_tensor_proto;

0 commit comments

Comments
 (0)