Skip to content

Commit 748df0a

Browse files
committed
NeuralNetworkService
1 parent ae3dc92 commit 748df0a

File tree

11 files changed

+448
-48
lines changed

11 files changed

+448
-48
lines changed

src/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ add_library(Strife.Engine STATIC
161161
Net/FileTransfer.cpp
162162
Net/SyncVar.cpp
163163
Resource/ResourceManager.hpp Resource/SpriteResource.hpp Resource/SpriteResource.cpp Resource/ResourceManager.cpp Resource/TilemapResource.hpp Resource/TilemapResource.cpp Resource/SpriteFontResource.hpp Resource/SpriteFontResource.cpp
164-
Resource/ShaderResource.cpp Resource/ShaderResource.hpp Components/ParticleSystemComponent.hpp Components/ParticleSystemComponent.cpp Scene/Isometric.hpp Scene/Isometric.cpp Components/IsometricSpriteComponent.hpp Components/IsometricSpriteComponent.cpp Resource/FileResource.hpp Resource/FileResource.cpp ML/UtilityAI.hpp Resource/ScriptResource.hpp Resource/ScriptResource.cpp ML/GridSensor.hpp Renderer/SpriteEffect.hpp Renderer/SpriteEffect.cpp Renderer/Stage/RenderPipeline.hpp Renderer/Stage/RenderPipeline.cpp Resource/SpriteAtlasResource.hpp Resource/SpriteAtlasResource.cpp Resource/ResourceSettings.hpp Components/AnimatorComponent.hpp Components/AnimatorComponent.cpp)
164+
Resource/ShaderResource.cpp Resource/ShaderResource.hpp Components/ParticleSystemComponent.hpp Components/ParticleSystemComponent.cpp Scene/Isometric.hpp Scene/Isometric.cpp Components/IsometricSpriteComponent.hpp Components/IsometricSpriteComponent.cpp Resource/FileResource.hpp Resource/FileResource.cpp ML/UtilityAI.hpp Resource/ScriptResource.hpp Resource/ScriptResource.cpp ML/GridSensor.hpp Renderer/SpriteEffect.hpp Renderer/SpriteEffect.cpp Renderer/Stage/RenderPipeline.hpp Renderer/Stage/RenderPipeline.cpp Resource/SpriteAtlasResource.hpp Resource/SpriteAtlasResource.cpp Resource/ResourceSettings.hpp Components/AnimatorComponent.hpp Components/AnimatorComponent.cpp ML/NeuralNetworkService.hpp)
165165

166166
message("SDL DIRS: ${SDL2_INCLUDE_DIRS}")
167167

src/Components/NetComponent.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ struct PlayerCommandHandler
130130
std::unordered_map<PlayerCommand::Metadata*, std::function<void(const PlayerCommand&)>> handlerByMetadata;
131131

132132
BlockAllocator* blockAllocator;
133-
CircularQueue<ScheduledCommand, 32> localCommands;
133+
FixedSizeCircularQueue<ScheduledCommand, 32> localCommands;
134134
unsigned int fixedUpdateCount = 0;
135135
unsigned int nextCommandSequenceNumber = 0;
136136
};

src/ML/ML.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ struct NeuralNetworkManager
253253
}
254254

255255
template<typename TDecider, typename TTrainer>
256-
void CreateNetwork(const char* name, TDecider* decider, TTrainer* trainer)
256+
StrifeML::NetworkContext<typename TDecider::NetworkType>* CreateNetwork(const char* name, TDecider* decider, TTrainer* trainer)
257257
{
258258
static_assert(std::is_same_v<typename TDecider::NetworkType, typename TTrainer::NetworkType>, "Trainer and decider must accept the same type of neural network");
259259

@@ -272,6 +272,8 @@ struct NeuralNetworkManager
272272
trainer->OnCreateNewNetwork(trainer->network);
273273

274274
newContext->decider->networkContext = newContext;
275+
276+
return newContext.get();
275277
}
276278

277279
// TODO remove network method

src/ML/NeuralNetworkService.hpp

