Skip to content

Commit f263ff7

Browse files
Merge pull request #68 from Strife-AI/nn-service
Remove tinycc and finish neural network service
2 parents 6f424b3 + 9f0ad84 commit f263ff7

File tree

10 files changed

+14
-346
lines changed

10 files changed

+14
-346
lines changed

src/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ add_library(Strife.Engine STATIC
161161
Net/FileTransfer.cpp
162162
Net/SyncVar.cpp
163163
Resource/ResourceManager.hpp Resource/SpriteResource.hpp Resource/SpriteResource.cpp Resource/ResourceManager.cpp Resource/TilemapResource.hpp Resource/TilemapResource.cpp Resource/SpriteFontResource.hpp Resource/SpriteFontResource.cpp
164-
Resource/ShaderResource.cpp Resource/ShaderResource.hpp Components/ParticleSystemComponent.hpp Components/ParticleSystemComponent.cpp Scene/Isometric.hpp Scene/Isometric.cpp Components/IsometricSpriteComponent.hpp Components/IsometricSpriteComponent.cpp Resource/FileResource.hpp Resource/FileResource.cpp ML/UtilityAI.hpp Resource/ScriptResource.hpp Resource/ScriptResource.cpp ML/GridSensor.hpp Renderer/SpriteEffect.hpp Renderer/SpriteEffect.cpp Renderer/Stage/RenderPipeline.hpp Renderer/Stage/RenderPipeline.cpp Resource/SpriteAtlasResource.hpp Resource/SpriteAtlasResource.cpp Resource/ResourceSettings.hpp Components/AnimatorComponent.hpp Components/AnimatorComponent.cpp ML/NeuralNetworkService.hpp)
164+
Resource/ShaderResource.cpp Resource/ShaderResource.hpp Components/ParticleSystemComponent.hpp Components/ParticleSystemComponent.cpp Scene/Isometric.hpp Scene/Isometric.cpp Components/IsometricSpriteComponent.hpp Components/IsometricSpriteComponent.cpp Resource/FileResource.hpp Resource/FileResource.cpp ML/UtilityAI.hpp ML/GridSensor.hpp Renderer/SpriteEffect.hpp Renderer/SpriteEffect.cpp Renderer/Stage/RenderPipeline.hpp Renderer/Stage/RenderPipeline.cpp Resource/SpriteAtlasResource.hpp Resource/SpriteAtlasResource.cpp Resource/ResourceSettings.hpp Components/AnimatorComponent.hpp Components/AnimatorComponent.cpp ML/NeuralNetworkService.hpp)
165165

166166
message("SDL DIRS: ${SDL2_INCLUDE_DIRS}")
167167

src/Engine.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
#include "Tools/MetricsManager.hpp"
2020
#include "Sound/SoundManager.hpp"
2121
#include "UI/UI.hpp"
22-
#include "Scripting/ScriptCompiler.hpp"
2322
#include "Net/ServerGame.hpp"
2423

2524
using namespace std::chrono;
@@ -39,18 +38,13 @@ void ExecuteOnGameThread(const std::function<void()>& function)
3938
g_workQueue.Enqueue(function);
4039
}
4140

42-
void RegisterScriptFunctions();
43-
4441
static void StrifeMlLog(const char* message)
4542
{
4643
Log("%s", message);
4744
}
4845

4946
Engine::Engine(const EngineConfig& config)
5047
{
51-
StrifeML::SetLogFunction(StrifeMlLog);
52-
RegisterScriptFunctions();
53-
5448
_config = config;
5549
_defaultBlockAllocator = std::make_unique<BlockAllocator>(config.blockAllocatorSizeBytes);
5650

@@ -197,8 +191,6 @@ void Engine::RunFrame()
197191
AccurateSleepFor(timeUntilUpdate);
198192
nextGameToRun->RunFrame(GetTimeSeconds());
199193
nextGameToRun->nextUpdateTime = nextGameToRun->nextUpdateTime + 1.0f / nextGameToRun->targetTickRate;
200-
201-
ScriptCompiler::GetInstance()->Update();
202194
}
203195

204196
void Engine::PauseGame()

src/ML/ML.hpp

Lines changed: 2 additions & 144 deletions
Original file line numberDiff line numberDiff line change
@@ -18,141 +18,6 @@ struct StrifeML::Serializer<Vector2>
1818
}
1919
};
2020

