@@ -18,141 +18,6 @@ struct StrifeML::Serializer<Vector2>
1818 }
1919};
2020
21- enum class NeuralNetworkMode
22- {
23- Disabled,
24- Deciding,
25- CollectingSamples,
26- ReinforcementLearning,
27- };
28-
29- template <typename TNeuralNetwork>
30- struct NeuralNetworkComponent : ComponentTemplate<NeuralNetworkComponent<TNeuralNetwork>>
31- {
32- using InputType = typename TNeuralNetwork::InputType;
33- using OutputType = typename TNeuralNetwork::OutputType;
34- using NetworkType = TNeuralNetwork;
35- using SampleType = StrifeML::Sample<InputType, OutputType>;
36-
37- explicit NeuralNetworkComponent (float decisionsPerSecond_ = 1 )
38- : decisionInput(StrifeML::MlUtil::SharedArray<InputType>(TNeuralNetwork::SequenceLength)),
39- decisionsPerSecond(decisionsPerSecond_)
40- {
41-
42- }
43-
44- virtual ~NeuralNetworkComponent () = default ;
45-
46- void Update (float deltaTime) override ;
47-
48- void SetNetwork (const char * name);
49-
50- bool StartMakingDecision ();
51-
52- StrifeML::NetworkContext<NetworkType>* networkContext = nullptr ;
53- std::shared_ptr<StrifeML::MakeDecisionWorkItem<TNeuralNetwork>> decisionInProgress;
54- StrifeML::MlUtil::SharedArray<InputType> decisionInput;
55- float makeDecisionsTimer = 0 ;
56- float decisionsPerSecond;
57-
58- std::shared_ptr<StrifeML::RunTrainingBatchWorkItem<TNeuralNetwork>> trainingInProgress;
59-
60- int inputsCollected = 0 ;
61- InputType previousInput;
62- std::function<void (InputType& input)> collectInput;
63- std::function<void (OutputType& decision)> collectDecision;
64-
65- std::function<void (OutputType& decision)> receiveDecision;
66-
67- NeuralNetworkMode mode = NeuralNetworkMode::Disabled;
68- };
69-
70- template <typename TNeuralNetwork>
71- void NeuralNetworkComponent<TNeuralNetwork>::Update(float deltaTime)
72- {
73- if (mode == NeuralNetworkMode::Disabled)
74- {
75- return ;
76- }
77- else if (mode == NeuralNetworkMode::Deciding || mode == NeuralNetworkMode::ReinforcementLearning)
78- {
79- OutputType output;
80- if (decisionInProgress != nullptr && decisionInProgress->TryGetResult (output))
81- {
82- if (receiveDecision != nullptr )
83- {
84- receiveDecision (output);
85- }
86-
87- decisionInProgress = nullptr ;
88- }
89-
90- makeDecisionsTimer -= deltaTime;
91-
92- if (makeDecisionsTimer <= 0 )
93- {
94- makeDecisionsTimer = 1 .0f / decisionsPerSecond;
95- auto wasStarted = StartMakingDecision ();
96-
97- if (wasStarted && mode == NeuralNetworkMode::ReinforcementLearning && inputsCollected > TNeuralNetwork::SequenceLength)
98- {
99- SampleType sample;
100- sample.input = previousInput;
101- collectDecision (sample.output );
102- networkContext->trainer ->AddSample (sample);
103- }
104- }
105- }
106- else if (mode == NeuralNetworkMode::CollectingSamples)
107- {
108- SampleType sample;
109- collectInput (sample.input );
110- collectDecision (sample.output );
111- networkContext->trainer ->AddSample (sample);
112- }
113- }
114-
115- template <typename TNeuralNetwork>
116- bool NeuralNetworkComponent<TNeuralNetwork>::StartMakingDecision()
117- {
118- // Don't allow making more than one decision at a time
119- if (decisionInProgress != nullptr
120- && !decisionInProgress->IsComplete ())
121- {
122- return false ;
123- }
124-
125- if (collectInput == nullptr
126- || networkContext == nullptr
127- || networkContext->decider == nullptr )
128- {
129- return false ;
130- }
131-
132- if (mode == NeuralNetworkMode::ReinforcementLearning && inputsCollected > TNeuralNetwork::SequenceLength)
133- {
134- previousInput = decisionInput.data .get ()[TNeuralNetwork::SequenceLength - 1 ];
135- }
136-
137- // Expire oldest input
138- for (int i = 0 ; i < TNeuralNetwork::SequenceLength - 1 ; ++i)
139- {
140- decisionInput.data .get ()[i] = std::move (decisionInput.data .get ()[i + 1 ]);
141- }
142-
143- // Collect new input
144- collectInput (decisionInput.data .get ()[TNeuralNetwork::SequenceLength - 1 ]);
145- inputsCollected++;
146-
147- // Start making decision if we have enough data
148- if (inputsCollected >= TNeuralNetwork::SequenceLength)
149- {
150- decisionInProgress = networkContext->decider ->MakeDecision (decisionInput, TNeuralNetwork::SequenceLength);
151- }
152-
153- return true ;
154- }
155-
15621struct SensorObjectDefinition
15722{
15823 SensorObjectDefinition ()
@@ -253,7 +118,7 @@ struct NeuralNetworkManager
253118 }
254119
255120 template <typename TDecider, typename TTrainer>
256- StrifeML::NetworkContext<typename TDecider::NetworkType>* CreateNetwork (const char * name, TDecider* decider, TTrainer* trainer)
121+ StrifeML::NetworkContext<typename TDecider::NetworkType>* CreateNetwork (const char * name, TDecider* decider, TTrainer* trainer, int sequenceLength )
257122 {
258123 static_assert (std::is_same_v<typename TDecider::NetworkType, typename TTrainer::NetworkType>, " Trainer and decider must accept the same type of neural network" );
259124
@@ -263,7 +128,7 @@ struct NeuralNetworkManager
263128 throw StrifeML::StrifeException (" Network already exists: " + std::string (name));
264129 }
265130
266- auto newContext = std::make_shared<StrifeML::NetworkContext<typename TDecider::NetworkType>>(decider, trainer);
131+ auto newContext = std::make_shared<StrifeML::NetworkContext<typename TDecider::NetworkType>>(decider, trainer, sequenceLength );
267132 _networksByName[name] = newContext;
268133
269134 trainer->networkContext = newContext;
@@ -303,13 +168,6 @@ struct NeuralNetworkManager
303168 std::unordered_map<std::string, std::shared_ptr<StrifeML::INetworkContext>> _networksByName;
304169};
305170
306- template <typename TNeuralNetwork>
307- void NeuralNetworkComponent<TNeuralNetwork>::SetNetwork(const char * name)
308- {
309- auto nnManager = this ->owner ->GetEngine ()->GetNeuralNetworkManager ();
310- networkContext = nnManager->template GetNetwork <TNeuralNetwork>(name);
311- }
312-
313171gsl::span<uint64_t > ReadGridSensorRectangles (
314172 Scene* scene,
315173 Vector2 center,
0 commit comments