From 4271c35678642f1cbd807b783568c6464a1b3ebd Mon Sep 17 00:00:00 2001 From: Pierre Gleize Date: Thu, 23 Sep 2021 16:14:39 -0700 Subject: [PATCH] Add unit test for FixedLengthSequenceDenseNormalization. (#545) Summary: Pull Request resolved: https://github.com/facebookresearch/ReAgent/pull/545 Reviewed By: igfox Differential Revision: D31136906 fbshipit-source-id: 767f57238380d996c6719be46255b89a3eff014c --- reagent/test/preprocessing/test_transforms.py | 132 +++++++++++++++++- 1 file changed, 127 insertions(+), 5 deletions(-) diff --git a/reagent/test/preprocessing/test_transforms.py b/reagent/test/preprocessing/test_transforms.py index e8cf887c1..db3ed09ac 100644 --- a/reagent/test/preprocessing/test_transforms.py +++ b/reagent/test/preprocessing/test_transforms.py @@ -1,4 +1,5 @@ import unittest +from copy import deepcopy from typing import List from unittest.mock import Mock, patch @@ -11,9 +12,23 @@ class TestTransforms(unittest.TestCase): def setUp(self): - # preparing various components for qr-dqn trainer initialization - # currently not needed - pass + # add custom compare function for torch.Tensor + self.addTypeEqualityFunc(torch.Tensor, TestTransforms.are_torch_tensor_equal) + + @staticmethod + def are_torch_tensor_equal(tensor_0, tensor_1, msg=None): + if torch.all(tensor_0 == tensor_1): + return True + raise TestTransforms.failureException("non-equal pytorch tensors found", msg) + + def assertTorchTensorEqual(self, tensor_0, tensor_1, msg=None): + self.assertIsInstance( + tensor_0, torch.Tensor, "first argument is not a torch.Tensor" + ) + self.assertIsInstance( + tensor_1, torch.Tensor, "second argument is not a torch.Tensor" + ) + self.assertEqual(tensor_0, tensor_1, msg=msg) def assertDictComparatorEqual(self, a, b, cmp): """ @@ -167,8 +182,115 @@ def test_DenseNormalization(self, Preprocessor): # ensure unnamed variables not changed self.assertEqual(out["c"], c_out) in_1, in_2 = [call_args.args for call_args in preprocessor.call_args_list] - self.assertTrue(torch.all(torch.stack(in_1) == torch.stack(a_in))) - self.assertTrue(torch.all(torch.stack(in_2) == torch.stack(b_in))) + + self.assertEqual(torch.stack(in_1), torch.stack(a_in)) + self.assertEqual(torch.stack(in_2), torch.stack(b_in)) + + @patch("reagent.preprocessing.transforms.Preprocessor") + def test_FixedLengthSequenceDenseNormalization(self, Preprocessor): + # test key mapping + rand_gen = torch.Generator().manual_seed(0) + + a_batch_size = 2 + b_batch_size = 3 + + a_dim = 13 + b_dim = 11 + + expected_length = 7 + + a_T = ( + torch.rand( + a_batch_size * expected_length, a_dim, generator=rand_gen + ), # value + torch.rand(a_batch_size * expected_length, a_dim, generator=rand_gen) + > 0.5, # presence + ) + b_T = ( + torch.rand( + b_batch_size * expected_length, b_dim, generator=rand_gen + ), # value + torch.rand(b_batch_size * expected_length, b_dim, generator=rand_gen) + > 0.5, # presence + ) + + # expected values after preprocessing + a_TN = a_T[0] + 1 + b_TN = b_T[0] + 1 + + # copy used for checking inplace modifications + a_TN_copy = deepcopy(a_TN) + b_TN_copy = deepcopy(b_TN) + + a_offsets = torch.arange(0, a_batch_size * expected_length, expected_length) + b_offsets = torch.arange(0, b_batch_size * expected_length, expected_length) + + a_in = {1: (a_offsets, a_T), 2: 0} + b_in = {1: (b_offsets, b_T), 2: 1} + + c_out = 2 + + # input data + data = {"a": a_in, "b": b_in, "c": c_out} + + # copy used for checking inplace modifications + data_copy = deepcopy(data) + + Preprocessor.return_value = Mock(side_effect=[a_TN, b_TN]) + + flsdn = transforms.FixedLengthSequenceDenseNormalization( + keys=["a", "b"], + sequence_id=1, + normalization_data=Mock(), + ) + + out = flsdn(data) + + # data is modified inplace and returned + self.assertEqual(data, out) + + # check preprocessor number of calls + self.assertEqual(Preprocessor.call_count, 1) + self.assertEqual(Preprocessor.return_value.call_count, 2) + + # result contains original keys and new processed keys + self.assertSetEqual(set(out.keys()), {"a", "b", "c", "a:1", "b:1"}) + + def assertKeySeqIdItem(item_0, item_1): + self.assertTorchTensorEqual(item_0[0], item_1[0]) + self.assertTorchTensorEqual(item_0[1][0], item_1[1][0]) + self.assertTorchTensorEqual(item_0[1][1], item_1[1][1]) + + # original keys should keep their value + for key in ("a", "b"): + # no change in the output + assertKeySeqIdItem(out[key][1], data_copy[key][1]) + + # no change in untouched seq id + self.assertEqual(out[key][2], data_copy[key][2]) + + # no change in the non-processed key + self.assertEqual(out["c"], data_copy["c"]) + + # check output shapes + self.assertListEqual( + [*out["a:1"].shape], [a_batch_size, expected_length, a_dim] + ) + self.assertListEqual( + [*out["b:1"].shape], [b_batch_size, expected_length, b_dim] + ) + + # no inplace change in normalized tensors + self.assertTorchTensorEqual(a_TN, a_TN_copy) + self.assertTorchTensorEqual(b_TN, b_TN_copy) + + # check if output has been properly slated + self.assertTorchTensorEqual( + out["a:1"], a_TN.view(a_batch_size, expected_length, a_dim) + ) + self.assertTorchTensorEqual( + out["b:1"], b_TN.view(b_batch_size, expected_length, b_dim) + ) @patch("reagent.preprocessing.transforms.make_sparse_preprocessor") def test_MapIDListFeatures(self, mock_make_sparse_preprocessor):