|
37 | 37 | from tensorflow.core.protobuf import config_pb2 |
38 | 38 | from tensorflow.core.protobuf import rewriter_config_pb2 |
39 | 39 | from tensorflow.python.client import session |
| 40 | +from tensorflow.python.distribute import distribute_coordinator as dc |
40 | 41 | from tensorflow.python.estimator import run_config |
41 | 42 | from tensorflow.python.platform import test |
42 | 43 | from tensorflow.python.platform import tf_logging as logging |
43 | 44 | from tensorflow.python.training import coordinator |
44 | 45 | from tensorflow.python.training import server_lib |
45 | 46 |
|
| 47 | + |
| 48 | +original_run_std_server = dc._run_std_server # pylint: disable=protected-access |
| 49 | + |
46 | 50 | ASSIGNED_PORTS = set() |
47 | 51 | lock = threading.Lock() |
48 | 52 |
|
@@ -357,6 +361,22 @@ def __len__(self): |
357 | 361 | class IndependentWorkerTestBase(test.TestCase): |
358 | 362 | """Testing infra for independent workers.""" |
359 | 363 |
|
| 364 | + def _make_mock_run_std_server(self): |
| 365 | + thread_local = threading.local() |
| 366 | + |
| 367 | + def _mock_run_std_server(*args, **kwargs): |
| 368 | + ret = original_run_std_server(*args, **kwargs) |
| 369 | + # Wait for all std servers to be brought up in order to reduce the chance |
| 370 | + # of remote sessions taking local ports that have been assigned to std |
| 371 | + # servers. Only call this barrier the first time this function is run for |
| 372 | + # each thread. |
| 373 | + if not getattr(thread_local, 'server_started', False): |
| 374 | + self._barrier.wait() |
| 375 | + thread_local.server_started = True |
| 376 | + return ret |
| 377 | + |
| 378 | + return _mock_run_std_server |
| 379 | + |
360 | 380 | def setUp(self): |
361 | 381 | self._mock_os_env = MockOsEnv() |
362 | 382 | self._mock_context = test.mock.patch.object(os, 'environ', |
|
0 commit comments