Skip to content

Commit 6dfd733

Browse files
Make collect input pure virtual
1 parent ca40fe6 commit 6dfd733

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

src/ML/NeuralNetworkService.hpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ void DecisionBatch<TNetwork>::AddToBatch(Entity* entity, DecisionBatch::InputCir
9999
entitiesInBatch.emplace_back(entity);
100100

101101
int col = 0;
102-
for (PlayerInput& sample : buffer)
102+
for (InputType& sample : buffer)
103103
{
104104
decisionInput.data.get()[row * sequenceLength + col] = sample;
105105
++col;
@@ -124,6 +124,7 @@ struct NeuralNetworkService : ISceneService, IEntityObserver
124124
{
125125
using InputType = typename TNetwork::InputType;
126126
using OutputType = typename TNetwork::OutputType;
127+
using SampleType = typename TNetwork::SampleType;
127128
using TrainerType = StrifeML::Trainer<TNetwork>;
128129

129130
using InputCircularBuffer = CircularQueue<InputType>;
@@ -140,10 +141,7 @@ struct NeuralNetworkService : ISceneService, IEntityObserver
140141
void ReceiveEvent(const IEntityEvent& ev) override;
141142

142143
protected:
143-
virtual void CollectInput(TEntity* entity, InputType& input)
144-
{
145-
146-
}
144+
virtual void CollectInput(TEntity* entity, InputType& input) = 0;
147145

148146
void ForEachEntity(const std::function<void(TEntity*)>& func);
149147

0 commit comments

Comments
 (0)