Skip to content

Commit 32d7940

Browse files
authored
Merge pull request #66 from Strife-AI/neural-network-service
Neural network service
2 parents 233a908 + f136232 commit 32d7940

File tree

11 files changed

+476
-48
lines changed

11 files changed

+476
-48
lines changed

src/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ add_library(Strife.Engine STATIC
161161
Net/FileTransfer.cpp
162162
Net/SyncVar.cpp
163163
Resource/ResourceManager.hpp Resource/SpriteResource.hpp Resource/SpriteResource.cpp Resource/ResourceManager.cpp Resource/TilemapResource.hpp Resource/TilemapResource.cpp Resource/SpriteFontResource.hpp Resource/SpriteFontResource.cpp
164-
Resource/ShaderResource.cpp Resource/ShaderResource.hpp Components/ParticleSystemComponent.hpp Components/ParticleSystemComponent.cpp Scene/Isometric.hpp Scene/Isometric.cpp Components/IsometricSpriteComponent.hpp Components/IsometricSpriteComponent.cpp Resource/FileResource.hpp Resource/FileResource.cpp ML/UtilityAI.hpp Resource/ScriptResource.hpp Resource/ScriptResource.cpp ML/GridSensor.hpp Renderer/SpriteEffect.hpp Renderer/SpriteEffect.cpp Renderer/Stage/RenderPipeline.hpp Renderer/Stage/RenderPipeline.cpp Resource/SpriteAtlasResource.hpp Resource/SpriteAtlasResource.cpp Resource/ResourceSettings.hpp Components/AnimatorComponent.hpp Components/AnimatorComponent.cpp)
164+
Resource/ShaderResource.cpp Resource/ShaderResource.hpp Components/ParticleSystemComponent.hpp Components/ParticleSystemComponent.cpp Scene/Isometric.hpp Scene/Isometric.cpp Components/IsometricSpriteComponent.hpp Components/IsometricSpriteComponent.cpp Resource/FileResource.hpp Resource/FileResource.cpp ML/UtilityAI.hpp Resource/ScriptResource.hpp Resource/ScriptResource.cpp ML/GridSensor.hpp Renderer/SpriteEffect.hpp Renderer/SpriteEffect.cpp Renderer/Stage/RenderPipeline.hpp Renderer/Stage/RenderPipeline.cpp Resource/SpriteAtlasResource.hpp Resource/SpriteAtlasResource.cpp Resource/ResourceSettings.hpp Components/AnimatorComponent.hpp Components/AnimatorComponent.cpp ML/NeuralNetworkService.hpp)
165165

166166
message("SDL DIRS: ${SDL2_INCLUDE_DIRS}")
167167

src/Components/NetComponent.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ struct PlayerCommandHandler
130130
std::unordered_map<PlayerCommand::Metadata*, std::function<void(const PlayerCommand&)>> handlerByMetadata;
131131

132132
BlockAllocator* blockAllocator;
133-
CircularQueue<ScheduledCommand, 32> localCommands;
133+
FixedSizeCircularQueue<ScheduledCommand, 32> localCommands;
134134
unsigned int fixedUpdateCount = 0;
135135
unsigned int nextCommandSequenceNumber = 0;
136136
};

src/ML/ML.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ struct NeuralNetworkManager
253253
}
254254

255255
template<typename TDecider, typename TTrainer>
256-
void CreateNetwork(const char* name, TDecider* decider, TTrainer* trainer)
256+
StrifeML::NetworkContext<typename TDecider::NetworkType>* CreateNetwork(const char* name, TDecider* decider, TTrainer* trainer)
257257
{
258258
static_assert(std::is_same_v<typename TDecider::NetworkType, typename TTrainer::NetworkType>, "Trainer and decider must accept the same type of neural network");
259259

@@ -272,6 +272,8 @@ struct NeuralNetworkManager
272272
trainer->OnCreateNewNetwork(trainer->network);
273273

274274
newContext->decider->networkContext = newContext;
275+
276+
return newContext.get();
275277
}
276278

277279
// TODO remove network method

src/ML/NeuralNetworkService.hpp

