Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 70 additions & 2 deletions tensorflow_serving/apis/prediction_service_pb2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# ==============================================================================
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: tensorflow_serving/apis/prediction_service.proto
# To regenerate run
# python -m grpc.tools.protoc --python_out=. --grpc_python_out=. -I. tensorflow_serving/apis/prediction_service.proto

import sys
_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
Expand All @@ -27,17 +29,19 @@
_sym_db = _symbol_database.Default()


from tensorflow_serving.apis import classification_pb2 as tensorflow__serving_dot_apis_dot_classification__pb2
from tensorflow_serving.apis import get_model_metadata_pb2 as tensorflow__serving_dot_apis_dot_get__model__metadata__pb2
from tensorflow_serving.apis import predict_pb2 as tensorflow__serving_dot_apis_dot_predict__pb2
from tensorflow_serving.apis import regression_pb2 as tensorflow__serving_dot_apis_dot_regression__pb2


DESCRIPTOR = _descriptor.FileDescriptor(
name='tensorflow_serving/apis/prediction_service.proto',
package='tensorflow.serving',
syntax='proto3',
serialized_pb=_b('\n0tensorflow_serving/apis/prediction_service.proto\x12\x12tensorflow.serving\x1a\x30tensorflow_serving/apis/get_model_metadata.proto\x1a%tensorflow_serving/apis/predict.proto2\xd6\x01\n\x11PredictionService\x12R\n\x07Predict\x12\".tensorflow.serving.PredictRequest\x1a#.tensorflow.serving.PredictResponse\x12m\n\x10GetModelMetadata\x12+.tensorflow.serving.GetModelMetadataRequest\x1a,.tensorflow.serving.GetModelMetadataResponseB\x03\xf8\x01\x01\x62\x06proto3')
serialized_pb=_b('\n0tensorflow_serving/apis/prediction_service.proto\x12\x12tensorflow.serving\x1a,tensorflow_serving/apis/classification.proto\x1a\x30tensorflow_serving/apis/get_model_metadata.proto\x1a%tensorflow_serving/apis/predict.proto\x1a(tensorflow_serving/apis/regression.proto2\x93\x03\n\x11PredictionService\x12\x61\n\x08\x43lassify\x12).tensorflow.serving.ClassificationRequest\x1a*.tensorflow.serving.ClassificationResponse\x12X\n\x07Regress\x12%.tensorflow.serving.RegressionRequest\x1a&.tensorflow.serving.RegressionResponse\x12R\n\x07Predict\x12\".tensorflow.serving.PredictRequest\x1a#.tensorflow.serving.PredictResponse\x12m\n\x10GetModelMetadata\x12+.tensorflow.serving.GetModelMetadataRequest\x1a,.tensorflow.serving.GetModelMetadataResponseB\x03\xf8\x01\x01\x62\x06proto3')
,
dependencies=[tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.DESCRIPTOR,tensorflow__serving_dot_apis_dot_predict__pb2.DESCRIPTOR,])
dependencies=[tensorflow__serving_dot_apis_dot_classification__pb2.DESCRIPTOR,tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.DESCRIPTOR,tensorflow__serving_dot_apis_dot_predict__pb2.DESCRIPTOR,tensorflow__serving_dot_apis_dot_regression__pb2.DESCRIPTOR,])
_sym_db.RegisterFileDescriptor(DESCRIPTOR)


Expand Down Expand Up @@ -66,6 +70,16 @@ def __init__(self, channel):
Args:
channel: A grpc.Channel.
"""
self.Classify = channel.unary_unary(
'/tensorflow.serving.PredictionService/Classify',
request_serializer=tensorflow__serving_dot_apis_dot_classification__pb2.ClassificationRequest.SerializeToString,
response_deserializer=tensorflow__serving_dot_apis_dot_classification__pb2.ClassificationResponse.FromString,
)
self.Regress = channel.unary_unary(
'/tensorflow.serving.PredictionService/Regress',
request_serializer=tensorflow__serving_dot_apis_dot_regression__pb2.RegressionRequest.SerializeToString,
response_deserializer=tensorflow__serving_dot_apis_dot_regression__pb2.RegressionResponse.FromString,
)
self.Predict = channel.unary_unary(
'/tensorflow.serving.PredictionService/Predict',
request_serializer=tensorflow__serving_dot_apis_dot_predict__pb2.PredictRequest.SerializeToString,
Expand All @@ -84,6 +98,20 @@ class PredictionServiceServicer(object):
model_servers.
"""

