#ifndef NETWORK_H_
#define NETWORK_H_

#include "matrix.h"
#include <functional>
#include <iostream>

template <typename T, int in, int out>
struct layer {
    using filler_t   = std::function<T(const int&, const int&)>;
    using combiner_t = std::function<T(const T&, const T&)>;
    using mutator_t  = std::function<T(const T&)>;

    matrix<T, out, in> weights;
    vector<T, out> bias;

    vector<T, out> evaluate(vector<T, in> input) 
    {
        return weights * input - bias;
    }
    
    void fill(filler_t weight_filler, filler_t bias_filler)
    {
        ::fill(weights, weight_filler);
        ::fill(bias,    bias_filler);
    }

    void fill(filler_t filler)
    {
        // use same filler for both
        fill(filler, filler);
    }

    layer<T, in, out> combine(const layer<T, in, out> &rhs, combiner_t weight_combiner, combiner_t bias_combiner)
    {
        layer<T, in, out> result;

        result.weights = ::combine(weights, rhs.weights, weight_combiner);
        result.bias = ::combine(bias, rhs.bias, bias_combiner);

        return result;
    }

    layer<T, in, out> combine(const layer<T, in, out> &rhs, combiner_t combiner)
    {
        return combine(rhs, combiner, combiner);
    }

    layer<T, in, out> mutate(mutator_t weight_mutator, mutator_t bias_mutator)
    {
        layer<T, in, out> result;

        result.weights = ::map(weights, weight_mutator);
        result.bias = ::map(bias, bias_mutator);

        return result;
    }

    layer<T, in, out> mutate(const layer<T, in, out> &rhs, mutator_t mutator)
    {
        return mutate(mutator, mutator);
    }

    std::ostream& save(std::ostream& stream) {
        weights.save(stream);
        bias.save(stream);

        return stream;
    }

    void load(std::istream& stream) 
    {
        weights.load(stream);
        bias.load(stream);
    }
};

template <std::size_t N, typename T, int ...inputs> struct layer_types;
template <std::size_t N, typename T, int in, int out, int ...rest> 
struct layer_types<N, T, in, out, rest...> : layer_types<N-1, T, out, rest...> { };

template<typename T, int in, int out, int ...rest>
struct layer_types<0, T, in, out, rest...> {
    using type = layer<T, in, out>;
};

template <typename T, int ...layers> class network;

template <typename T, int in, int out> 
class network<T, in, out> {
protected:
    layer<T, in, out> current;

public:
    static const std::size_t layer_count = 1;
    
    using self = network<T, in, out>;

    using output = vector<T, out>;
    using input  = vector<T, in>;

    using filler_t     = typename layer<T, in, out>::filler_t;
    using combiner_t   = typename layer<T, in, out>::combiner_t;
    using normalizer_t = std::function<T(const T&)>;
    using mutator_t    = std::function<T(const T&)>;

    template <std::size_t N> using layer_type = typename layer_types<N, T, in, out>::type;

    normalizer_t normalizer;

    network(normalizer_t normalizer) : normalizer(normalizer) { }
    network() : network([](const T& value) { return value; }) { };

    output evaluate(input inp)
    {
        return map(current.evaluate(inp), normalizer);
    }

    template <std::size_t N> layer_type<N>& get()
    {
        return current;
    }

    void fill(filler_t weights, filler_t bias)
    {
        current.fill(weights, bias);
    }

    void fill(filler_t filler)
    {
        current.fill(filler);
    }

    self combine(self& rhs, combiner_t weight_combiner, combiner_t bias_combiner) 
    {
        self result(normalizer);
        result.template get<0>() = get<0>().combine(rhs.template get<0>(), weight_combiner, bias_combiner);

        return result;
    }

    self combine(self& rhs, combiner_t combiner)
    {
        return combine(rhs, combiner, combiner);
    }

    self mutate(mutator_t weight_mutator, mutator_t bias_mutator) 
    {
        self result(normalizer);
        result.template get<0>() = get<0>().mutate(
                weight_mutator, 
                bias_mutator
        );

        return result;
    }

    self mutate(mutator_t mutator)
    {
        return mutate(mutator, mutator);
    }

    std::ostream& save(std::ostream& stream) 
    {
        current.save(stream);
        
        return stream;
    }

    void load(std::istream& stream)
    {
        current.load(stream);
    }
};

template <typename T, int in, int out, int ...layers> 
class network<T, in, out, layers...> : public network<T, in, out> {
    using base = network<T, in, out>;
    using self = network<T, in, out, layers...>;

    network<T, out, layers...> subnetwork;

public:
    network(typename base::normalizer_t normalizer) : base(normalizer), subnetwork(normalizer) {}
    network() : network([](const T& value) { return value; }) { };


    using output = typename network<T, out, layers...>::output;
    static const std::size_t layer_count = sizeof...(layers) + 1;

    template <std::size_t N> using layer_type = typename layer_types<N, T, in, out, layers...>::type;

    template <std::size_t N> layer_type<N>& get()
    {
        return subnetwork.template get<N-1>();
    }

    template<> layer_type<0>& get<0>()
    {
        return base::template get<0>();
    }

    output evaluate(typename base::input inp)
    {
        auto result = base::evaluate(inp);
        return subnetwork.evaluate(result);
    }

    void fill(typename base::filler_t filler)
    {
        base::fill(filler);
        subnetwork.fill(filler);
    }

    void fill(typename base::filler_t weights, typename base::filler_t bias)
    {
        base::fill(weights, bias);
        subnetwork.fill(weights, bias);
    }
    
    self combine(self& rhs, typename base::combiner_t weight_combiner, typename base::combiner_t bias_combiner) 
    {
        self result((this->normalizer));

        result.template get<0>() = get<0>().combine(
                rhs.template get<0>(), 
                weight_combiner, 
                bias_combiner
        );
        result.subnetwork = subnetwork.combine(rhs.subnetwork, weight_combiner, bias_combiner);

        return result;
    }

    self combine(self& rhs, typename base::combiner_t combiner)
    {
        return combine(rhs, combiner, combiner);
    }

    self mutate(typename base::mutator_t weight_mutator, typename base::mutator_t bias_mutator) 
    {
        self result((this->normalizer));

        result.template get<0>() = get<0>().mutate(
                weight_mutator, 
                bias_mutator
        );
        result.subnetwork = subnetwork.mutate(weight_mutator, bias_mutator);

        return result;
    }

    self mutate(typename base::mutator_t mutator)
    {
        return mutate(mutator, mutator);
    }

    std::ostream& save(std::ostream& stream)
    {
        base::save(stream);
        subnetwork.save(stream);
        
        return stream;
    }

    void load(std::istream& stream)
    {
        base::load(stream);
        subnetwork.load(stream);
    }
};

#endif