Skip to content

Commit e087ccc

Browse files
Laurawlyzhiics
authored andcommitted
[Fix] Fix get_valid_count flaky test for cuda (apache#4901)
* get_valid_count accuracy issue fixed for individual tests but not for all tests running together * minor fix * initialize valid_count and PrefixSum buffers * test updated * udpate relay test as well * update document * fix lint * address comment * fix lint * correct atomicAdd identifier name
1 parent b9dc7db commit e087ccc

3 files changed

Lines changed: 166 additions & 255 deletions

File tree

tests/python/relay/test_op_level5.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,8 +221,6 @@ def verify_get_valid_counts(dshape, score_threshold, id_index, score_index):
221221
func = relay.Function([x], z.astuple())
222222
func = run_infer_type(func)
223223
for target, ctx in ctx_list():
224-
if target == 'cuda':
225-
return
226224
intrp = relay.create_executor("debug", ctx=ctx, target=target)
227225
out = intrp.evaluate(func)(np_data)
228226
tvm.testing.assert_allclose(out[0].asnumpy(), np_out1, rtol=1e-3, atol=1e-04)

0 commit comments

Comments
 (0)