def Classify(self, request, context):
"""Classify.
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def Regress(self, request, context):
"""Regress.
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def Predict(self, request, context):
"""Predict -- provides access to loaded TensorFlow model.
"""
Expand All @@ -101,6 +129,16 @@ def GetModelMetadata(self, request, context):

def add_PredictionServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
'Classify': grpc.unary_unary_rpc_method_handler(
servicer.Classify,
request_deserializer=tensorflow__serving_dot_apis_dot_classification__pb2.ClassificationRequest.FromString,
response_serializer=tensorflow__serving_dot_apis_dot_classification__pb2.ClassificationResponse.SerializeToString,
),
'Regress': grpc.unary_unary_rpc_method_handler(
servicer.Regress,
request_deserializer=tensorflow__serving_dot_apis_dot_regression__pb2.RegressionRequest.FromString,
response_serializer=tensorflow__serving_dot_apis_dot_regression__pb2.RegressionResponse.SerializeToString,
),
'Predict': grpc.unary_unary_rpc_method_handler(
servicer.Predict,
request_deserializer=tensorflow__serving_dot_apis_dot_predict__pb2.PredictRequest.FromString,
Expand All @@ -127,6 +165,14 @@ class BetaPredictionServiceServicer(object):
PredictionService provides access to machine-learned models loaded by
model_servers.
"""
def Classify(self, request, context):
"""Classify.
"""
context.code(beta_interfaces.StatusCode.UNIMPLEMENTED)
def Regress(self, request, context):
"""Regress.
"""
context.code(beta_interfaces.StatusCode.UNIMPLEMENTED)
def Predict(self, request, context):
"""Predict -- provides access to loaded TensorFlow model.
"""
Expand All @@ -147,6 +193,16 @@ class BetaPredictionServiceStub(object):
PredictionService provides access to machine-learned models loaded by
model_servers.
"""
def Classify(self, request, timeout, metadata=None, with_call=False, protocol_options=None):
"""Classify.
"""
raise NotImplementedError()
Classify.future = None
def Regress(self, request, timeout, metadata=None, with_call=False, protocol_options=None):
"""Regress.
"""
raise NotImplementedError()
Regress.future = None
def Predict(self, request, timeout, metadata=None, with_call=False, protocol_options=None):
"""Predict -- provides access to loaded TensorFlow model.
"""
Expand All @@ -166,16 +222,22 @@ def beta_create_PredictionService_server(servicer, pool=None, pool_size=None, de
file not marked beta) for all further purposes. This function was
generated only to ease transition from grpcio<0.15.0 to grpcio>=0.15.0"""
request_deserializers = {
('tensorflow.serving.PredictionService', 'Classify'): tensorflow__serving_dot_apis_dot_classification__pb2.ClassificationRequest.FromString,
('tensorflow.serving.PredictionService', 'GetModelMetadata'): tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.GetModelMetadataRequest.FromString,
('tensorflow.serving.PredictionService', 'Predict'): tensorflow__serving_dot_apis_dot_predict__pb2.PredictRequest.FromString,
('tensorflow.serving.PredictionService', 'Regress'): tensorflow__serving_dot_apis_dot_regression__pb2.RegressionRequest.FromString,
}
response_serializers = {
('tensorflow.serving.PredictionService', 'Classify'): tensorflow__serving_dot_apis_dot_classification__pb2.ClassificationResponse.SerializeToString,
('tensorflow.serving.PredictionService', 'GetModelMetadata'): tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.GetModelMetadataResponse.SerializeToString,
('tensorflow.serving.PredictionService', 'Predict'): tensorflow__serving_dot_apis_dot_predict__pb2.PredictResponse.SerializeToString,
('tensorflow.serving.PredictionService', 'Regress'): tensorflow__serving_dot_apis_dot_regression__pb2.RegressionResponse.SerializeToString,
}
method_implementations = {
('tensorflow.serving.PredictionService', 'Classify'): face_utilities.unary_unary_inline(servicer.Classify),
('tensorflow.serving.PredictionService', 'GetModelMetadata'): face_utilities.unary_unary_inline(servicer.GetModelMetadata),
('tensorflow.serving.PredictionService', 'Predict'): face_utilities.unary_unary_inline(servicer.Predict),
('tensorflow.serving.PredictionService', 'Regress'): face_utilities.unary_unary_inline(servicer.Regress),
}
server_options = beta_implementations.server_options(request_deserializers=request_deserializers, response_serializers=response_serializers, thread_pool=pool, thread_pool_size=pool_size, default_timeout=default_timeout, maximum_timeout=maximum_timeout)
return beta_implementations.server(method_implementations, options=server_options)
Expand All @@ -188,16 +250,22 @@ def beta_create_PredictionService_stub(channel, host=None, metadata_transformer=
file not marked beta) for all further purposes. This function was
generated only to ease transition from grpcio<0.15.0 to grpcio>=0.15.0"""
request_serializers = {
('tensorflow.serving.PredictionService', 'Classify'): tensorflow__serving_dot_apis_dot_classification__pb2.ClassificationRequest.SerializeToString,
('tensorflow.serving.PredictionService', 'GetModelMetadata'): tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.GetModelMetadataRequest.SerializeToString,
('tensorflow.serving.PredictionService', 'Predict'): tensorflow__serving_dot_apis_dot_predict__pb2.PredictRequest.SerializeToString,
('tensorflow.serving.PredictionService', 'Regress'): tensorflow__serving_dot_apis_dot_regression__pb2.RegressionRequest.SerializeToString,
}
response_deserializers = {
('tensorflow.serving.PredictionService', 'Classify'): tensorflow__serving_dot_apis_dot_classification__pb2.ClassificationResponse.FromString,
('tensorflow.serving.PredictionService', 'GetModelMetadata'): tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.GetModelMetadataResponse.FromString,
('tensorflow.serving.PredictionService', 'Predict'): tensorflow__serving_dot_apis_dot_predict__pb2.PredictResponse.FromString,
('tensorflow.serving.PredictionService', 'Regress'): tensorflow__serving_dot_apis_dot_regression__pb2.RegressionResponse.FromString,
}
cardinalities = {
'Classify': cardinality.Cardinality.UNARY_UNARY,
'GetModelMetadata': cardinality.Cardinality.UNARY_UNARY,
'Predict': cardinality.Cardinality.UNARY_UNARY,
'Regress': cardinality.Cardinality.UNARY_UNARY,
}
stub_options = beta_implementations.stub_options(host=host, metadata_transformer=metadata_transformer, request_serializers=request_serializers, response_deserializers=response_deserializers, thread_pool=pool, thread_pool_size=pool_size)
return beta_implementations.dynamic_stub(channel, 'tensorflow.serving.PredictionService', cardinalities, options=stub_options)
Expand Down
81 changes: 66 additions & 15 deletions tensorflow_serving/batching/batching_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,14 @@ class BatchingSession : public ServingSession {
const std::vector<string>& target_node_names,
std::vector<Tensor>* outputs) override;

// TODO(b/34971139): at the moment this method ignores run_options and
// run_metadata and behaves exactly like Run.
// RunOptions handling:
// Since multiple of these Run() calls get backed into a single call to the
// underlying Session's Run(), we select an arbitrary 'run_options' (typically
// they are the same across calls). The exception is the timeout; we take the
// largest value (after subtracting time spent in the batching queue).
//
// RunMetadata:
// We copy the batched call's RunMetadata to each non-batched call's output.
Status Run(const RunOptions& run_options,
const std::vector<std::pair<string, Tensor>>& inputs,
const std::vector<string>& output_tensor_names,
Expand Down Expand Up @@ -210,21 +216,21 @@ Status BatchingSession::Create(
}

Status BatchingSession::Run(
const RunOptions& run_options,
const std::vector<std::pair<string, Tensor>>& inputs,
const std::vector<string>& output_tensor_names,
const std::vector<string>& target_node_names, std::vector<Tensor>* outputs,
RunMetadata* run_metadata) {
LOG(WARNING) << "Currently both run_options and run_metadata are ignored, "
<< "see b/34971139";
return Run(inputs, output_tensor_names, target_node_names, outputs);
const std::vector<string>& target_node_names,
std::vector<Tensor>* outputs) {
RunMetadata run_metadata;
return Run(RunOptions(), inputs, output_tensor_names, target_node_names,
outputs, &run_metadata);
}

Status BatchingSession::Run(
const RunOptions& run_options,
const std::vector<std::pair<string, Tensor>>& inputs,
const std::vector<string>& output_tensor_names,
const std::vector<string>& target_node_names,
std::vector<Tensor>* outputs) {
const std::vector<string>& target_node_names, std::vector<Tensor>* outputs,
RunMetadata* run_metadata) {
if (!target_node_names.empty()) {
return errors::PermissionDenied(
"BatchingSession does not support target nodes");
Expand All @@ -239,8 +245,8 @@ Status BatchingSession::Run(
LOG(WARNING) << "Request doesn't match any declared signature. Bypassing "
"batcher. Request signature is: "
<< TensorSignatureDebugString(signature);
return wrapped_->Run(inputs, output_tensor_names, target_node_names,
outputs);
return wrapped_->Run(run_options, inputs, output_tensor_names,
target_node_names, outputs, run_metadata);
}
BatchScheduler<BatchingSessionTask>* batch_scheduler =
batch_scheduler_it->second.get();
Expand All @@ -250,12 +256,15 @@ Status BatchingSession::Run(
Notification done;
Status status;
auto task = std::unique_ptr<BatchingSessionTask>(new BatchingSessionTask);
task->enqueue_time_micros = Env::Default()->NowMicros();
task->run_options = run_options;
TF_RETURN_IF_ERROR(ComputeInputSize(inputs, &task->zeroth_dim_size));
task->inputs = &inputs;
task->output_tensor_names = &output_tensor_names;
task->done = &done;
task->status = &status;
task->outputs = outputs;
task->run_metadata = run_metadata;

TF_RETURN_IF_ERROR(batch_scheduler->Schedule(&task));
done.WaitForNotification();
Expand Down Expand Up @@ -457,18 +466,55 @@ void BatchingSession::ProcessBatch(
return;
}

Status status;
const uint64 dequeue_time_micros = Env::Default()->NowMicros();

// Regardless of the outcome, we need to propagate the status to the
// individual tasks and signal that they are done. We use MakeCleanup() to
// ensure that this happens no matter how we exit the method below.
Status status;
auto finally = MakeCleanup([&status, &batch] {
for (int i = 0; i < batch->num_tasks(); ++i) {
*batch->mutable_task(i)->status = status;
batch->mutable_task(i)->done->Notify();
}
});

// Make sure we have at least one task that hasn't exceeded its timeout from
// queue time alone, and find the latest task deadline which we'll use for the
// overall batch.
bool all_tasks_timeout_exceeded = true;
uint64 batch_deadline_micros = 0;
for (int i = 0; i < batch->num_tasks(); ++i) {
const BatchingSessionTask& task = batch->task(i);
// If the caller doesn't populate RunOptions, the timeout is 0 by default.
// Interpret that as "no timeout" i.e. infinity.
const int64 task_timeout_micros =
task.run_options.timeout_in_ms() <= 0
? INT_MAX
: task.run_options.timeout_in_ms() * 1000;
const uint64 task_deadline_micros =
task.enqueue_time_micros + task_timeout_micros;
if (task_deadline_micros > dequeue_time_micros) {
all_tasks_timeout_exceeded = false;
if (task_deadline_micros > batch_deadline_micros) {
batch_deadline_micros = task_deadline_micros;
}
}
}
if (all_tasks_timeout_exceeded) {
status = Status(error::RESOURCE_EXHAUSTED,
"Run() timeout exceeded while waiting in batching queue");
return;
}

RunOptions run_options = batch->task(0).run_options;
if (batch_deadline_micros == INT_MAX) {
run_options.set_timeout_in_ms(0);
} else {
run_options.set_timeout_in_ms(
(batch_deadline_micros - dequeue_time_micros) / 1000);
}

std::vector<std::pair<string, Tensor>> merged_inputs;
status = MergeInputTensors(signature, *batch, &merged_inputs);
if (!status.ok()) {
Expand All @@ -478,8 +524,13 @@ void BatchingSession::ProcessBatch(
const std::vector<string> output_tensor_names(
signature.output_tensors.begin(), signature.output_tensors.end());
std::vector<Tensor> combined_outputs;
status = wrapped_->Run(merged_inputs, output_tensor_names,
{} /* target node names */, &combined_outputs);
RunMetadata run_metadata;
status = wrapped_->Run(run_options, merged_inputs, output_tensor_names,
{} /* target node names */, &combined_outputs,
&run_metadata);
for (int i = 0; i < batch->num_tasks(); ++i) {
*(batch->mutable_task(i)->run_metadata) = run_metadata;
}
if (!status.ok()) {
return;
}
Expand Down
3 changes: 3 additions & 0 deletions tensorflow_serving/batching/batching_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ struct BatchingSessionTask : public BatchTask {
size_t size() const override { return zeroth_dim_size; }

// Fields populated when a task is received.
uint64 enqueue_time_micros;
RunOptions run_options;
size_t zeroth_dim_size;
const std::vector<std::pair<string, Tensor>>* inputs;
const std::vector<string>* output_tensor_names;
Expand All @@ -175,6 +177,7 @@ struct BatchingSessionTask : public BatchTask {
Notification* done;
Status* status;
std::vector<Tensor>* outputs;
RunMetadata* run_metadata;
};

} // namespace serving
Expand Down
Loading