21-
enum class NeuralNetworkMode
22-
{
23-
Disabled,
24-
Deciding,
25-
CollectingSamples,
26-
ReinforcementLearning,
27-
};
28-
29-
template<typename TNeuralNetwork>
30-
struct NeuralNetworkComponent : ComponentTemplate<NeuralNetworkComponent<TNeuralNetwork>>
31-
{
32-
using InputType = typename TNeuralNetwork::InputType;
33-
using OutputType = typename TNeuralNetwork::OutputType;
34-
using NetworkType = TNeuralNetwork;
35-
using SampleType = StrifeML::Sample<InputType, OutputType>;
36-
37-
explicit NeuralNetworkComponent(float decisionsPerSecond_ = 1)
38-
: decisionInput(StrifeML::MlUtil::SharedArray<InputType>(TNeuralNetwork::SequenceLength)),
39-
decisionsPerSecond(decisionsPerSecond_)
40-
{
41-
42-
}
43-
44-
virtual ~NeuralNetworkComponent() = default;
45-
46-
void Update(float deltaTime) override;
47-
48-
void SetNetwork(const char* name);
49-
50-
bool StartMakingDecision();
51-
52-
StrifeML::NetworkContext<NetworkType>* networkContext = nullptr;
53-
std::shared_ptr<StrifeML::MakeDecisionWorkItem<TNeuralNetwork>> decisionInProgress;
54-
StrifeML::MlUtil::SharedArray<InputType> decisionInput;
55-
float makeDecisionsTimer = 0;
56-
float decisionsPerSecond;
57-
58-
std::shared_ptr<StrifeML::RunTrainingBatchWorkItem<TNeuralNetwork>> trainingInProgress;
59-
60-
int inputsCollected = 0;
61-
InputType previousInput;
62-
std::function<void(InputType& input)> collectInput;
63-
std::function<void(OutputType& decision)> collectDecision;
64-
65-
std::function<void(OutputType& decision)> receiveDecision;
66-
67-
NeuralNetworkMode mode = NeuralNetworkMode::Disabled;
68-
};
69-
70-
template <typename TNeuralNetwork>
71-
void NeuralNetworkComponent<TNeuralNetwork>::Update(float deltaTime)
72-
{
73-
if (mode == NeuralNetworkMode::Disabled)
74-
{
75-
return;
76-
}
77-
else if (mode == NeuralNetworkMode::Deciding || mode == NeuralNetworkMode::ReinforcementLearning)
78-
{
79-
OutputType output;
80-
if (decisionInProgress != nullptr && decisionInProgress->TryGetResult(output))
81-
{
82-
if (receiveDecision != nullptr)
83-
{
84-
receiveDecision(output);
85-
}
86-
87-
decisionInProgress = nullptr;
88-
}
89-
90-
makeDecisionsTimer -= deltaTime;
91-
92-
if (makeDecisionsTimer <= 0)
93-
{
94-
makeDecisionsTimer = 1.0f / decisionsPerSecond;
95-
auto wasStarted = StartMakingDecision();
96-
97-
if (wasStarted && mode == NeuralNetworkMode::ReinforcementLearning && inputsCollected > TNeuralNetwork::SequenceLength)
98-
{
99-
SampleType sample;
100-
sample.input = previousInput;
101-
collectDecision(sample.output);
102-
networkContext->trainer->AddSample(sample);
103-
}
104-
}
105-
}
106-
else if (mode == NeuralNetworkMode::CollectingSamples)
107-
{
108-
SampleType sample;
109-
collectInput(sample.input);
110-
collectDecision(sample.output);
111-
networkContext->trainer->AddSample(sample);
112-
}
113-
}
114-
115-
template <typename TNeuralNetwork>
116-
bool NeuralNetworkComponent<TNeuralNetwork>::StartMakingDecision()
117-
{
118-
// Don't allow making more than one decision at a time
119-
if (decisionInProgress != nullptr
120-
&& !decisionInProgress->IsComplete())
121-
{
122-
return false;
123-
}
124-
125-
if (collectInput == nullptr
126-
|| networkContext == nullptr
127-
|| networkContext->decider == nullptr)
128-
{
129-
return false;
130-
}
131-
132-
if (mode == NeuralNetworkMode::ReinforcementLearning && inputsCollected > TNeuralNetwork::SequenceLength)
133-
{
134-
previousInput = decisionInput.data.get()[TNeuralNetwork::SequenceLength - 1];
135-
}
136-
137-
// Expire oldest input
138-
for (int i = 0; i < TNeuralNetwork::SequenceLength - 1; ++i)
139-
{
140-
decisionInput.data.get()[i] = std::move(decisionInput.data.get()[i + 1]);
141-
}
142-
143-
// Collect new input
144-
collectInput(decisionInput.data.get()[TNeuralNetwork::SequenceLength - 1]);
145-
inputsCollected++;
146-
147-
// Start making decision if we have enough data
148-
if (inputsCollected >= TNeuralNetwork::SequenceLength)
149-
{
150-
decisionInProgress = networkContext->decider->MakeDecision(decisionInput, TNeuralNetwork::SequenceLength);
151-
}
152-
153-
return true;
154-
}
155-
15621
struct SensorObjectDefinition
15722
{
15823
SensorObjectDefinition()
@@ -253,7 +118,7 @@ struct NeuralNetworkManager
253118
}
254119

