Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions sdks/python/apache_beam/testing/test_stream_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class TestStreamServiceController(TestStreamServiceServicer):
"""
def __init__(self, reader, endpoint=None, exception_handler=None):
self._server = grpc.server(ThreadPoolExecutor(max_workers=10))
self._server_started = False
self._server_stopped = False

if endpoint:
self.endpoint = endpoint
Expand All @@ -50,9 +52,18 @@ def __init__(self, reader, endpoint=None, exception_handler=None):
self._exception_handler = lambda _: False

def start(self):
# A server can only be started if never started and never stopped before.
if self._server_started or self._server_stopped:
return
self._server_started = True
self._server.start()

def stop(self):
# A server can only be stopped if already started and never stopped before.
if not self._server_started or self._server_stopped:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Write this in a format similar to start():

if self._server_started and not self._server_stopped:

or change start to similar format as

if self._server_started or self._server_stopped:
    return

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Formatted it.

return
self._server_started = False
self._server_stopped = True
self._server.stop(0)
# This was introduced in grpcio 1.24 and might be gone in the future. Keep
# this check in case the runtime is on a older, current or future grpcio.
Expand Down
81 changes: 81 additions & 0 deletions sdks/python/apache_beam/testing/test_stream_service_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from __future__ import absolute_import

import sys
import unittest

import grpc
Expand All @@ -30,6 +31,13 @@
from apache_beam.portability.api.beam_runner_api_pb2 import TestStreamPayload
from apache_beam.testing.test_stream_service import TestStreamServiceController

# TODO(BEAM-8288): clean up the work-around of nose tests using Python2 without
# unittest.mock module.
try:
from unittest.mock import patch
except ImportError:
from mock import patch # type: ignore[misc]

# Nose automatically detects tests if they match a regex. Here, it mistakens
# these protos as tests. For more info see the Nose docs at:
# https://nose.readthedocs.io/en/latest/writing_tests.html
Expand Down Expand Up @@ -116,5 +124,78 @@ def test_multiple_sessions(self):
self.assertEqual(events_b, expected_events)


@unittest.skipIf(
sys.version_info < (3, 6), 'The tests require at least Python 3.6 to work.')
class TestStreamServiceStartStopTest(unittest.TestCase):

# Weak internal use needs to be explicitly imported.
from grpc import _server

def setUp(self):
self.controller = TestStreamServiceController(
EventsReader(expected_key=[('full', EXPECTED_KEY)]))
self.assertFalse(self.controller._server_started)
self.assertFalse(self.controller._server_stopped)

def tearDown(self):
self.controller.stop()

def test_start_when_never_started(self):
with patch.object(self._server._Server,
'start',
wraps=self.controller._server.start) as mock_start:
self.controller.start()
mock_start.assert_called_once()
self.assertTrue(self.controller._server_started)
self.assertFalse(self.controller._server_stopped)

def test_start_noop_when_already_started(self):
with patch.object(self._server._Server,
'start',
wraps=self.controller._server.start) as mock_start:
self.controller.start()
mock_start.assert_called_once()
self.controller.start()
mock_start.assert_called_once()

def test_start_noop_when_already_stopped(self):
with patch.object(self._server._Server,
'start',
wraps=self.controller._server.start) as mock_start:
self.controller.start()
self.controller.stop()
mock_start.assert_called_once()
self.controller.start()
mock_start.assert_called_once()

def test_stop_noop_when_not_started(self):
with patch.object(self._server._Server,
'stop',
wraps=self.controller._server.stop) as mock_stop:
self.controller.stop()
mock_stop.assert_not_called()

def test_stop_when_already_started(self):
with patch.object(self._server._Server,
'stop',
wraps=self.controller._server.stop) as mock_stop:
self.controller.start()
mock_stop.assert_not_called()
self.controller.stop()
mock_stop.assert_called_once()
self.assertFalse(self.controller._server_started)
self.assertTrue(self.controller._server_stopped)

def test_stop_noop_when_already_stopped(self):
with patch.object(self._server._Server,
'stop',
wraps=self.controller._server.stop) as mock_stop:
self.controller.start()
self.controller.stop()
mock_stop.assert_called_once()
self.controller.stop()
mock_stop.assert_called_once()


if __name__ == '__main__':
unittest.main()