Skip to content

Commit 9f0ad84

Browse files
Pass sequenceLength to network context
1 parent 73fa49e commit 9f0ad84

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/ML/ML.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ struct NeuralNetworkManager
118118
}
119119

120120
template<typename TDecider, typename TTrainer>
121-
StrifeML::NetworkContext<typename TDecider::NetworkType>* CreateNetwork(const char* name, TDecider* decider, TTrainer* trainer)
121+
StrifeML::NetworkContext<typename TDecider::NetworkType>* CreateNetwork(const char* name, TDecider* decider, TTrainer* trainer, int sequenceLength)
122122
{
123123
static_assert(std::is_same_v<typename TDecider::NetworkType, typename TTrainer::NetworkType>, "Trainer and decider must accept the same type of neural network");
124124

@@ -128,7 +128,7 @@ struct NeuralNetworkManager
128128
throw StrifeML::StrifeException("Network already exists: " + std::string(name));
129129
}
130130

131-
auto newContext = std::make_shared<StrifeML::NetworkContext<typename TDecider::NetworkType>>(decider, trainer);
131+
auto newContext = std::make_shared<StrifeML::NetworkContext<typename TDecider::NetworkType>>(decider, trainer, sequenceLength);
132132
_networksByName[name] = newContext;
133133

134134
trainer->networkContext = newContext;

0 commit comments

Comments
 (0)