1+ #pragma once
2+
3+ #include " Memory/CircularQueue.hpp"
4+ #include " Scene/Scene.hpp"
5+ #include " Scene/IEntityEvent.hpp"
6+ #include " ML/ML.hpp"
7+
8+ template <typename TInput>
9+ struct InputCircularBufferAllocator
10+ {
11+ using InputCircularBuffer = CircularQueue<TInput>;
12+
13+ InputCircularBufferAllocator (int maxBuffers, int circularBufferSize)
14+ : bufferPool(std::make_unique<InputCircularBuffer[]>(maxBuffers)),
15+ inputs (std::make_unique<TInput[]>(maxBuffers * circularBufferSize)),
16+ freeBuffers(bufferPool.get(), maxBuffers),
17+ circularBufferSize(circularBufferSize)
18+ {
19+
20+ }
21+
22+ InputCircularBuffer* Allocate ()
23+ {
24+ auto buffer = freeBuffers.Borrow ();
25+
26+ // Returning a node to the free list overwrites its data, so this has to be initialized every time its allocated
27+ int sampleId = buffer - bufferPool.get ();
28+ TInput* storageStart = inputs.get () + sampleId * circularBufferSize;
29+
30+ new (buffer) InputCircularBuffer (storageStart, circularBufferSize);
31+
32+ return buffer;
33+ }
34+
35+ void Free (InputCircularBuffer* buffer)
36+ {
37+ freeBuffers.Return (buffer);
38+ }
39+
40+ std::unique_ptr<InputCircularBuffer[]> bufferPool;
41+ std::unique_ptr<TInput[]> inputs;
42+ FreeList<InputCircularBuffer> freeBuffers;
43+ int circularBufferSize;
44+ };
45+
46+ template <typename TNetwork>
47+ struct DecisionBatch
48+ {
49+ using InputType = typename TNetwork::InputType;
50+ using OutputType = typename TNetwork::OutputType;
51+ using InputCircularBuffer = CircularQueue<InputType>;
52+
53+ DecisionBatch (int maxBatchSize, int sequenceLength, StrifeML::NetworkContext<TNetwork>* networkContext)
54+ : sequenceLength(sequenceLength),
55+ decisionInput (maxBatchSize * sequenceLength),
56+ decisionOutput(maxBatchSize),
57+ networkContext(networkContext)
58+ {
59+ entitiesInBatch.reserve (maxBatchSize);
60+ }
61+
62+ void ResetBatch ()
63+ {
64+ entitiesInBatch.clear ();
65+ decisionInProgress = nullptr ;
66+ }
67+
68+ void AddToBatch (Entity* entity, InputCircularBuffer& buffer)
69+ {
70+ int row = entitiesInBatch.size ();
71+ entitiesInBatch.emplace_back (entity);
72+
73+ int col = 0 ;
74+ for (PlayerInput& sample : buffer)
75+ {
76+ decisionInput.data .get ()[row * sequenceLength + col] = sample;
77+ ++col;
78+ }
79+ }
80+
81+ bool HasBatchInProgress () const
82+ {
83+ return decisionInProgress != nullptr ;
84+ }
85+
86+ bool BatchIsComplete () const
87+ {
88+ return decisionInProgress->IsComplete ();
89+ }
90+
91+ void StartBatchIfAnyEntities ()
92+ {
93+ if (entitiesInBatch.size () > 0 )
94+ {
95+ printf (" Start making decision\n " );
96+
97+ decisionInProgress = networkContext->decider ->MakeDecision (
98+ decisionInput,
99+ decisionOutput,
100+ networkContext->sequenceLength ,
101+ entitiesInBatch.size ());
102+ }
103+ }
104+
105+ int sequenceLength;
106+ std::vector<EntityReference<Entity>> entitiesInBatch;
107+ std::shared_ptr<StrifeML::MakeDecisionWorkItem<TNetwork>> decisionInProgress;
108+ StrifeML::MlUtil::SharedArray<InputType> decisionInput;
109+ StrifeML::MlUtil::SharedArray<OutputType> decisionOutput;
110+ StrifeML::NetworkContext<TNetwork>* networkContext;
111+ };
112+
113+ template <typename TEntity, typename TNetwork>
114+ struct NeuralNetworkService : ISceneService, IEntityObserver
115+ {
116+ using InputType = typename TNetwork::InputType;
117+ using OutputType = typename TNetwork::OutputType;
118+ using TrainerType = StrifeML::Trainer<TNetwork>;
119+
120+ using InputCircularBuffer = CircularQueue<InputType>;
121+
122+ NeuralNetworkService (StrifeML::NetworkContext<TNetwork>* networkContext, int maxEntitiesInBatch)
123+ : networkContext(networkContext),
124+ bufferAllocator (maxEntitiesInBatch, networkContext->sequenceLength + 1 ),
125+ decisionBatch(maxEntitiesInBatch, networkContext->sequenceLength, networkContext)
126+ {
127+
128+ }
129+
130+ void OnAdded () override
131+ {
132+ scene->AddEntityObserver <TEntity>(this );
133+ }
134+
135+ void ReceiveEvent (const IEntityEvent& ev) override
136+ {
137+ if (ev.Is <UpdateEvent>())
138+ {
139+ // Collect inputs
140+ {
141+ collectInputTimer -= scene->deltaTime ;
142+ if (collectInputTimer <= 0 )
143+ {
144+ CollectInputs ();
145+ collectInputTimer = 1 .0f / collectInputFrequency;
146+ }
147+ }
148+
149+ // Make decisions
150+ {
151+ makeDecisionTimer -= scene->deltaTime ;
152+
153+ if (decisionBatch.HasBatchInProgress ())
154+ {
155+ if (decisionBatch.BatchIsComplete ())
156+ {
157+ BroadcastDecisions ();
158+ decisionBatch.ResetBatch ();
159+ }
160+ }
161+ else
162+ {
163+ if (makeDecisionTimer <= 0 )
164+ {
165+ StartMakingDecision ();
166+ makeDecisionTimer = 1 .0f / makeDecisionFrequency;
167+ }
168+ }
169+ }
170+
171+ // Collect training samples
172+ {
173+ collectTrainingSampleTimer -= scene->deltaTime ;
174+ if (collectTrainingSampleTimer <= 0 .0f )
175+ {
176+ CollectTrainingSamples (networkContext->trainer );
177+ collectTrainingSampleTimer = 1 .0f / collectTrainingSampleFrequency;
178+ }
179+ }
180+ }
181+ }
182+
183+ protected:
184+ virtual void CollectInput (TEntity* entity, InputType& input)
185+ {
186+
187+ }
188+
189+ void ForEachEntity (const std::function<void (TEntity*)>& func)
190+ {
191+ for (auto & entityBufferPair : samplesByEntity)
192+ {
193+ func (entityBufferPair.first );
194+ }
195+ }
196+
197+ private:
198+ void CollectInputs ()
199+ {
200+ for (auto & entityBufferPair : samplesByEntity)
201+ {
202+ TEntity* entity = entityBufferPair.first ;
203+
204+ if (!IncludeEntityInBatch (entity))
205+ {
206+ continue ;
207+ }
208+
209+ InputCircularBuffer* buffer = entityBufferPair.second ;
210+ InputType* input = buffer->DequeueHeadIfFullAndAllocate ();
211+ CollectInput (entityBufferPair.first , *input);
212+ }
213+ }
214+
215+ void StartMakingDecision ()
216+ {
217+ for (auto entityBufferPair : samplesByEntity)
218+ {
219+ TEntity* entity = entityBufferPair.first ;
220+
221+ if (!IncludeEntityInBatch (entity))
222+ {
223+ continue ;
224+ }
225+
226+ InputCircularBuffer* buffer = entityBufferPair.second ;
227+
228+ // Include in batch if there are enough inputs in the sequence
229+ bool includeInBatch = buffer->IsFull ();
230+ if (includeInBatch)
231+ {
232+ decisionBatch.AddToBatch (entity, *buffer);
233+ }
234+ }
235+
236+ decisionBatch.StartBatchIfAnyEntities ();
237+ }
238+
239+ void BroadcastDecisions ()
240+ {
241+ for (int i = 0 ; i < decisionBatch.entitiesInBatch .size (); ++i)
242+ {
243+ Entity* entity;
244+
245+ // Make sure the entity wasn't destroyed in the middle of making a decision
246+ if (decisionBatch.entitiesInBatch [i].TryGetValue (entity))
247+ {
248+ TEntity* entityAsTEntity = static_cast <TEntity*>(entity);
249+ ReceiveDecision (entityAsTEntity, decisionBatch.decisionOutput .data .get ()[i]);
250+ }
251+ }
252+ }
253+
254+ virtual void ReceiveDecision (TEntity* entity, OutputType& output)
255+ {
256+
257+ }
258+
259+ virtual void CollectTrainingSamples (TrainerType* trainer)
260+ {
261+
262+ }
263+
264+ void OnEntityAdded (Entity* entity) override
265+ {
266+ // Safe to do static_cast<> since we're only subscribing to entities of one type
267+ TEntity* entityAsTEntity = static_cast <TEntity*>(entity);
268+ if (TrackEntity (entityAsTEntity))
269+ {
270+ auto buffer = bufferAllocator.Allocate ();
271+ samplesByEntity[entityAsTEntity] = buffer;
272+ }
273+ }
274+
275+ void OnEntityRemoved (Entity* entity) override
276+ {
277+ TEntity* entityAsTEntity = static_cast <TEntity*>(entity);
278+ auto it = samplesByEntity.find (entityAsTEntity);
279+ if (it != samplesByEntity.end ())
280+ {
281+ bufferAllocator.Free (it->second );
282+ }
283+
284+ samplesByEntity.erase (it);
285+ }
286+
287+ virtual bool IncludeEntityInBatch (TEntity* entity)
288+ {
289+ return true ;
290+ }
291+
292+ virtual bool TrackEntity (TEntity* entity)
293+ {
294+ return true ;
295+ }
296+
297+ float makeDecisionTimer = 0 .0f ;
298+ float makeDecisionFrequency = 1 .0f ;
299+
300+ float collectInputTimer = 0 .0f ;
301+ float collectInputFrequency = 1 .0f ;
302+
303+ float collectTrainingSampleTimer = 0 .0f ;
304+ float collectTrainingSampleFrequency = 1 .0f ;
305+
306+ StrifeML::NetworkContext<TNetwork>* networkContext;
307+
308+ robin_hood::unordered_flat_map<TEntity*, InputCircularBuffer*> samplesByEntity;
309+
310+ InputCircularBufferAllocator<InputType> bufferAllocator;
311+ DecisionBatch<TNetwork> decisionBatch;
312+ };
0 commit comments