Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 127 additions & 5 deletions reagent/test/preprocessing/test_transforms.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest
from copy import deepcopy
from typing import List
from unittest.mock import Mock, patch

Expand All @@ -11,9 +12,23 @@

class TestTransforms(unittest.TestCase):
def setUp(self):
# preparing various components for qr-dqn trainer initialization
# currently not needed
pass
# add custom compare function for torch.Tensor
self.addTypeEqualityFunc(torch.Tensor, TestTransforms.are_torch_tensor_equal)

@staticmethod
def are_torch_tensor_equal(tensor_0, tensor_1, msg=None):
if torch.all(tensor_0 == tensor_1):
return True
raise TestTransforms.failureException("non-equal pytorch tensors found", msg)

def assertTorchTensorEqual(self, tensor_0, tensor_1, msg=None):
self.assertIsInstance(
tensor_0, torch.Tensor, "first argument is not a torch.Tensor"
)
self.assertIsInstance(
tensor_1, torch.Tensor, "second argument is not a torch.Tensor"
)
self.assertEqual(tensor_0, tensor_1, msg=msg)

def assertDictComparatorEqual(self, a, b, cmp):
"""
Expand Down Expand Up @@ -167,8 +182,115 @@ def test_DenseNormalization(self, Preprocessor):
# ensure unnamed variables not changed
self.assertEqual(out["c"], c_out)
in_1, in_2 = [call_args.args for call_args in preprocessor.call_args_list]
self.assertTrue(torch.all(torch.stack(in_1) == torch.stack(a_in)))
self.assertTrue(torch.all(torch.stack(in_2) == torch.stack(b_in)))

self.assertEqual(torch.stack(in_1), torch.stack(a_in))
self.assertEqual(torch.stack(in_2), torch.stack(b_in))

@patch("reagent.preprocessing.transforms.Preprocessor")
def test_FixedLengthSequenceDenseNormalization(self, Preprocessor):
# test key mapping
rand_gen = torch.Generator().manual_seed(0)

a_batch_size = 2
b_batch_size = 3

a_dim = 13
b_dim = 11

expected_length = 7

a_T = (
torch.rand(
a_batch_size * expected_length, a_dim, generator=rand_gen
), # value
torch.rand(a_batch_size * expected_length, a_dim, generator=rand_gen)
> 0.5, # presence
)
b_T = (
torch.rand(
b_batch_size * expected_length, b_dim, generator=rand_gen
), # value
torch.rand(b_batch_size * expected_length, b_dim, generator=rand_gen)
> 0.5, # presence
)

# expected values after preprocessing
a_TN = a_T[0] + 1
b_TN = b_T[0] + 1

# copy used for checking inplace modifications
a_TN_copy = deepcopy(a_TN)
b_TN_copy = deepcopy(b_TN)

a_offsets = torch.arange(0, a_batch_size * expected_length, expected_length)
b_offsets = torch.arange(0, b_batch_size * expected_length, expected_length)

a_in = {1: (a_offsets, a_T), 2: 0}
b_in = {1: (b_offsets, b_T), 2: 1}

c_out = 2

# input data
data = {"a": a_in, "b": b_in, "c": c_out}

# copy used for checking inplace modifications
data_copy = deepcopy(data)

Preprocessor.return_value = Mock(side_effect=[a_TN, b_TN])

flsdn = transforms.FixedLengthSequenceDenseNormalization(
keys=["a", "b"],
sequence_id=1,
normalization_data=Mock(),
)

out = flsdn(data)

# data is modified inplace and returned
self.assertEqual(data, out)

# check preprocessor number of calls
self.assertEqual(Preprocessor.call_count, 1)
self.assertEqual(Preprocessor.return_value.call_count, 2)

# result contains original keys and new processed keys
self.assertSetEqual(set(out.keys()), {"a", "b", "c", "a:1", "b:1"})

def assertKeySeqIdItem(item_0, item_1):
self.assertTorchTensorEqual(item_0[0], item_1[0])
self.assertTorchTensorEqual(item_0[1][0], item_1[1][0])
self.assertTorchTensorEqual(item_0[1][1], item_1[1][1])

# original keys should keep their value
for key in ("a", "b"):
# no change in the output
assertKeySeqIdItem(out[key][1], data_copy[key][1])

# no change in untouched seq id
self.assertEqual(out[key][2], data_copy[key][2])

# no change in the non-processed key
self.assertEqual(out["c"], data_copy["c"])

# check output shapes
self.assertListEqual(
[*out["a:1"].shape], [a_batch_size, expected_length, a_dim]
)
self.assertListEqual(
[*out["b:1"].shape], [b_batch_size, expected_length, b_dim]
)

# no inplace change in normalized tensors
self.assertTorchTensorEqual(a_TN, a_TN_copy)
self.assertTorchTensorEqual(b_TN, b_TN_copy)

# check if output has been properly slated
self.assertTorchTensorEqual(
out["a:1"], a_TN.view(a_batch_size, expected_length, a_dim)
)
self.assertTorchTensorEqual(
out["b:1"], b_TN.view(b_batch_size, expected_length, b_dim)
)

@patch("reagent.preprocessing.transforms.make_sparse_preprocessor")
def test_MapIDListFeatures(self, mock_make_sparse_preprocessor):
Expand Down