-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgraph.cpp
More file actions
38 lines (28 loc) · 1018 Bytes
/
graph.cpp
File metadata and controls
38 lines (28 loc) · 1018 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
#include "graph.h"
#include "xtensor/xio.hpp"
void Graph::add_layer(Layer *layer) {
this->layers.push_back(layer);
}
xt::xarray<double> Graph::run(xt::xarray<double> input) {
xt::xarray<double> result;
for (size_t i = 0; i < this->layers.size(); i++) {
result = this->layers[i]->forward(input);
input = result;
}
return result;
}
void Graph::backwards(xt::xarray<double> loss_grad) {
xt::xarray<double> grad = loss_grad;
for (size_t i = 0; i < this->layers.size(); i++) {
grad = this->layers[this->layers.size() - i - 1]->backward(grad);
}
}
void Graph::optimize(Loss *loss, xt::xarray<double> input, xt::xarray<double> target, size_t num_iter, bool print_loss) {
for (size_t i = 0; i < num_iter; i++) {
xt::xarray<double> result = this->run(input);
xt::xarray<double> loss_amt = loss->forward(result, target);
if (print_loss) std::cout << "Loss: " << loss_amt << std::endl;
xt::xarray<double> loss_grad = loss->backward();
this->backwards(loss_grad);
}
}