Lines changed: 312 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,312 @@
1+
#pragma once
2+
3+
#include "Memory/CircularQueue.hpp"
4+
#include "Scene/Scene.hpp"
5+
#include "Scene/IEntityEvent.hpp"
6+
#include "ML/ML.hpp"
7+
8+
template<typename TInput>
9+
struct InputCircularBufferAllocator
10+
{
11+
using InputCircularBuffer = CircularQueue<TInput>;
12+
13+
InputCircularBufferAllocator(int maxBuffers, int circularBufferSize)
14+
: bufferPool(std::make_unique<InputCircularBuffer[]>(maxBuffers)),
15+
inputs(std::make_unique<TInput[]>(maxBuffers * circularBufferSize)),
16+
freeBuffers(bufferPool.get(), maxBuffers),
17+
circularBufferSize(circularBufferSize)
18+
{
19+
20+
}
21+
22+
InputCircularBuffer* Allocate()
23+
{
24+
auto buffer = freeBuffers.Borrow();
25+
26+
// Returning a node to the free list overwrites its data, so this has to be initialized every time its allocated
27+
int sampleId = buffer - bufferPool.get();
28+
TInput* storageStart = inputs.get() + sampleId * circularBufferSize;
29+
30+
new (buffer) InputCircularBuffer(storageStart, circularBufferSize);
31+
32+
return buffer;
33+
}
34+
35+
void Free(InputCircularBuffer* buffer)
36+
{
37+
freeBuffers.Return(buffer);
38+
}
39+
40+
std::unique_ptr<InputCircularBuffer[]> bufferPool;
41+
std::unique_ptr<TInput[]> inputs;
42+
FreeList<InputCircularBuffer> freeBuffers;
43+
int circularBufferSize;
44+
};
45+
46+
template<typename TNetwork>
47+
struct DecisionBatch
48+
{
49+
using InputType = typename TNetwork::InputType;
50+
using OutputType = typename TNetwork::OutputType;
51+
using InputCircularBuffer = CircularQueue<InputType>;
52+
53+
DecisionBatch(int maxBatchSize, int sequenceLength, StrifeML::NetworkContext<TNetwork>* networkContext)
54+
: sequenceLength(sequenceLength),
55+
decisionInput(maxBatchSize * sequenceLength),
56+
decisionOutput(maxBatchSize),
57+
networkContext(networkContext)
58+
{
59+
entitiesInBatch.reserve(maxBatchSize);
60+
}
61+
62+
void ResetBatch()
63+
{
64+
entitiesInBatch.clear();
65+
decisionInProgress = nullptr;
66+
}
67+
68+
void AddToBatch(Entity* entity, InputCircularBuffer& buffer)
69+
{
70+
int row = entitiesInBatch.size();
71+
entitiesInBatch.emplace_back(entity);
72+
73+
int col = 0;
74+
for (PlayerInput& sample : buffer)
75+
{
76+
decisionInput.data.get()[row * sequenceLength + col] = sample;
77+
++col;
78+
}
79+
}
80+
81+
bool HasBatchInProgress() const
82+
{
83+
return decisionInProgress != nullptr;
84+
}
85+
86+
bool BatchIsComplete() const
87+
{
88+
return decisionInProgress->IsComplete();
89+
}
90+
91+
void StartBatchIfAnyEntities()
92+
{
93+
if (entitiesInBatch.size() > 0)
94+
{
95+
printf("Start making decision\n");
96+
97+
decisionInProgress = networkContext->decider->MakeDecision(
98+
decisionInput,
99+
decisionOutput,
100+
networkContext->sequenceLength,
101+
entitiesInBatch.size());
102+
}
103+
}
104+
105+
int sequenceLength;
106+
std::vector<EntityReference<Entity>> entitiesInBatch;
107+
std::shared_ptr<StrifeML::MakeDecisionWorkItem<TNetwork>> decisionInProgress;
108+
StrifeML::MlUtil::SharedArray<InputType> decisionInput;
109+
StrifeML::MlUtil::SharedArray<OutputType> decisionOutput;
110+
StrifeML::NetworkContext<TNetwork>* networkContext;
111+
};
112+
113+
template<typename TEntity, typename TNetwork>
114+
struct NeuralNetworkService : ISceneService, IEntityObserver
115+
{
116+
using InputType = typename TNetwork::InputType;
117+
using OutputType = typename TNetwork::OutputType;
118+
using TrainerType = StrifeML::Trainer<TNetwork>;
119+
120+
using InputCircularBuffer = CircularQueue<InputType>;
121+
122+
NeuralNetworkService(StrifeML::NetworkContext<TNetwork>* networkContext, int maxEntitiesInBatch)
123+
: networkContext(networkContext),
124+
bufferAllocator(maxEntitiesInBatch, networkContext->sequenceLength + 1),
125+
decisionBatch(maxEntitiesInBatch, networkContext->sequenceLength, networkContext)
126+
{
127+
128+
}
129+
130+
void OnAdded() override
131+
{
132+
scene->AddEntityObserver<TEntity>(this);
133+
}
134+
135+
void ReceiveEvent(const IEntityEvent& ev) override
136+
{
137+
if (ev.Is<UpdateEvent>())
138+
{
139+
// Collect inputs
140+
{
141+
collectInputTimer -= scene->deltaTime;
142+
if (collectInputTimer <= 0)
143+
{
144+
CollectInputs();
145+
collectInputTimer = 1.0f / collectInputFrequency;
146+
}
147+
}
148+
149+
// Make decisions
150+
{
151+
makeDecisionTimer -= scene->deltaTime;
152+
153+
if (decisionBatch.HasBatchInProgress())
154+
{
155+
if (decisionBatch.BatchIsComplete())
156+
{
157+
BroadcastDecisions();
158+
decisionBatch.ResetBatch();
159+
}
160+
}
161+
else
162+
{
163+
if (makeDecisionTimer <= 0)
164+
{
165+
StartMakingDecision();
166+
makeDecisionTimer = 1.0f / makeDecisionFrequency;
167+
}
168+
}
169+
}
170+
171+
// Collect training samples
172+
{
173+
collectTrainingSampleTimer -= scene->deltaTime;
174+
if (collectTrainingSampleTimer <= 0.0f)
175+
{
176+
CollectTrainingSamples(networkContext->trainer);
177+
collectTrainingSampleTimer = 1.0f / collectTrainingSampleFrequency;
178+
}
179+
}
180+
}
181+
}
182+
183+
protected:
184+
virtual void CollectInput(TEntity* entity, InputType& input)
185+
{
186+
187+
}
188+
189+
void ForEachEntity(const std::function<void(TEntity*)>& func)
190+
{
191+
for (auto& entityBufferPair : samplesByEntity)
192+
{
193+
func(entityBufferPair.first);
194+
}
195+
}
196+
197+
private:
198+
void CollectInputs()
199+
{
200+
for (auto& entityBufferPair : samplesByEntity)
201+
{
202+
TEntity* entity = entityBufferPair.first;
203+
204+
if (!IncludeEntityInBatch(entity))
205+
{
206+
continue;
207+
}
208+
209+
InputCircularBuffer* buffer = entityBufferPair.second;
210+
InputType* input = buffer->DequeueHeadIfFullAndAllocate();
211+
CollectInput(entityBufferPair.first, *input);
212+
}
213+
}
214+
215+
void StartMakingDecision()
216+
{
217+
for (auto entityBufferPair : samplesByEntity)
218+
{
219+
TEntity* entity = entityBufferPair.first;
220+
221+
if (!IncludeEntityInBatch(entity))
222+
{
223+
continue;
224+
}
225+
226+
InputCircularBuffer* buffer = entityBufferPair.second;
227+
228+
// Include in batch if there are enough inputs in the sequence
229+
bool includeInBatch = buffer->IsFull();
230+
if (includeInBatch)
231+
{
232+
decisionBatch.AddToBatch(entity, *buffer);
233+
}
234+
}
235+
236+
decisionBatch.StartBatchIfAnyEntities();
237+
}
238+
239+
void BroadcastDecisions()
240+
{
241+
for (int i = 0; i < decisionBatch.entitiesInBatch.size(); ++i)
242+
{
243+
Entity* entity;
244+
245+
// Make sure the entity wasn't destroyed in the middle of making a decision
246+
if (decisionBatch.entitiesInBatch[i].TryGetValue(entity))
247+
{
248+
TEntity* entityAsTEntity = static_cast<TEntity*>(entity);
249+
ReceiveDecision(entityAsTEntity, decisionBatch.decisionOutput.data.get()[i]);
250+
}
251+
}
252+
}
253+
254+
virtual void ReceiveDecision(TEntity* entity, OutputType& output)
255+
{
256+
257+
}
258+
259+
virtual void CollectTrainingSamples(TrainerType* trainer)
260+
{
261+
262+
}
263+
264+
void OnEntityAdded(Entity* entity) override
265+
{
266+
// Safe to do static_cast<> since we're only subscribing to entities of one type
267+
TEntity* entityAsTEntity = static_cast<TEntity*>(entity);
268+
if (TrackEntity(entityAsTEntity))
269+
{
270+
auto buffer = bufferAllocator.Allocate();
271+
samplesByEntity[entityAsTEntity] = buffer;
272+
}
273+
}
274+
275+
void OnEntityRemoved(Entity* entity) override
276+
{
277+
TEntity* entityAsTEntity = static_cast<TEntity*>(entity);
278+
auto it = samplesByEntity.find(entityAsTEntity);
279+
if (it != samplesByEntity.end())
280+
{
281+
bufferAllocator.Free(it->second);
282+
}
283+
284+
samplesByEntity.erase(it);
285+
}
286+
287+
virtual bool IncludeEntityInBatch(TEntity* entity)
288+
{
289+
return true;
290+
}
291+
292+
virtual bool TrackEntity(TEntity* entity)
293+
{
294+
return true;
295+
}
296+
297+
float makeDecisionTimer = 0.0f;
298+
float makeDecisionFrequency = 1.0f;
299+
300+
float collectInputTimer = 0.0f;
301+
float collectInputFrequency = 1.0f;
302+
303+
float collectTrainingSampleTimer = 0.0f;
304+
float collectTrainingSampleFrequency = 1.0f;
305+
306+
StrifeML::NetworkContext<TNetwork>* networkContext;
307+
308+
robin_hood::unordered_flat_map<TEntity*, InputCircularBuffer*> samplesByEntity;
309+
310+
InputCircularBufferAllocator<InputType> bufferAllocator;
311+
DecisionBatch<TNetwork> decisionBatch;
312+
};

0 commit comments

Comments
 (0)