Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tensorflow
Submodule tensorflow updated 967 files
28 changes: 27 additions & 1 deletion tensorflow_serving/apis/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,28 @@ filegroup(
load("//tensorflow_serving:serving.bzl", "serving_proto_library")
load("//tensorflow_serving:serving.bzl", "serving_proto_library_py")

serving_proto_library(
name = "get_model_metadata_proto",
srcs = ["get_model_metadata.proto"],
cc_api_version = 2,
go_api_version = 2,
java_api_version = 2,
deps = [
":model_proto",
"@org_tensorflow//tensorflow/core:protos_all_cc",
"@protobuf//:cc_wkt_protos",
],
)

serving_proto_library_py(
name = "get_model_metadata_proto_py_pb2",
srcs = ["get_model_metadata.proto"],
proto_library = "get_model_metadata_proto",
deps = [
"@org_tensorflow//tensorflow/core:protos_all_py",
],
)

serving_proto_library(
name = "model_proto",
srcs = ["model.proto"],
Expand Down Expand Up @@ -72,12 +94,16 @@ serving_proto_library(
go_api_version = 2,
java_api_version = 2,
deps = [
":get_model_metadata_proto",
":predict_proto",
],
)

py_library(
name = "prediction_service_proto_py_pb2",
srcs = ["prediction_service_pb2.py"],
deps = [":predict_proto_py_pb2"],
deps = [
":get_model_metadata_proto_py_pb2",
":predict_proto_py_pb2",
],
)
29 changes: 29 additions & 0 deletions tensorflow_serving/apis/get_model_metadata.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
syntax = "proto3";

package tensorflow.serving;
option cc_enable_arenas = true;

import "google/protobuf/any.proto";
import "tensorflow/core/protobuf/meta_graph.proto";
import "tensorflow_serving/apis/model.proto";

// Message returned for "signature_def" field.
message SignatureDefMap {
map<string, SignatureDef> signature_def = 1;
};

message GetModelMetadataRequest {
// Model Specification indicating which model we are querying for metadata.
ModelSpec model_spec = 1;
// Metadata fields to get. Currently supported: "signature_def".
repeated string metadata_field = 2;
}

message GetModelMetadataResponse {
// Model Specification indicating which model this metadata belongs to.
ModelSpec model_spec = 1;
// Map of metadata field name to metadata field. The options for metadata
// field name are listed in GetModelMetadataRequest. Currently supported:
// "signature_def".
map<string, google.protobuf.Any> metadata = 2;
}
5 changes: 5 additions & 0 deletions tensorflow_serving/apis/prediction_service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ syntax = "proto3";
package tensorflow.serving;
option cc_enable_arenas = true;

import "tensorflow_serving/apis/get_model_metadata.proto";
import "tensorflow_serving/apis/predict.proto";

// open source marker; do not remove
Expand All @@ -11,4 +12,8 @@ import "tensorflow_serving/apis/predict.proto";
service PredictionService {
// Predict -- provides access to loaded TensorFlow model.
rpc Predict(PredictRequest) returns (PredictResponse);

// GetModelMetadata - provides access to metadata for loaded models.
rpc GetModelMetadata(GetModelMetadataRequest)
returns (GetModelMetadataResponse);
}
50 changes: 44 additions & 6 deletions tensorflow_serving/apis/prediction_service_pb2.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,17 @@
_sym_db = _symbol_database.Default()


from tensorflow_serving.apis import get_model_metadata_pb2 as tensorflow__serving_dot_apis_dot_get__model__metadata__pb2
from tensorflow_serving.apis import predict_pb2 as tensorflow__serving_dot_apis_dot_predict__pb2


DESCRIPTOR = _descriptor.FileDescriptor(
name='tensorflow_serving/apis/prediction_service.proto',
package='tensorflow.serving',
syntax='proto3',
serialized_pb=_b('\n0tensorflow_serving/apis/prediction_service.proto\x12\x12tensorflow.serving\x1a%tensorflow_serving/apis/predict.proto2g\n\x11PredictionService\x12R\n\x07Predict\x12\".tensorflow.serving.PredictRequest\x1a#.tensorflow.serving.PredictResponseB\x03\xf8\x01\x01\x62\x06proto3')
serialized_pb=_b('\n0tensorflow_serving/apis/prediction_service.proto\x12\x12tensorflow.serving\x1a\x30tensorflow_serving/apis/get_model_metadata.proto\x1a%tensorflow_serving/apis/predict.proto2\xd6\x01\n\x11PredictionService\x12R\n\x07Predict\x12\".tensorflow.serving.PredictRequest\x1a#.tensorflow.serving.PredictResponse\x12m\n\x10GetModelMetadata\x12+.tensorflow.serving.GetModelMetadataRequest\x1a,.tensorflow.serving.GetModelMetadataResponseB\x03\xf8\x01\x01\x62\x06proto3')
,
dependencies=[tensorflow__serving_dot_apis_dot_predict__pb2.DESCRIPTOR,])
dependencies=[tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.DESCRIPTOR,tensorflow__serving_dot_apis_dot_predict__pb2.DESCRIPTOR,])
_sym_db.RegisterFileDescriptor(DESCRIPTOR)


Expand All @@ -54,7 +55,8 @@


class PredictionServiceStub(object):
"""PredictionService provides access to machine-learned models loaded by
"""open source marker; do not remove
PredictionService provides access to machine-learned models loaded by
model_servers.
"""

Expand All @@ -69,10 +71,16 @@ def __init__(self, channel):
request_serializer=tensorflow__serving_dot_apis_dot_predict__pb2.PredictRequest.SerializeToString,
response_deserializer=tensorflow__serving_dot_apis_dot_predict__pb2.PredictResponse.FromString,
)
self.GetModelMetadata = channel.unary_unary(
'/tensorflow.serving.PredictionService/GetModelMetadata',
request_serializer=tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.GetModelMetadataRequest.SerializeToString,
response_deserializer=tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.GetModelMetadataResponse.FromString,
)


class PredictionServiceServicer(object):
"""PredictionService provides access to machine-learned models loaded by
"""open source marker; do not remove
PredictionService provides access to machine-learned models loaded by
model_servers.
"""

Expand All @@ -83,6 +91,13 @@ def Predict(self, request, context):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def GetModelMetadata(self, request, context):
"""GetModelMetadata - provides access to metadata for loaded models.
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')


def add_PredictionServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
Expand All @@ -91,6 +106,11 @@ def add_PredictionServiceServicer_to_server(servicer, server):
request_deserializer=tensorflow__serving_dot_apis_dot_predict__pb2.PredictRequest.FromString,
response_serializer=tensorflow__serving_dot_apis_dot_predict__pb2.PredictResponse.SerializeToString,
),
'GetModelMetadata': grpc.unary_unary_rpc_method_handler(
servicer.GetModelMetadata,
request_deserializer=tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.GetModelMetadataRequest.FromString,
response_serializer=tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.GetModelMetadataResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'tensorflow.serving.PredictionService', rpc_method_handlers)
Expand All @@ -103,13 +123,18 @@ class BetaPredictionServiceServicer(object):
It is recommended to use the GA API (classes and functions in this
file not marked beta) for all further purposes. This class was generated
only to ease transition from grpcio<0.15.0 to grpcio>=0.15.0."""
"""PredictionService provides access to machine-learned models loaded by
"""open source marker; do not remove
PredictionService provides access to machine-learned models loaded by
model_servers.
"""
def Predict(self, request, context):
"""Predict -- provides access to loaded TensorFlow model.
"""
context.code(beta_interfaces.StatusCode.UNIMPLEMENTED)
def GetModelMetadata(self, request, context):
"""GetModelMetadata - provides access to metadata for loaded models.
"""
context.code(beta_interfaces.StatusCode.UNIMPLEMENTED)


class BetaPredictionServiceStub(object):
Expand All @@ -118,14 +143,20 @@ class BetaPredictionServiceStub(object):
It is recommended to use the GA API (classes and functions in this
file not marked beta) for all further purposes. This class was generated
only to ease transition from grpcio<0.15.0 to grpcio>=0.15.0."""
"""PredictionService provides access to machine-learned models loaded by
"""open source marker; do not remove
PredictionService provides access to machine-learned models loaded by
model_servers.
"""
def Predict(self, request, timeout, metadata=None, with_call=False, protocol_options=None):
"""Predict -- provides access to loaded TensorFlow model.
"""
raise NotImplementedError()
Predict.future = None
def GetModelMetadata(self, request, timeout, metadata=None, with_call=False, protocol_options=None):
"""GetModelMetadata - provides access to metadata for loaded models.
"""
raise NotImplementedError()
GetModelMetadata.future = None


def beta_create_PredictionService_server(servicer, pool=None, pool_size=None, default_timeout=None, maximum_timeout=None):
Expand All @@ -135,12 +166,15 @@ def beta_create_PredictionService_server(servicer, pool=None, pool_size=None, de
file not marked beta) for all further purposes. This function was
generated only to ease transition from grpcio<0.15.0 to grpcio>=0.15.0"""
request_deserializers = {
('tensorflow.serving.PredictionService', 'GetModelMetadata'): tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.GetModelMetadataRequest.FromString,
('tensorflow.serving.PredictionService', 'Predict'): tensorflow__serving_dot_apis_dot_predict__pb2.PredictRequest.FromString,
}
response_serializers = {
('tensorflow.serving.PredictionService', 'GetModelMetadata'): tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.GetModelMetadataResponse.SerializeToString,
('tensorflow.serving.PredictionService', 'Predict'): tensorflow__serving_dot_apis_dot_predict__pb2.PredictResponse.SerializeToString,
}
method_implementations = {
('tensorflow.serving.PredictionService', 'GetModelMetadata'): face_utilities.unary_unary_inline(servicer.GetModelMetadata),
('tensorflow.serving.PredictionService', 'Predict'): face_utilities.unary_unary_inline(servicer.Predict),
}
server_options = beta_implementations.server_options(request_deserializers=request_deserializers, response_serializers=response_serializers, thread_pool=pool, thread_pool_size=pool_size, default_timeout=default_timeout, maximum_timeout=maximum_timeout)
Expand All @@ -154,13 +188,17 @@ def beta_create_PredictionService_stub(channel, host=None, metadata_transformer=
file not marked beta) for all further purposes. This function was
generated only to ease transition from grpcio<0.15.0 to grpcio>=0.15.0"""
request_serializers = {
('tensorflow.serving.PredictionService', 'GetModelMetadata'): tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.GetModelMetadataRequest.SerializeToString,
('tensorflow.serving.PredictionService', 'Predict'): tensorflow__serving_dot_apis_dot_predict__pb2.PredictRequest.SerializeToString,
}
response_deserializers = {
('tensorflow.serving.PredictionService', 'GetModelMetadata'): tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.GetModelMetadataResponse.FromString,
('tensorflow.serving.PredictionService', 'Predict'): tensorflow__serving_dot_apis_dot_predict__pb2.PredictResponse.FromString,
}
cardinalities = {
'GetModelMetadata': cardinality.Cardinality.UNARY_UNARY,
'Predict': cardinality.Cardinality.UNARY_UNARY,
}
stub_options = beta_implementations.stub_options(host=host, metadata_transformer=metadata_transformer, request_serializers=request_serializers, response_deserializers=response_deserializers, thread_pool=pool, thread_pool_size=pool_size)
return beta_implementations.dynamic_stub(channel, 'tensorflow.serving.PredictionService', cardinalities, options=stub_options)
# @@protoc_insertion_point(module_scope)
20 changes: 10 additions & 10 deletions tensorflow_serving/batching/batching_session_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,13 @@ TEST(BatchingSessionTest, TensorSignatureFromSignatureDefs) {
const SignatureDef signature_def_0 =
CreateSignatureDef({{"x0", "x1"}, {"y0", "y1"}});
const SignatureDef signature_def_1 =
CreateSignatureDef({{"x1", "x2"}, {"y1", "y2"}});
CreateSignatureDef({{"x1", "x2"}, {"y1", "y3"}});
const TensorSignature tensor_signature =
TensorSignatureFromSignatureDefs({signature_def_0, signature_def_1});
EXPECT_THAT(tensor_signature.input_tensors,
UnorderedElementsAre("x0", "x1", "x2"));
EXPECT_THAT(tensor_signature.output_tensors,
UnorderedElementsAre("y0", "y1", "y2"));
UnorderedElementsAre("y0", "y1", "y3"));
}

TEST(BatchingSessionTest, Basic) {
Expand Down Expand Up @@ -188,9 +188,9 @@ TEST(BatchingSessionTest, RequestThatDoesntMatchSignatureGetsRunAnyway) {
std::unique_ptr<Session> batching_session;
BatchingSessionOptions batching_session_options;
TF_ASSERT_OK(CreateBasicBatchingSession(
schedule_options, batching_session_options, {{"x2"}, {"y2"}},
schedule_options, batching_session_options, {{"x2"}, {"y3"}},
CreateHalfPlusTwoSession(), &batching_session));
// Issue a request using x/y, which doesn't match the x2/y2 signature.
// Issue a request using x/y, which doesn't match the x2/y3 signature.
TestSingleRequest(100.0f, 42.0f, batching_session.get());
}

Expand Down Expand Up @@ -288,7 +288,7 @@ TEST(BatchingSessionTest, DifferentOrderForInputAndOutputTensors) {
BatchingSessionOptions batching_session_options;
std::unique_ptr<Session> batching_session;
TF_ASSERT_OK(CreateBasicBatchingSession(
schedule_options, batching_session_options, {{"x", "x2"}, {"y", "y2"}},
schedule_options, batching_session_options, {{"x", "x2"}, {"y", "y3"}},
CreateHalfPlusTwoSession(), &batching_session));

const Tensor input0 = test::AsTensor<float>({8.0f, 6.0f}, {2});
Expand All @@ -300,7 +300,7 @@ TEST(BatchingSessionTest, DifferentOrderForInputAndOutputTensors) {
Env::Default()->StartThread(ThreadOptions(), "first_request_thread", [&] {
std::vector<Tensor> outputs;
TF_ASSERT_OK(batching_session->Run({{"x", input0}, {"x2", input1}},
{"y", "y2"} /* outputs */,
{"y", "y3"} /* outputs */,
{} /* target nodes */, &outputs));
ASSERT_EQ(2, outputs.size());
test::ExpectTensorEqual<float>(expected_output0, outputs[0]);
Expand All @@ -310,7 +310,7 @@ TEST(BatchingSessionTest, DifferentOrderForInputAndOutputTensors) {
ThreadOptions(), "second_request_thread", [&] {
std::vector<Tensor> outputs;
TF_ASSERT_OK(batching_session->Run({{"x2", input1}, {"x", input0}},
{"y2", "y"} /* outputs */,
{"y3", "y"} /* outputs */,
{} /* target nodes */, &outputs));
ASSERT_EQ(2, outputs.size());
test::ExpectTensorEqual<float>(expected_output1, outputs[0]);
Expand All @@ -320,7 +320,7 @@ TEST(BatchingSessionTest, DifferentOrderForInputAndOutputTensors) {
Env::Default()->StartThread(ThreadOptions(), "third_request_thread", [&] {
std::vector<Tensor> outputs;
TF_ASSERT_OK(batching_session->Run({{"x2", input1}, {"x", input0}},
{"y", "y2"} /* outputs */,
{"y", "y3"} /* outputs */,
{} /* target nodes */, &outputs));
ASSERT_EQ(2, outputs.size());
test::ExpectTensorEqual<float>(expected_output0, outputs[0]);
Expand Down Expand Up @@ -349,7 +349,7 @@ TEST(BatchingSessionTest, MultipleSignatures) {
std::unique_ptr<Session> batching_session;
TF_CHECK_OK(CreateBatchingSession(
batching_session_options, {{{{"x"}, {"y"}}, create_scheduler},
{{{"x2"}, {"y2"}}, create_scheduler}},
{{{"x2"}, {"y3"}}, create_scheduler}},
CreateHalfPlusTwoSession(), &batching_session));
ASSERT_EQ(2, schedulers.size());

Expand All @@ -367,7 +367,7 @@ TEST(BatchingSessionTest, MultipleSignatures) {
Tensor input = test::AsTensor<float>({100.0f, 42.0f}, {2});
Tensor expected_output = test::AsTensor<float>({53.0f, 24.0f}, {2});
std::vector<Tensor> outputs;
TF_ASSERT_OK(batching_session->Run({{"x2", input}}, {"y2"} /* outputs */,
TF_ASSERT_OK(batching_session->Run({{"x2", input}}, {"y3"} /* outputs */,
{} /* target nodes */, &outputs));
ASSERT_EQ(1, outputs.size());
test::ExpectTensorEqual<float>(expected_output, outputs[0]);
Expand Down
4 changes: 2 additions & 2 deletions tensorflow_serving/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ cc_test(
srcs = ["aspired_versions_manager_builder_test.cc"],
deps = [
":aspired_versions_manager_builder",
":eager_load_policy",
":availability_preserving_policy",
":servable_data",
":servable_handle",
":servable_state_monitor",
Expand Down Expand Up @@ -534,7 +534,7 @@ cc_test(
deps = [
":aspired_version_policy",
":aspired_versions_manager",
":eager_load_policy",
":availability_preserving_policy",
":loader",
":manager",
":servable_data",
Expand Down
6 changes: 3 additions & 3 deletions tensorflow_serving/core/aspired_versions_manager_benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
#include "tensorflow_serving/core/aspired_version_policy.h"
#include "tensorflow_serving/core/aspired_versions_manager.h"
#include "tensorflow_serving/core/eager_load_policy.h"
#include "tensorflow_serving/core/availability_preserving_policy.h"
#include "tensorflow_serving/core/loader.h"
#include "tensorflow_serving/core/manager.h"
#include "tensorflow_serving/core/servable_data.h"
Expand Down Expand Up @@ -74,7 +74,7 @@ class BenchmarkState {
AspiredVersionsManager::Options options;
// Do policy thread won't be run automatically.
options.manage_state_interval_micros = -1;
options.aspired_version_policy.reset(new EagerLoadPolicy());
options.aspired_version_policy.reset(new AvailabilityPreservingPolicy());
TF_CHECK_OK(AspiredVersionsManager::Create(std::move(options), &manager_));
}

Expand Down Expand Up @@ -304,7 +304,7 @@ static void BM_GetServableHandle(const int iters) {
AspiredVersionsManager::Options options;
// Do policy thread won't be run automatically.
options.manage_state_interval_micros = -1;
options.aspired_version_policy.reset(new EagerLoadPolicy());
options.aspired_version_policy.reset(new AvailabilityPreservingPolicy());
std::unique_ptr<AspiredVersionsManager> manager;
TF_CHECK_OK(AspiredVersionsManager::Create(std::move(options), &manager));
auto aspired_versions_callback = manager->GetAspiredVersionsCallback();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ limitations under the License.
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow_serving/core/eager_load_policy.h"
#include "tensorflow_serving/core/availability_preserving_policy.h"
#include "tensorflow_serving/core/servable_data.h"
#include "tensorflow_serving/core/servable_handle.h"
#include "tensorflow_serving/core/servable_state_monitor.h"
Expand Down Expand Up @@ -46,7 +46,8 @@ class AspiredVersionsManagerBuilderTest : public ::testing::Test {
servable_state_monitor_(servable_event_bus_.get()) {
AspiredVersionsManagerBuilder::Options manager_options;
manager_options.servable_event_bus = servable_event_bus_.get();
manager_options.aspired_version_policy.reset(new EagerLoadPolicy());
manager_options.aspired_version_policy.reset(
new AvailabilityPreservingPolicy());
TF_CHECK_OK(AspiredVersionsManagerBuilder::Create(
std::move(manager_options), &builder_));
}
Expand Down
Loading