Lines changed: 340 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,340 @@
1+
#pragma once
2+
3+
#include "Memory/CircularQueue.hpp"
4+
#include "Scene/Scene.hpp"
5+
#include "Scene/IEntityEvent.hpp"
6+
#include "ML/ML.hpp"
7+
8+
template<typename TInput>
9+
struct InputCircularBufferAllocator
10+
{
11+
using InputCircularBuffer = CircularQueue<TInput>;
12+
13+
InputCircularBufferAllocator(int maxBuffers, int circularBufferSize)
14+
: bufferPool(std::make_unique<InputCircularBuffer[]>(maxBuffers)),
15+
inputs(std::make_unique<TInput[]>(maxBuffers * circularBufferSize)),
16+
freeBuffers(bufferPool.get(), maxBuffers),
17+
circularBufferSize(circularBufferSize)
18+
{
19+
20+
}
21+
22+
InputCircularBuffer* Allocate();
23+
24+
void Free(InputCircularBuffer* buffer)
25+
{
26+
freeBuffers.Return(buffer);
27+
}
28+
29+
std::unique_ptr<InputCircularBuffer[]> bufferPool;
30+
std::unique_ptr<TInput[]> inputs;
31+
FreeList<InputCircularBuffer> freeBuffers;
32+
int circularBufferSize;
33+
};
34+
35+
template<typename TInput>
36+
CircularQueue<TInput>* InputCircularBufferAllocator<TInput>::Allocate()
37+
{
38+
auto buffer = freeBuffers.Borrow();
39+
40+
// Returning a node to the free list overwrites its data, so this has to be initialized every time its allocated
41+
int sampleId = buffer - bufferPool.get();
42+
TInput* storageStart = inputs.get() + sampleId * circularBufferSize;
43+
44+
new (buffer) InputCircularBuffer(storageStart, circularBufferSize);
45+
46+
return buffer;
47+
}
48+
49+
template<typename TNetwork>
50+
struct DecisionBatch
51+
{
52+
using InputType = typename TNetwork::InputType;
53+
using OutputType = typename TNetwork::OutputType;
54+
using InputCircularBuffer = CircularQueue<InputType>;
55+
56+
DecisionBatch(int maxBatchSize, int sequenceLength, StrifeML::NetworkContext<TNetwork>* networkContext)
57+
: sequenceLength(sequenceLength),
58+
decisionInput(maxBatchSize * sequenceLength),
59+
decisionOutput(maxBatchSize),
60+
networkContext(networkContext)
61+
{
62+
entitiesInBatch.reserve(maxBatchSize);
63+
}
64+
65+
void ResetBatch();
66+
void AddToBatch(Entity* entity, InputCircularBuffer& buffer);
67+
68+
bool HasBatchInProgress() const
69+
{
70+
return decisionInProgress != nullptr;
71+
}
72+
73+
bool BatchIsComplete() const
74+
{
75+
return decisionInProgress->IsComplete();
76+
}
77+
78+
void StartBatchIfAnyEntities();
79+
80+
int sequenceLength;
81+
std::vector<EntityReference<Entity>> entitiesInBatch;
82+
std::shared_ptr<StrifeML::MakeDecisionWorkItem<TNetwork>> decisionInProgress;
83+
StrifeML::MlUtil::SharedArray<InputType> decisionInput;
84+
StrifeML::MlUtil::SharedArray<OutputType> decisionOutput;
85+
StrifeML::NetworkContext<TNetwork>* networkContext;
86+
};
87+
88+
template<typename TNetwork>
89+
void DecisionBatch<TNetwork>::ResetBatch()
90+
{
91+
entitiesInBatch.clear();
92+
decisionInProgress = nullptr;
93+
}
94+
95+
template<typename TNetwork>
96+
void DecisionBatch<TNetwork>::AddToBatch(Entity* entity, DecisionBatch::InputCircularBuffer& buffer)
97+
{
98+
int row = entitiesInBatch.size();
99+
entitiesInBatch.emplace_back(entity);
100+
101+
int col = 0;
102+
for (PlayerInput& sample : buffer)
103+
{
104+
decisionInput.data.get()[row * sequenceLength + col] = sample;
105+
++col;
106+
}
107+
}
108+
109+
template<typename TNetwork>
110+
void DecisionBatch<TNetwork>::StartBatchIfAnyEntities()
111+
{
112+
if (entitiesInBatch.size() > 0)
113+
{
114+
decisionInProgress = networkContext->decider->MakeDecision(
115+
decisionInput,
116+
decisionOutput,
117+
networkContext->sequenceLength,
118+
entitiesInBatch.size());
119+
}
120+
}
121+
122+
template<typename TEntity, typename TNetwork>
123+
struct NeuralNetworkService : ISceneService, IEntityObserver
124+
{
125+
using InputType = typename TNetwork::InputType;
126+
using OutputType = typename TNetwork::OutputType;
127+
using TrainerType = StrifeML::Trainer<TNetwork>;
128+
129+
using InputCircularBuffer = CircularQueue<InputType>;
130+
131+
NeuralNetworkService(StrifeML::NetworkContext<TNetwork>* networkContext, int maxEntitiesInBatch)
132+
: networkContext(networkContext),
133+
bufferAllocator(maxEntitiesInBatch, networkContext->sequenceLength + 1),
134+
decisionBatch(maxEntitiesInBatch, networkContext->sequenceLength, networkContext)
135+
{
136+
137+
}
138+
139+
void OnAdded() override;
140+
void ReceiveEvent(const IEntityEvent& ev) override;
141+
142+
protected:
143+
virtual void CollectInput(TEntity* entity, InputType& input)
144+
{
145+
146+
}
147+
148+
void ForEachEntity(const std::function<void(TEntity*)>& func);
149+
150+
private:
151+
void CollectInputs();
152+
void StartMakingDecision();
153+
void BroadcastDecisions();
154+
void OnEntityAdded(Entity* entity) override;
155+
void OnEntityRemoved(Entity* entity) override;
156+
157+
virtual void ReceiveDecision(TEntity* entity, OutputType& output)
158+
{
159+
160+
}
161+
162+
virtual void CollectTrainingSamples(TrainerType* trainer)
163+
{
164+
165+
}
166+
167+
virtual bool IncludeEntityInBatch(TEntity* entity)
168+
{
169+
return true;
170+
}
171+
172+
virtual bool TrackEntity(TEntity* entity)
173+
{
174+
return true;
175+
}
176+
177+
float makeDecisionTimer = 0.0f;
178+
float makeDecisionFrequency = 1.0f;
179+
180+
float collectInputTimer = 0.0f;
181+
float collectInputFrequency = 1.0f;
182+
183+
float collectTrainingSampleTimer = 0.0f;
184+
float collectTrainingSampleFrequency = 1.0f;
185+
186+
StrifeML::NetworkContext<TNetwork>* networkContext;
187+
188+
robin_hood::unordered_flat_map<TEntity*, InputCircularBuffer*> samplesByEntity;
189+
190+
InputCircularBufferAllocator<InputType> bufferAllocator;
191+
DecisionBatch<TNetwork> decisionBatch;
192+
};
193+
194+
template<typename TEntity, typename TNetwork>
195+
void NeuralNetworkService<TEntity, TNetwork>::OnAdded()
196+
{
197+
scene->AddEntityObserver<TEntity>(this);
198+
}
199+
200+
template<typename TEntity, typename TNetwork>
201+
void NeuralNetworkService<TEntity, TNetwork>::ReceiveEvent(const IEntityEvent& ev)
202+
{
203+
if (ev.Is<UpdateEvent>())
204+
{
205+
// Collect inputs
206+
{
207+
collectInputTimer -= scene->deltaTime;
208+
if (collectInputTimer <= 0)
209+
{
210+
CollectInputs();
211+
collectInputTimer = 1.0f / collectInputFrequency;
212+
}
213+
}
214+
215+
// Make decisions
216+
{
217+
makeDecisionTimer -= scene->deltaTime;
218+
219+
if (decisionBatch.HasBatchInProgress())
220+
{
221+
if (decisionBatch.BatchIsComplete())
222+
{
223+
BroadcastDecisions();
224+
decisionBatch.ResetBatch();
225+
}
226+
}
227+
else
228+
{
229+
if (makeDecisionTimer <= 0)
230+
{
231+
StartMakingDecision();
232+
makeDecisionTimer = 1.0f / makeDecisionFrequency;
233+
}
234+
}
235+
}
236+
237+
// Collect training samples
238+
{
239+
collectTrainingSampleTimer -= scene->deltaTime;
240+
if (collectTrainingSampleTimer <= 0.0f)
241+
{
242+
CollectTrainingSamples(networkContext->trainer);
243+
collectTrainingSampleTimer = 1.0f / collectTrainingSampleFrequency;
244+
}
245+
}
246+
}
247+
}
248+
249+
template<typename TEntity, typename TNetwork>
250+
void NeuralNetworkService<TEntity, TNetwork>::ForEachEntity(const std::function<void(TEntity*)>& func)
251+
{
252+
for (auto& entityBufferPair : samplesByEntity)
253+
{
254+
func(entityBufferPair.first);
255+
}
256+
}
257+
258+
template<typename TEntity, typename TNetwork>
259+
void NeuralNetworkService<TEntity, TNetwork>::CollectInputs()
260+
{
261+
for (auto& entityBufferPair : samplesByEntity)
262+
{
263+
TEntity* entity = entityBufferPair.first;
264+
265+
if (!IncludeEntityInBatch(entity))
266+
{
267+
continue;
268+
}
269+
270+
InputCircularBuffer* buffer = entityBufferPair.second;
271+
InputType* input = buffer->DequeueHeadIfFullAndAllocate();
272+
CollectInput(entityBufferPair.first, *input);
273+
}
274+
}
275+
276+
template<typename TEntity, typename TNetwork>
277+
void NeuralNetworkService<TEntity, TNetwork>::StartMakingDecision()
278+
{
279+
for (auto entityBufferPair : samplesByEntity)
280+
{
281+
TEntity* entity = entityBufferPair.first;
282+
283+
if (!IncludeEntityInBatch(entity))
284+
{
285+
continue;
286+
}
287+
288+
InputCircularBuffer* buffer = entityBufferPair.second;
289+
290+
// Include in batch if there are enough inputs in the sequence
291+
bool includeInBatch = buffer->IsFull();
292+
if (includeInBatch)
293+
{
294+
decisionBatch.AddToBatch(entity, *buffer);
295+
}
296+
}
297+
298+
decisionBatch.StartBatchIfAnyEntities();
299+
}
300+
301+
template<typename TEntity, typename TNetwork>
302+
void NeuralNetworkService<TEntity, TNetwork>::BroadcastDecisions()
303+
{
304+
for (int i = 0; i < decisionBatch.entitiesInBatch.size(); ++i)
305+
{
306+
Entity* entity;
307+
308+
// Make sure the entity wasn't destroyed in the middle of making a decision
309+
if (decisionBatch.entitiesInBatch[i].TryGetValue(entity))
310+
{
311+
TEntity* entityAsTEntity = static_cast<TEntity*>(entity);
312+
ReceiveDecision(entityAsTEntity, decisionBatch.decisionOutput.data.get()[i]);
313+
}
314+
}
315+
}
316+
317+
template<typename TEntity, typename TNetwork>
318+
void NeuralNetworkService<TEntity, TNetwork>::OnEntityAdded(Entity* entity)
319+
{
320+
// Safe to do static_cast<> since we're only subscribing to entities of one type
321+
TEntity* entityAsTEntity = static_cast<TEntity*>(entity);
322+
if (TrackEntity(entityAsTEntity))
323+
{
324+
auto buffer = bufferAllocator.Allocate();
325+
samplesByEntity[entityAsTEntity] = buffer;
326+
}
327+
}
328+
329+
template<typename TEntity, typename TNetwork>
330+
void NeuralNetworkService<TEntity, TNetwork>::OnEntityRemoved(Entity* entity)
331+
{
332+
TEntity* entityAsTEntity = static_cast<TEntity*>(entity);
333+
auto it = samplesByEntity.find(entityAsTEntity);
334+
if (it != samplesByEntity.end())
335+
{
336+
bufferAllocator.Free(it->second);
337+
}
338+
339+
samplesByEntity.erase(it);
340+
}

0 commit comments

Comments
 (0)