55from reagent .core .dataclasses import dataclass
66from reagent .core .parameters import NormalizationData , NormalizationKey
77from reagent .data .data_fetcher import DataFetcher
8+ from reagent .data .manual_data_module import ManualDataModule
89from reagent .data .reagent_data_module import ReAgentDataModule
910from reagent .gym .policies .policy import Policy
1011from reagent .preprocessing .batch_preprocessor import BatchPreprocessor
12+ from reagent .preprocessing .types import InputColumn
13+ from reagent .workflow .identify_types_flow import identify_normalization_parameters
1114from reagent .workflow .types import (
1215 Dataset ,
16+ PreprocessingOptions ,
1317 ReaderOptions ,
1418 ResourceOptions ,
1519 RewardOptions ,
@@ -40,7 +44,7 @@ def create_policy(self) -> Policy:
4044
4145 @property
4246 def should_generate_eval_dataset (self ) -> bool :
43- return False
47+ raise RuntimeError
4448
4549 @property
4650 def required_normalization_keys (self ) -> List [str ]:
@@ -49,7 +53,7 @@ def required_normalization_keys(self) -> List[str]:
4953 def run_feature_identification (
5054 self , input_table_spec : TableSpec
5155 ) -> Dict [str , NormalizationData ]:
52- raise NotImplementedError ()
56+ raise RuntimeError
5357
5458 def query_data (
5559 self ,
@@ -58,10 +62,30 @@ def query_data(
5862 reward_options : RewardOptions ,
5963 data_fetcher : DataFetcher ,
6064 ) -> Dataset :
61- raise NotImplementedError ()
65+ raise RuntimeError
6266
63- def build_batch_preprocessor (self , use_gpu : bool ) -> BatchPreprocessor :
64- raise NotImplementedError ()
67+ def build_batch_preprocessor (self ) -> BatchPreprocessor :
68+ raise RuntimeError
69+
70+ def get_data_module (
71+ self ,
72+ * ,
73+ input_table_spec : Optional [TableSpec ] = None ,
74+ reward_options : Optional [RewardOptions ] = None ,
75+ reader_options : Optional [ReaderOptions ] = None ,
76+ setup_data : Optional [Dict [str , bytes ]] = None ,
77+ saved_setup_data : Optional [Dict [str , bytes ]] = None ,
78+ resource_options : Optional [ResourceOptions ] = None ,
79+ ) -> Optional [ReAgentDataModule ]:
80+ return WorldModelDataModule (
81+ input_table_spec = input_table_spec ,
82+ reward_options = reward_options ,
83+ setup_data = setup_data ,
84+ saved_setup_data = saved_setup_data ,
85+ reader_options = reader_options ,
86+ resource_options = resource_options ,
87+ model_manager = self ,
88+ )
6589
6690 def train (
6791 self ,
@@ -84,3 +108,49 @@ def train(
84108 - validation_output
85109 """
86110 raise NotImplementedError ()
111+
112+
113+ class WorldModelDataModule (ManualDataModule ):
114+ @property
115+ def should_generate_eval_dataset (self ) -> bool :
116+ return True
117+
118+ @property
119+ def required_normalization_keys (self ) -> List [str ]:
120+ return [NormalizationKey .STATE ]
121+
122+ def run_feature_identification (
123+ self , input_table_spec : TableSpec
124+ ) -> Dict [str , NormalizationData ]:
125+ # Run state feature identification
126+ state_preprocessing_options = PreprocessingOptions ()
127+ state_features = [
128+ ffi .feature_id
129+ for ffi in self .model_manager .state_feature_config .float_feature_infos
130+ ]
131+ logger .info (f"state allowedlist_features: { state_features } " )
132+ state_preprocessing_options = state_preprocessing_options ._replace (
133+ allowedlist_features = state_features
134+ )
135+
136+ state_normalization_parameters = identify_normalization_parameters (
137+ input_table_spec , InputColumn .STATE_FEATURES , state_preprocessing_options
138+ )
139+
140+ return {
141+ NormalizationKey .STATE : NormalizationData (
142+ dense_normalization_parameters = state_normalization_parameters
143+ )
144+ }
145+
146+ def query_data (
147+ self ,
148+ input_table_spec : TableSpec ,
149+ sample_range : Optional [Tuple [float , float ]],
150+ reward_options : RewardOptions ,
151+ data_fetcher : DataFetcher ,
152+ ) -> Dataset :
153+ raise NotImplementedError ()
154+
155+ def build_batch_preprocessor (self ) -> BatchPreprocessor :
156+ raise NotImplementedError ()
0 commit comments