Skip to content

Commit a191480

Browse files
Ban Kawasfacebook-github-bot
authored andcommitted
Create get_data_module() on OSS WorldModelBase AND on FB FbWorldModel
Summary: As titled. See T83887308 & T83886520 for more details. Reviewed By: kaiwenw Differential Revision: D26498062 fbshipit-source-id: 2202b8a3dcde22bc97d03231228d518e947ca7db
1 parent 02244c5 commit a191480

File tree

1 file changed

+75
-5
lines changed

1 file changed

+75
-5
lines changed

reagent/model_managers/world_model_base.py

Lines changed: 75 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,15 @@
55
from reagent.core.dataclasses import dataclass
66
from reagent.core.parameters import NormalizationData, NormalizationKey
77
from reagent.data.data_fetcher import DataFetcher
8+
from reagent.data.manual_data_module import ManualDataModule
89
from reagent.data.reagent_data_module import ReAgentDataModule
910
from reagent.gym.policies.policy import Policy
1011
from reagent.preprocessing.batch_preprocessor import BatchPreprocessor
12+
from reagent.preprocessing.types import InputColumn
13+
from reagent.workflow.identify_types_flow import identify_normalization_parameters
1114
from reagent.workflow.types import (
1215
Dataset,
16+
PreprocessingOptions,
1317
ReaderOptions,
1418
ResourceOptions,
1519
RewardOptions,
@@ -40,7 +44,7 @@ def create_policy(self) -> Policy:
4044

4145
@property
4246
def should_generate_eval_dataset(self) -> bool:
43-
return False
47+
raise RuntimeError
4448

4549
@property
4650
def required_normalization_keys(self) -> List[str]:
@@ -49,7 +53,7 @@ def required_normalization_keys(self) -> List[str]:
4953
def run_feature_identification(
5054
self, input_table_spec: TableSpec
5155
) -> Dict[str, NormalizationData]:
52-
raise NotImplementedError()
56+
raise RuntimeError
5357

5458
def query_data(
5559
self,
@@ -58,10 +62,30 @@ def query_data(
5862
reward_options: RewardOptions,
5963
data_fetcher: DataFetcher,
6064
) -> Dataset:
61-
raise NotImplementedError()
65+
raise RuntimeError
6266

63-
def build_batch_preprocessor(self, use_gpu: bool) -> BatchPreprocessor:
64-
raise NotImplementedError()
67+
def build_batch_preprocessor(self) -> BatchPreprocessor:
68+
raise RuntimeError
69+
70+
def get_data_module(
71+
self,
72+
*,
73+
input_table_spec: Optional[TableSpec] = None,
74+
reward_options: Optional[RewardOptions] = None,
75+
reader_options: Optional[ReaderOptions] = None,
76+
setup_data: Optional[Dict[str, bytes]] = None,
77+
saved_setup_data: Optional[Dict[str, bytes]] = None,
78+
resource_options: Optional[ResourceOptions] = None,
79+
) -> Optional[ReAgentDataModule]:
80+
return WorldModelDataModule(
81+
input_table_spec=input_table_spec,
82+
reward_options=reward_options,
83+
setup_data=setup_data,
84+
saved_setup_data=saved_setup_data,
85+
reader_options=reader_options,
86+
resource_options=resource_options,
87+
model_manager=self,
88+
)
6589

6690
def train(
6791
self,
@@ -84,3 +108,49 @@ def train(
84108
- validation_output
85109
"""
86110
raise NotImplementedError()
111+
112+
113+
class WorldModelDataModule(ManualDataModule):
114+
@property
115+
def should_generate_eval_dataset(self) -> bool:
116+
return True
117+
118+
@property
119+
def required_normalization_keys(self) -> List[str]:
120+
return [NormalizationKey.STATE]
121+
122+
def run_feature_identification(
123+
self, input_table_spec: TableSpec
124+
) -> Dict[str, NormalizationData]:
125+
# Run state feature identification
126+
state_preprocessing_options = PreprocessingOptions()
127+
state_features = [
128+
ffi.feature_id
129+
for ffi in self.model_manager.state_feature_config.float_feature_infos
130+
]
131+
logger.info(f"state allowedlist_features: {state_features}")
132+
state_preprocessing_options = state_preprocessing_options._replace(
133+
allowedlist_features=state_features
134+
)
135+
136+
state_normalization_parameters = identify_normalization_parameters(
137+
input_table_spec, InputColumn.STATE_FEATURES, state_preprocessing_options
138+
)
139+
140+
return {
141+
NormalizationKey.STATE: NormalizationData(
142+
dense_normalization_parameters=state_normalization_parameters
143+
)
144+
}
145+
146+
def query_data(
147+
self,
148+
input_table_spec: TableSpec,
149+
sample_range: Optional[Tuple[float, float]],
150+
reward_options: RewardOptions,
151+
data_fetcher: DataFetcher,
152+
) -> Dataset:
153+
raise NotImplementedError()
154+
155+
def build_batch_preprocessor(self) -> BatchPreprocessor:
156+
raise NotImplementedError()

0 commit comments

Comments
 (0)