namespace Kadet.Quantum.Tests
{
    open Microsoft.Quantum.Convert;
    open Microsoft.Quantum.Canon;
    open Microsoft.Quantum.Intrinsic;
    open Microsoft.Quantum.Measurement;

    open Kadet.Quantum.Util;
    open Kadet.Quantum.Grover;

    operation DatabaseOracle(value: Int, db: ((Qubit[], Qubit[]) => Unit is Adj), input: Qubit[], output: Qubit): Unit {
        using (scratch = Qubit[6]) {
            db(input, scratch);
            IsEqualOracle(scratch, value, output);
            Adjoint db(input, scratch);
        }
    }

    operation TestDatabaseSearch(database: Int[], value: Int): Int {
        mutable result = 0;
        let db = PrepareDatabase(Records(database));

        using (input = Qubit[4]) {
            ApplyToEach(H, input);

            GroverSearch(input, FlipingOracle(DatabaseOracle(value, db, _, _)));
            set result = ResultArrayAsInt(MultiM(input));

            ResetAll(input);
        }

        return result;
    }

    operation TestDatabaseSuperposition(database: Int[]): (Int, Int) {
        mutable result = (0, 0);

        using ((input, output) = (Qubit[4], Qubit[6])) {
            let db = PrepareDatabase(Records(database));
            
            ApplyToEach(H, input);
            db(input, output);

            set result = (
                ResultArrayAsInt(MultiM(input)),
                ResultArrayAsInt(MultiM(output))
            );

            ResetAll(input);
            ResetAll(output);
        }

        return result;
    }

    operation TestDatabase(database: Int[], index: Int): Int {
        mutable result = 0;


        using ((input, output) = (Qubit[4], Qubit[6])) {
            let db = PrepareDatabase(Records(database));

            InitFromInt(input, index);

            db(input, output);

            set result = ResultArrayAsInt(MultiM(output));

            ResetAll(input);
            ResetAll(output);
        }

        return result;
    }

    operation TestAddition(a: Int, b: Int) : Int {
        mutable result = 0;

        using ((aReg, bReg, sum) = (Qubit[5], Qubit[5], Qubit[6])) {
            ResetAll(aReg);
            ResetAll(bReg);
            ResetAll(sum);

            InitFromInt(aReg, a);
            InitFromInt(bReg, b);

            Add(aReg, bReg, sum);
            set result = ResultArrayAsInt(MultiM(sum));

            ResetAll(aReg);
            ResetAll(bReg);
            ResetAll(sum);
        }

        return result;
    }
}