AlgoPlus v0.1.0
Loading...
Searching...
No Matches
nn.h
1#pragma once
2
3#ifdef __cplusplus
4#include <cassert>
5#include <iostream>
6#include <optional>
7#include <random>
8#include <vector>
9#include "../../algorithms/math/multiply.h"
10#endif
11
12namespace nn {
13
18class Linear {
19 private:
20 std::vector<std::vector<double>> weight;
21 std::optional<double> bias;
22 int in_features_;
23 int out_features_;
24
25 public:
33 explicit Linear(int, int, bool bias = false);
34
40 std::vector<double> forward(std::vector<double> const&);
41
46 void update_weights(std::vector<double> const&, double, double);
47};
48} // namespace nn
49
50inline nn::Linear::Linear(int in_features, int out_features, bool bias)
51 : in_features_(in_features), out_features_(out_features) {
52 assert(in_features != 0);
53 assert(out_features != 0);
54 std::random_device rd;
55 std::mt19937 gen(rd());
56 std::uniform_real_distribution<double> dist(-1.0, 1.0);
57 this->weight =
58 std::vector<std::vector<double>>(out_features, std::vector<double>(in_features, 0.0));
59 for (auto& w_vec : this->weight) {
60 for (auto& w : w_vec) {
61 w = dist(gen);
62 }
63 }
64
65 if (bias) {
66 this->bias = dist(gen);
67 } else {
68 this->bias = std::nullopt;
69 }
70}
71
72inline std::vector<double> nn::Linear::forward(std::vector<double> const& input_tensor) {
73 std::vector<double> output(out_features_, 0.0);
74
75 for (int i = 0; i < this->out_features_; i++) {
76 for (int j = 0; j < this->in_features_; j++) {
77 output[i] += weight[i][j] * input_tensor[j];
78 }
79
80 if (bias.has_value()) {
81 output[i] += bias.value();
82 }
83 }
84
85 return output;
86}
87
88inline void nn::Linear::update_weights(std::vector<double> const& input, double error,
89 double learning_rate) {
90 for (int i = 0; i < this->out_features_; i++) {
91 for (int j = 0; j < this->in_features_; j++) {
92 weight[i][j] -= learning_rate * error * input[j];
93 }
94
95 if (bias.has_value()) {
96 bias = bias.value() - learning_rate * error;
97 }
98 }
99}
Linear(int, int, bool bias=false)
Default constructor for nn::Linear class.
Definition nn.h:50
std::vector< double > forward(std::vector< double > const &)
forward function: Forwards an input 1D tensor to the network
Definition nn.h:72
void update_weights(std::vector< double > const &, double, double)
updates the weight vector by value
Definition nn.h:88