@@ -28,25 +28,25 @@ limitations under the License.
2828#include " tensorflow_serving/core/eager_load_policy.h"
2929#include " tensorflow_serving/core/servable_state.h"
3030#include " tensorflow_serving/core/test_util/dynamic_manager_test_util.h"
31+ #include " tensorflow_serving/core/test_util/mock_loader.h"
3132#include " tensorflow_serving/util/any_ptr.h"
3233#include " tensorflow_serving/util/event_bus.h"
3334
3435namespace tensorflow {
3536namespace serving {
3637
3738using ::testing::Eq;
39+ using ::testing::Invoke;
40+ using ::testing::NiceMock;
41+ using ::testing::Return;
3842using ::testing::UnorderedElementsAreArray;
3943
4044namespace {
4145
4246class FakeLoader : public Loader {
4347 public:
44- explicit FakeLoader (int64 servable) : FakeLoader(servable, false ) {}
45-
46- FakeLoader (int64 servable, const bool errors_on_load)
47- : servable_(servable), errors_on_load_(errors_on_load) {
48- ++num_fake_loaders_;
49- }
48+ explicit FakeLoader (int64 servable, const bool errors_on_load = false )
49+ : servable_(servable), errors_on_load_(errors_on_load) {}
5050 ~FakeLoader () override { --num_fake_loaders_; }
5151
5252 Status Load () override {
@@ -83,13 +83,15 @@ class DynamicManagerTest : public ::testing::Test {
8383 LOG (INFO) << " Published state: " << state.DebugString ();
8484 last_published_servable_state_ = state;
8585 });
86- DynamicManager::Options options;
8786 // The state manager thread won't be run automatically.
88- options.manage_state_interval_micros = -1 ;
89- options.env = Env::Default ();
90- options.version_policy .reset (new EagerLoadPolicy ());
91- options.servable_event_bus = servable_event_bus_.get ();
92- manager_.reset (new DynamicManager (std::move (options)));
87+ dynamic_manager_options_.manage_state_interval_micros = -1 ;
88+ dynamic_manager_options_.env = Env::Default ();
89+ dynamic_manager_options_.version_policy .reset (new EagerLoadPolicy ());
90+ dynamic_manager_options_.servable_event_bus = servable_event_bus_.get ();
91+ dynamic_manager_options_.max_num_load_tries = 2 ;
92+ dynamic_manager_options_.load_retry_interval_micros = 0 ;
93+ // dynamic_manager_options_.load_retry_interval_micros = 0;
94+ manager_.reset (new DynamicManager (std::move (dynamic_manager_options_)));
9395 }
9496
9597 // Creates an aspired-versions entry with 'id' and a FakeLoader whose servable
@@ -138,6 +140,7 @@ class DynamicManagerTest : public ::testing::Test {
138140 std::unique_ptr<EventBus<ServableState>::Subscription>
139141 servable_state_subscription_;
140142 ServableState last_published_servable_state_;
143+ DynamicManager::Options dynamic_manager_options_;
141144 std::unique_ptr<DynamicManager> manager_;
142145};
143146
@@ -578,6 +581,51 @@ TEST_F(DynamicManagerTest, NoEventBus) {
578581 std::move (aspired_versions));
579582}
580583
584+ TEST_F (DynamicManagerTest, RetryOnLoadErrorFinallySucceeds) {
585+ std::vector<ServableData<std::unique_ptr<Loader>>> aspired_versions;
586+
587+ test_util::MockLoader* loader = new NiceMock<test_util::MockLoader>;
588+ // Prevents it being changed without our knowledge.
589+ CHECK_EQ (dynamic_manager_options_.max_num_load_tries , 2 );
590+ // We succeed on the last load, before the manager gives up.
591+ EXPECT_CALL (*loader, Load ())
592+ .WillOnce (Return (errors::Internal (" Error on load." )))
593+ .WillOnce (Return (Status::OK ()));
594+
595+ const ServableId id = {kServableName , 7 };
596+ aspired_versions.push_back ({id, std::unique_ptr<Loader>(loader)});
597+ manager_->GetAspiredVersionsCallback ()(kServableName ,
598+ std::move (aspired_versions));
599+
600+ RunManageState ();
601+
602+ const ServableState available_state = {
603+ {kServableName , 7 },
604+ true ,
605+ ServableState::ManagerState::kAvailable ,
606+ Status::OK ()};
607+ EXPECT_THAT (last_published_servable_state_,
608+ EqualsServableState (available_state));
609+ }
610+
611+ TEST_F (DynamicManagerTest, RetryOnLoadErrorFinallyFails) {
612+ std::vector<ServableData<std::unique_ptr<Loader>>> aspired_versions;
613+ const ServableId id = {kServableName , 7 };
614+ // We always fail.
615+ std::unique_ptr<Loader> loader (new FakeLoader (7 , true /* errors_on_load */ ));
616+ aspired_versions.push_back ({id, std::move (loader)});
617+ manager_->GetAspiredVersionsCallback ()(kServableName ,
618+ std::move (aspired_versions));
619+
620+ RunManageState ();
621+
622+ const ServableState error_state = {{kServableName , 7 },
623+ true ,
624+ ServableState::ManagerState::kEnd ,
625+ errors::Internal (" Error on load." )};
626+ EXPECT_THAT (last_published_servable_state_, EqualsServableState (error_state));
627+ }
628+
581629} // namespace
582630} // namespace serving
583631} // namespace tensorflow
0 commit comments