AlgoPlus v0.1.0
Loading...
Searching...
No Matches
mlp.h
1#pragma once
2
3#ifdef __cplusplus
4#include <cassert>
5#include <iostream>
6#include <vector>
7#include "../activation/activation_functions.h"
8#include "../metrics/metrics.h"
9#include "nn.h"
10#endif
11
18class MLP {
19 std::vector<std::vector<double>> data_;
20 std::vector<double> labels_;
21 std::vector<nn::Linear> seq_;
22 double binary_;
23 int epochs_;
24 double learning_rate_;
25
26 public:
36 explicit MLP(std::vector<std::vector<double>> const&, std::vector<std::pair<int, int>> const,
37 const int epochs = 100, const double learning_rate = 0.001);
38
42 void fit();
43
49 double predict(std::vector<double> const&);
50};
51
52inline MLP::MLP(std::vector<std::vector<double>> const& data,
53 std::vector<std::pair<int, int>> const arch, const int epochs,
54 const double learning_rate) {
55 assert(data.size() > 0);
56 assert(epochs > 0);
57 assert(learning_rate > 0);
58 assert(arch.size() > 0);
59 this->epochs_ = epochs;
60 this->data_ = data;
61 this->learning_rate_ = learning_rate;
62 this->binary_ = (arch.back().second == 1) ? true : false;
63 for (std::vector<double>& row : this->data_) {
64 this->labels_.push_back(row.back());
65 row.pop_back();
66 }
67
68 for (auto [in_features_, out_features_] : arch) {
69 assert(in_features_ > 0);
70 assert(out_features_ > 0);
71 this->seq_.push_back(nn::Linear(in_features_, out_features_, true));
72 }
73}
74
75inline void MLP::fit() {
76 for (int epoch = 0; epoch < this->epochs_; epoch++) {
77 std::vector<double> y_pred;
78 for (size_t i = 0; i < this->data_.size(); i++) {
79 std::vector<double> out_ = this->data_[i];
80 for (nn::Linear& layer : this->seq_) {
81 out_ = layer.forward(out_);
82 }
83
84 // double y_pred;
85 double y_pred_ = (out_[0] > 0.0) ? 1.0 : -1.0;
86 y_pred.push_back(y_pred_);
87 // TODO: Perform multiclass classification
88 // else {
89 // std::vector<double> logits = activation::softmax(out_);
90 // y_pred = std::max_element(logits.begin(), logits.end()) -
91 // logits.begin(); std::cout << y_pred << '\n';
92 // }
93 double err = y_pred_ - this->labels_[i];
94
95 if (err != 0) {
96 for (nn::Linear& layer : this->seq_) {
97 layer.update_weights(this->data_[i], err, this->learning_rate_);
98 }
99 }
100 }
101 std::cout << "Epoch: " << epoch + 1 << ": "
102 << "Accuracy: " << metrics::accuracy_score(this->labels_, y_pred)
103 << " | f1_score: " << metrics::f1_score(this->labels_, y_pred)
104 << " | Recall: " << metrics::recall(this->labels_, y_pred)
105 << " | Precision: " << metrics::precision(this->labels_, y_pred) << '\n';
106 }
107}
108
109inline double MLP::predict(std::vector<double> const& input) {
110 assert(input.size() == this->data_[0].size());
111 std::vector<double> out_ = input;
112 for (nn::Linear& layer : this->seq_) {
113 out_ = layer.forward(out_);
114 }
115
116 return (out_[0] > 0.0) ? 1.0 : -1.0;
117 // else {
118 // std::vector<double> logits = activation::softmax(out_);
119 // return std::max_element(logits.begin(), logits.end()) - logits.begin();
120 // }
121}
void fit()
fit an MLP on the input data
Definition mlp.h:75
double predict(std::vector< double > const &)
performs inference
Definition mlp.h:109
MLP(std::vector< std::vector< double > > const &, std::vector< std::pair< int, int > > const, const int epochs=100, const double learning_rate=0.001)
default constructor for MLP class
Definition mlp.h:52
Linear module. This implementation mostly follows PyTorch's implementation.
Definition nn.h:18
double f1_score(const std::vector< double > &y, const std::vector< double > &y_pred)
f1 score function: [2 * precision * recall / precision + recall]
Definition metrics.h:99
double accuracy_score(const std::vector< double > &y, const std::vector< double > &y_pred)
accuracy score function[(tp + tn) / (tp + tn + fp + fn)]
Definition metrics.h:81
double recall(const std::vector< double > &y, const std::vector< double > &y_pred)
recall function[tp / tp + fn]
Definition metrics.h:72
double precision(const std::vector< double > &y, const std::vector< double > &y_pred)
precision function[tp / tp + fp]
Definition metrics.h:90