MNP01/network.h
2018-03-11 12:48:48 +01:00

172 lines
4.5 KiB
C++

#ifndef NETWORK_H_
#define NETWORK_H_
#include "matrix.h"
#include <functional>
template <typename T, int in, int out>
struct layer {
using filler_t = std::function<T(const int&, const int&)>;
using combiner_t = std::function<T(const T&, const T&)>;
matrix<T, out, in> weights;
vector<T, out> bias;
vector<T, out> evaluate(vector<T, in> input)
{
return weights * input - bias;
}
void fill(filler_t weight_filler, filler_t bias_filler)
{
::fill(weights, weight_filler);
::fill(bias, bias_filler);
}
void fill(filler_t filler)
{
// use same filler for both
fill(filler, filler);
}
layer<T, in, out> combine(const layer<T, in, out> &rhs, combiner_t weight_combiner, combiner_t bias_combiner)
{
layer<T, in, out> result;
result.weights = ::combine(weights, rhs.weights, weight_combiner);
result.bias = ::combine(bias, rhs.bias, bias_combiner);
return result;
}
layer<T, in, out> combine(const layer<T, in, out> &rhs, combiner_t combiner)
{
return combine(rhs, combiner, combiner);
}
};
template <std::size_t N, typename T, int ...inputs> struct layer_types;
template <std::size_t N, typename T, int in, int out, int ...rest>
struct layer_types<N, T, in, out, rest...> : layer_types<N-1, T, out, rest...> { };
template<typename T, int in, int out, int ...rest>
struct layer_types<0, T, in, out, rest...> {
using type = layer<T, in, out>;
};
template <typename T, int ...layers> class network;
template <typename T, int in, int out>
class network<T, in, out> {
protected:
layer<T, in, out> current;
public:
static const std::size_t layer_count = 1;
using self = network<T, in, out>;
using output = vector<T, out>;
using input = vector<T, in>;
using filler_t = typename layer<T, in, out>::filler_t;
using combiner_t = typename layer<T, in, out>::combiner_t;
using normalizer_t = std::function<T(const T&)>;
template <std::size_t N> using layer_type = typename layer_types<N, T, in, out>::type;
normalizer_t normalizer;
network(normalizer_t normalizer) : normalizer(normalizer) { }
network() : network([](const T& value) { return value; }) { };
output evaluate(input inp)
{
return map(current.evaluate(inp), normalizer);
}
template <std::size_t N> layer_type<N>& get()
{
return current;
}
void fill(filler_t weights, filler_t bias)
{
current.fill(weights, bias);
}
void fill(filler_t filler)
{
current.fill(filler);
}
self combine(self& rhs, combiner_t weight_combiner, combiner_t bias_combiner) {
self result;
result.template get<0>() = get<0>().combine(rhs.template get<0>(), weight_combiner, bias_combiner);
return result;
}
self combine(self& rhs, combiner_t combiner)
{
return combine(rhs, combiner, combiner);
}
};
template <typename T, int in, int out, int ...layers>
class network<T, in, out, layers...> : public network<T, in, out> {
using base = network<T, in, out>;
using self = network<T, in, out, layers...>;
network<T, out, layers...> subnetwork;
public:
network(typename base::normalizer_t normalizer) : base(normalizer), subnetwork(normalizer) {}
network() : network([](const T& value) { return value; }) { };
using output = typename network<T, out, layers...>::output;
static const std::size_t layer_count = sizeof...(layers) + 1;
template <std::size_t N> using layer_type = typename layer_types<N, T, in, out, layers...>::type;
template <std::size_t N> layer_type<N>& get()
{
return subnetwork.template get<N-1>();
}
template<> layer_type<0>& get<0>()
{
return base::template get<0>();
}
output evaluate(typename base::input inp)
{
auto result = base::evaluate(inp);
return subnetwork.evaluate(result);
}
void fill(typename base::filler_t filler)
{
base::fill(filler);
subnetwork.fill(filler);
}
void fill(typename base::filler_t weights, typename base::filler_t bias)
{
base::fill(weights, bias);
subnetwork.fill(weights, bias);
}
self combine(self& rhs, typename base::combiner_t weight_combiner, typename base::combiner_t bias_combiner) {
self result;
result.template get<0>() = get<0>().combine(rhs.template get<0>(), weight_combiner, bias_combiner);
result.subnetwork = subnetwork.combine(rhs.subnetwork);
return result;
}
};
#endif