172 lines
4.5 KiB
C++
172 lines
4.5 KiB
C++
#ifndef NETWORK_H_
|
|
#define NETWORK_H_
|
|
|
|
#include "matrix.h"
|
|
#include <functional>
|
|
|
|
template <typename T, int in, int out>
|
|
struct layer {
|
|
using filler_t = std::function<T(const int&, const int&)>;
|
|
using combiner_t = std::function<T(const T&, const T&)>;
|
|
|
|
matrix<T, out, in> weights;
|
|
vector<T, out> bias;
|
|
|
|
vector<T, out> evaluate(vector<T, in> input)
|
|
{
|
|
return weights * input - bias;
|
|
}
|
|
|
|
void fill(filler_t weight_filler, filler_t bias_filler)
|
|
{
|
|
::fill(weights, weight_filler);
|
|
::fill(bias, bias_filler);
|
|
}
|
|
|
|
void fill(filler_t filler)
|
|
{
|
|
// use same filler for both
|
|
fill(filler, filler);
|
|
}
|
|
|
|
layer<T, in, out> combine(const layer<T, in, out> &rhs, combiner_t weight_combiner, combiner_t bias_combiner)
|
|
{
|
|
layer<T, in, out> result;
|
|
|
|
result.weights = ::combine(weights, rhs.weights, weight_combiner);
|
|
result.bias = ::combine(bias, rhs.bias, bias_combiner);
|
|
|
|
return result;
|
|
}
|
|
|
|
layer<T, in, out> combine(const layer<T, in, out> &rhs, combiner_t combiner)
|
|
{
|
|
return combine(rhs, combiner, combiner);
|
|
}
|
|
};
|
|
|
|
template <std::size_t N, typename T, int ...inputs> struct layer_types;
|
|
template <std::size_t N, typename T, int in, int out, int ...rest>
|
|
struct layer_types<N, T, in, out, rest...> : layer_types<N-1, T, out, rest...> { };
|
|
|
|
template<typename T, int in, int out, int ...rest>
|
|
struct layer_types<0, T, in, out, rest...> {
|
|
using type = layer<T, in, out>;
|
|
};
|
|
|
|
template <typename T, int ...layers> class network;
|
|
|
|
template <typename T, int in, int out>
|
|
class network<T, in, out> {
|
|
protected:
|
|
layer<T, in, out> current;
|
|
|
|
public:
|
|
static const std::size_t layer_count = 1;
|
|
|
|
using self = network<T, in, out>;
|
|
|
|
using output = vector<T, out>;
|
|
using input = vector<T, in>;
|
|
|
|
using filler_t = typename layer<T, in, out>::filler_t;
|
|
using combiner_t = typename layer<T, in, out>::combiner_t;
|
|
using normalizer_t = std::function<T(const T&)>;
|
|
|
|
template <std::size_t N> using layer_type = typename layer_types<N, T, in, out>::type;
|
|
|
|
normalizer_t normalizer;
|
|
|
|
network(normalizer_t normalizer) : normalizer(normalizer) { }
|
|
network() : network([](const T& value) { return value; }) { };
|
|
|
|
output evaluate(input inp)
|
|
{
|
|
return map(current.evaluate(inp), normalizer);
|
|
}
|
|
|
|
template <std::size_t N> layer_type<N>& get()
|
|
{
|
|
return current;
|
|
}
|
|
|
|
void fill(filler_t weights, filler_t bias)
|
|
{
|
|
current.fill(weights, bias);
|
|
}
|
|
|
|
void fill(filler_t filler)
|
|
{
|
|
current.fill(filler);
|
|
}
|
|
|
|
self combine(self& rhs, combiner_t weight_combiner, combiner_t bias_combiner) {
|
|
self result;
|
|
result.template get<0>() = get<0>().combine(rhs.template get<0>(), weight_combiner, bias_combiner);
|
|
|
|
return result;
|
|
}
|
|
|
|
self combine(self& rhs, combiner_t combiner)
|
|
{
|
|
return combine(rhs, combiner, combiner);
|
|
}
|
|
};
|
|
|
|
template <typename T, int in, int out, int ...layers>
|
|
class network<T, in, out, layers...> : public network<T, in, out> {
|
|
using base = network<T, in, out>;
|
|
using self = network<T, in, out, layers...>;
|
|
|
|
network<T, out, layers...> subnetwork;
|
|
|
|
public:
|
|
network(typename base::normalizer_t normalizer) : base(normalizer), subnetwork(normalizer) {}
|
|
network() : network([](const T& value) { return value; }) { };
|
|
|
|
|
|
using output = typename network<T, out, layers...>::output;
|
|
static const std::size_t layer_count = sizeof...(layers) + 1;
|
|
|
|
template <std::size_t N> using layer_type = typename layer_types<N, T, in, out, layers...>::type;
|
|
|
|
template <std::size_t N> layer_type<N>& get()
|
|
{
|
|
return subnetwork.template get<N-1>();
|
|
}
|
|
|
|
template<> layer_type<0>& get<0>()
|
|
{
|
|
return base::template get<0>();
|
|
}
|
|
|
|
output evaluate(typename base::input inp)
|
|
{
|
|
auto result = base::evaluate(inp);
|
|
return subnetwork.evaluate(result);
|
|
}
|
|
|
|
void fill(typename base::filler_t filler)
|
|
{
|
|
base::fill(filler);
|
|
subnetwork.fill(filler);
|
|
}
|
|
|
|
void fill(typename base::filler_t weights, typename base::filler_t bias)
|
|
{
|
|
base::fill(weights, bias);
|
|
subnetwork.fill(weights, bias);
|
|
}
|
|
|
|
self combine(self& rhs, typename base::combiner_t weight_combiner, typename base::combiner_t bias_combiner) {
|
|
self result;
|
|
|
|
result.template get<0>() = get<0>().combine(rhs.template get<0>(), weight_combiner, bias_combiner);
|
|
result.subnetwork = subnetwork.combine(rhs.subnetwork);
|
|
|
|
return result;
|
|
}
|
|
};
|
|
|
|
#endif
|