11# -*- coding: utf-8 -*-
2- # Copyright 2016 OpenMarket Ltd
2+ # Copyright 2016-2021 The Matrix.org Foundation C.I.C.
33#
44# Licensed under the Apache License, Version 2.0 (the "License");
55# you may not use this file except in compliance with the License.
1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16- from twisted .internet import defer
17-
1816import synapse .api .errors
1917
20- import tests .unittest
21- import tests .utils
22-
23-
24- class DeviceStoreTestCase (tests .unittest .TestCase ):
25- def __init__ (self , * args , ** kwargs ):
26- super ().__init__ (* args , ** kwargs )
27- self .store = None # type: synapse.storage.DataStore
18+ from tests .unittest import HomeserverTestCase
2819
29- @defer .inlineCallbacks
30- def setUp (self ):
31- hs = yield tests .utils .setup_test_homeserver (self .addCleanup )
3220
21+ class DeviceStoreTestCase (HomeserverTestCase ):
22+ def prepare (self , reactor , clock , hs ):
3323 self .store = hs .get_datastore ()
3424
35- @defer .inlineCallbacks
3625 def test_store_new_device (self ):
37- yield defer . ensureDeferred (
26+ self . get_success (
3827 self .store .store_device ("user_id" , "device_id" , "display_name" )
3928 )
4029
41- res = yield defer . ensureDeferred (self .store .get_device ("user_id" , "device_id" ))
30+ res = self . get_success (self .store .get_device ("user_id" , "device_id" ))
4231 self .assertDictContainsSubset (
4332 {
4433 "user_id" : "user_id" ,
@@ -48,19 +37,18 @@ def test_store_new_device(self):
4837 res ,
4938 )
5039
51- @defer .inlineCallbacks
5240 def test_get_devices_by_user (self ):
53- yield defer . ensureDeferred (
41+ self . get_success (
5442 self .store .store_device ("user_id" , "device1" , "display_name 1" )
5543 )
56- yield defer . ensureDeferred (
44+ self . get_success (
5745 self .store .store_device ("user_id" , "device2" , "display_name 2" )
5846 )
59- yield defer . ensureDeferred (
47+ self . get_success (
6048 self .store .store_device ("user_id2" , "device3" , "display_name 3" )
6149 )
6250
63- res = yield defer . ensureDeferred (self .store .get_devices_by_user ("user_id" ))
51+ res = self . get_success (self .store .get_devices_by_user ("user_id" ))
6452 self .assertEqual (2 , len (res .keys ()))
6553 self .assertDictContainsSubset (
6654 {
@@ -79,43 +67,41 @@ def test_get_devices_by_user(self):
7967 res ["device2" ],
8068 )
8169
82- @defer .inlineCallbacks
8370 def test_count_devices_by_users (self ):
84- yield defer . ensureDeferred (
71+ self . get_success (
8572 self .store .store_device ("user_id" , "device1" , "display_name 1" )
8673 )
87- yield defer . ensureDeferred (
74+ self . get_success (
8875 self .store .store_device ("user_id" , "device2" , "display_name 2" )
8976 )
90- yield defer . ensureDeferred (
77+ self . get_success (
9178 self .store .store_device ("user_id2" , "device3" , "display_name 3" )
9279 )
9380
94- res = yield defer . ensureDeferred (self .store .count_devices_by_users ())
81+ res = self . get_success (self .store .count_devices_by_users ())
9582 self .assertEqual (0 , res )
9683
97- res = yield defer . ensureDeferred (self .store .count_devices_by_users (["unknown" ]))
84+ res = self . get_success (self .store .count_devices_by_users (["unknown" ]))
9885 self .assertEqual (0 , res )
9986
100- res = yield defer . ensureDeferred (self .store .count_devices_by_users (["user_id" ]))
87+ res = self . get_success (self .store .count_devices_by_users (["user_id" ]))
10188 self .assertEqual (2 , res )
10289
103- res = yield defer . ensureDeferred (
90+ res = self . get_success (
10491 self .store .count_devices_by_users (["user_id" , "user_id2" ])
10592 )
10693 self .assertEqual (3 , res )
10794
108- @defer .inlineCallbacks
10995 def test_get_device_updates_by_remote (self ):
11096 device_ids = ["device_id1" , "device_id2" ]
11197
11298 # Add two device updates with a single stream_id
113- yield defer . ensureDeferred (
99+ self . get_success (
114100 self .store .add_device_change_to_streams ("user_id" , device_ids , ["somehost" ])
115101 )
116102
117103 # Get all device updates ever meant for this remote
118- now_stream_id , device_updates = yield defer . ensureDeferred (
104+ now_stream_id , device_updates = self . get_success (
119105 self .store .get_device_updates_by_remote ("somehost" , - 1 , limit = 100 )
120106 )
121107
@@ -131,37 +117,35 @@ def _check_devices_in_updates(self, expected_device_ids, device_updates):
131117 }
132118 self .assertEqual (received_device_ids , set (expected_device_ids ))
133119
134- @defer .inlineCallbacks
135120 def test_update_device (self ):
136- yield defer . ensureDeferred (
121+ self . get_success (
137122 self .store .store_device ("user_id" , "device_id" , "display_name 1" )
138123 )
139124
140- res = yield defer . ensureDeferred (self .store .get_device ("user_id" , "device_id" ))
125+ res = self . get_success (self .store .get_device ("user_id" , "device_id" ))
141126 self .assertEqual ("display_name 1" , res ["display_name" ])
142127
143128 # do a no-op first
144- yield defer . ensureDeferred (self .store .update_device ("user_id" , "device_id" ))
145- res = yield defer . ensureDeferred (self .store .get_device ("user_id" , "device_id" ))
129+ self . get_success (self .store .update_device ("user_id" , "device_id" ))
130+ res = self . get_success (self .store .get_device ("user_id" , "device_id" ))
146131 self .assertEqual ("display_name 1" , res ["display_name" ])
147132
148133 # do the update
149- yield defer . ensureDeferred (
134+ self . get_success (
150135 self .store .update_device (
151136 "user_id" , "device_id" , new_display_name = "display_name 2"
152137 )
153138 )
154139
155140 # check it worked
156- res = yield defer . ensureDeferred (self .store .get_device ("user_id" , "device_id" ))
141+ res = self . get_success (self .store .get_device ("user_id" , "device_id" ))
157142 self .assertEqual ("display_name 2" , res ["display_name" ])
158143
159- @defer .inlineCallbacks
160144 def test_update_unknown_device (self ):
161- with self .assertRaises ( synapse . api . errors . StoreError ) as cm :
162- yield defer . ensureDeferred (
163- self . store . update_device (
164- "user_id" , "unknown_device_id" , new_display_name = "display_name 2"
165- )
166- )
167- self .assertEqual (404 , cm . exception .code )
145+ exc = self .get_failure (
146+ self . store . update_device (
147+ "user_id" , "unknown_device_id" , new_display_name = "display_name 2"
148+ ),
149+ synapse . api . errors . StoreError ,
150+ )
151+ self .assertEqual (404 , exc . value .code )
0 commit comments