255120
template<typename TDecider, typename TTrainer>
256-
StrifeML::NetworkContext<typename TDecider::NetworkType>* CreateNetwork(const char* name, TDecider* decider, TTrainer* trainer)
121+
StrifeML::NetworkContext<typename TDecider::NetworkType>* CreateNetwork(const char* name, TDecider* decider, TTrainer* trainer, int sequenceLength)
257122
{
258123
static_assert(std::is_same_v<typename TDecider::NetworkType, typename TTrainer::NetworkType>, "Trainer and decider must accept the same type of neural network");
259124

@@ -263,7 +128,7 @@ struct NeuralNetworkManager
263128
throw StrifeML::StrifeException("Network already exists: " + std::string(name));
264129
}
265130

266-
auto newContext = std::make_shared<StrifeML::NetworkContext<typename TDecider::NetworkType>>(decider, trainer);
131+
auto newContext = std::make_shared<StrifeML::NetworkContext<typename TDecider::NetworkType>>(decider, trainer, sequenceLength);
267132
_networksByName[name] = newContext;
268133

269134
trainer->networkContext = newContext;
@@ -303,13 +168,6 @@ struct NeuralNetworkManager
303168
std::unordered_map<std::string, std::shared_ptr<StrifeML::INetworkContext>> _networksByName;
304169
};
305170

306-
template<typename TNeuralNetwork>
307-
void NeuralNetworkComponent<TNeuralNetwork>::SetNetwork(const char* name)
308-
{
309-
auto nnManager = this->owner->GetEngine()->GetNeuralNetworkManager();
310-
networkContext = nnManager->template GetNetwork<TNeuralNetwork>(name);
311-
}
312-
313171
gsl::span<uint64_t> ReadGridSensorRectangles(
314172
Scene* scene,
315173
Vector2 center,

src/ML/NeuralNetworkService.hpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ void DecisionBatch<TNetwork>::AddToBatch(Entity* entity, DecisionBatch::InputCir
9999
entitiesInBatch.emplace_back(entity);
100100

101101
int col = 0;
102-
for (PlayerInput& sample : buffer)
102+
for (InputType& sample : buffer)
103103
{
104104
decisionInput.data.get()[row * sequenceLength + col] = sample;
105105
++col;
@@ -124,6 +124,7 @@ struct NeuralNetworkService : ISceneService, IEntityObserver
124124
{
125125
using InputType = typename TNetwork::InputType;
126126
using OutputType = typename TNetwork::OutputType;
127+
using SampleType = typename TNetwork::SampleType;
127128
using TrainerType = StrifeML::Trainer<TNetwork>;
128129

129130
using InputCircularBuffer = CircularQueue<InputType>;
@@ -140,10 +141,7 @@ struct NeuralNetworkService : ISceneService, IEntityObserver
140141
void ReceiveEvent(const IEntityEvent& ev) override;
141142

142143
protected:
143-
virtual void CollectInput(TEntity* entity, InputType& input)
144-
{
145-
146-
}
144+
virtual void CollectInput(TEntity* entity, InputType& input) = 0;
147145

148146
void ForEachEntity(const std::function<void(TEntity*)>& func);
149147

src/ML/ScriptNetwork.cpp

Lines changed: 0 additions & 23 deletions
This file was deleted.

src/ML/ScriptNetwork.hpp

Lines changed: 0 additions & 124 deletions
This file was deleted.

0 commit comments

Comments
 (0)