AlgoPlus v0.1.0
Loading...
Searching...
No Matches
nn::Linear Class Reference

Linear module. This implementation mostly follows PyTorch's implementation. More...

#include <nn.h>

Public Member Functions

 Linear (int, int, bool bias=false)
 Default constructor for nn::Linear class.
 
std::vector< double > forward (std::vector< double > const &)
 forward function: Forwards an input 1D tensor to the network
 
void update_weights (std::vector< double > const &, double, double)
 updates the weight vector by value
 

Detailed Description

Linear module. This implementation mostly follows PyTorch's implementation.

Constructor & Destructor Documentation

◆ Linear()

nn::Linear::Linear ( int in_features,
int out_features,
bool bias = false )
inlineexplicit

Default constructor for nn::Linear class.

Parameters
in_features(int)The input features
out_features(int)The output features
bias(bool)If set to true, then bias will be initialized with a uniform distribution on U(-1.0, 1.0)

Member Function Documentation

◆ forward()

std::vector< double > nn::Linear::forward ( std::vector< double > const & input_tensor)
inline

forward function: Forwards an input 1D tensor to the network

Parameters
input_tensor1D vector, the input tensor
Returns
1D vector(wT * x + bias)

◆ update_weights()

void nn::Linear::update_weights ( std::vector< double > const & input,
double error,
double learning_rate )
inline

updates the weight vector by value

Parameters
valuedouble, the value that will be added to weight vector

The documentation for this class was generated from the following file: