#ifndef NETWORK_H_ #define NETWORK_H_ #include "matrix.h" #include #include template struct layer { using filler_t = std::function; using combiner_t = std::function; using mutator_t = std::function; matrix weights; vector bias; vector evaluate(vector input) { return weights * input - bias; } void fill(filler_t weight_filler, filler_t bias_filler) { ::fill(weights, weight_filler); ::fill(bias, bias_filler); } void fill(filler_t filler) { // use same filler for both fill(filler, filler); } layer combine(const layer &rhs, combiner_t weight_combiner, combiner_t bias_combiner) { layer result; result.weights = ::combine(weights, rhs.weights, weight_combiner); result.bias = ::combine(bias, rhs.bias, bias_combiner); return result; } layer combine(const layer &rhs, combiner_t combiner) { return combine(rhs, combiner, combiner); } layer mutate(mutator_t weight_mutator, mutator_t bias_mutator) { layer result; result.weights = ::map(weights, weight_mutator); result.bias = ::map(bias, bias_mutator); return result; } layer mutate(const layer &rhs, mutator_t mutator) { return mutate(mutator, mutator); } std::ostream& save(std::ostream& stream) { weights.save(stream); bias.save(stream); return stream; } void load(std::istream& stream) { weights.load(stream); bias.load(stream); } }; template struct layer_types; template struct layer_types : layer_types { }; template struct layer_types<0, T, in, out, rest...> { using type = layer; }; template class network; template class network { protected: layer current; public: static const std::size_t layer_count = 1; using self = network; using output = vector; using input = vector; using filler_t = typename layer::filler_t; using combiner_t = typename layer::combiner_t; using normalizer_t = std::function; using mutator_t = std::function; template using layer_type = typename layer_types::type; normalizer_t normalizer; network(normalizer_t normalizer) : normalizer(normalizer) { } network() : network([](const T& value) { return value; }) { }; output evaluate(input inp) { return map(current.evaluate(inp), normalizer); } template layer_type& get() { return current; } void fill(filler_t weights, filler_t bias) { current.fill(weights, bias); } void fill(filler_t filler) { current.fill(filler); } self combine(self& rhs, combiner_t weight_combiner, combiner_t bias_combiner) { self result(normalizer); result.template get<0>() = get<0>().combine(rhs.template get<0>(), weight_combiner, bias_combiner); return result; } self combine(self& rhs, combiner_t combiner) { return combine(rhs, combiner, combiner); } self mutate(mutator_t weight_mutator, mutator_t bias_mutator) { self result(normalizer); result.template get<0>() = get<0>().mutate( weight_mutator, bias_mutator ); return result; } self mutate(mutator_t mutator) { return mutate(mutator, mutator); } std::ostream& save(std::ostream& stream) { current.save(stream); return stream; } void load(std::istream& stream) { current.load(stream); } }; template class network : public network { using base = network; using self = network; network subnetwork; public: network(typename base::normalizer_t normalizer) : base(normalizer), subnetwork(normalizer) {} network() : network([](const T& value) { return value; }) { }; using output = typename network::output; static const std::size_t layer_count = sizeof...(layers) + 1; template using layer_type = typename layer_types::type; template typename std::enable_if>::type& get() { return subnetwork.template get(); } template typename std::enable_if>::type& get() { return base::template get(); } output evaluate(typename base::input inp) { auto result = base::evaluate(inp); return subnetwork.evaluate(result); } void fill(typename base::filler_t filler) { base::fill(filler); subnetwork.fill(filler); } void fill(typename base::filler_t weights, typename base::filler_t bias) { base::fill(weights, bias); subnetwork.fill(weights, bias); } self combine(self& rhs, typename base::combiner_t weight_combiner, typename base::combiner_t bias_combiner) { self result((this->normalizer)); result.template get<0>() = get<0>().combine( rhs.template get<0>(), weight_combiner, bias_combiner ); result.subnetwork = subnetwork.combine(rhs.subnetwork, weight_combiner, bias_combiner); return result; } self combine(self& rhs, typename base::combiner_t combiner) { return combine(rhs, combiner, combiner); } self mutate(typename base::mutator_t weight_mutator, typename base::mutator_t bias_mutator) { self result((this->normalizer)); result.template get<0>() = get<0>().mutate( weight_mutator, bias_mutator ); result.subnetwork = subnetwork.mutate(weight_mutator, bias_mutator); return result; } self mutate(typename base::mutator_t mutator) { return mutate(mutator, mutator); } std::ostream& save(std::ostream& stream) { base::save(stream); subnetwork.save(stream); return stream; } void load(std::istream& stream) { base::load(stream); subnetwork.load(stream); } }; #endif