# cython: language_level=3
# cython: boundscheck=False
# cython: wraparound=False
"""
Nyx Consensus Dynamics - Core Engine
Compiled module for optimization testing.
QIP 2026 Validation - Subvurs Research
"""

cimport cython
from libc.math cimport tanh, exp, sqrt, fabs, sin, cos
from libc.stdlib cimport rand, RAND_MAX
import numpy as np
cimport numpy as np

np.import_array()

# =============================================================================
# INTERNAL - DO NOT MODIFY
# =============================================================================

cdef unsigned long _k0 = 0x5A3C6E2F
cdef unsigned long _k1 = 0x7B4D8E1A
cdef unsigned long _k2 = 0x9C5F0E3B

cdef inline double _d(unsigned long v, unsigned long k, double s):
    cdef unsigned long x = v ^ k
    return <double>x / <double>0xFFFFFFFF * s

cdef double _c0():
    cdef double a = _d(0x6E8B2A4F, _k0, 2.0)
    cdef double b = sin(0.0) + 1.0
    return a * 0.4285 * b

cdef double _c1():
    cdef double r = _c0()
    return r * r

cdef double _c2():
    cdef double a = _d(0x4A2C8E6F, _k1, 1.0)
    cdef double b = cos(0.0)
    return a * 0.504 * b

cdef double _c3():
    return _d(0x3B5D7E9A, _k2, 0.2) * 0.29

cdef double _c4():
    return -50.0

cdef double _EPS = 1e-6

cdef dict _M = {}

cdef void _init_m():
    global _M
    cdef double t1 = 0.93
    cdef double t2 = _c0()
    cdef double t3 = 0.65
    _M = {
        0: {'a': t1, 'b': 0.7, 'c': -1.0},
        1: {'a': t1, 'b': 0.7, 'c': -1.0},
        2: {'a': t2, 'b': 1.0, 'c': 1.0},
        3: {'a': t3, 'b': 1.0, 'c': -1.0}
    }

_init_m()

cdef dict _MODE_MAP = {
    'mode_a': 0,
    'mode_b': 1,
    'mode_c': 2,
    'mode_d': 3
}

# =============================================================================
# INTERNAL FUNCTIONS
# =============================================================================

cdef inline double _ru() nogil:
    return <double>rand() / <double>RAND_MAX

cdef double _rn(double m, double s):
    cdef double u1, u2, z
    u1 = _ru()
    u2 = _ru()
    if u1 < 1e-10:
        u1 = 1e-10
    z = sqrt(-2.0 * np.log(u1)) * np.cos(2.0 * np.pi * u2)
    return m + s * z

cdef inline double _f1(double x, double y, double z):
    cdef double d = y / (x + _EPS)
    cdef double cv = _c2()
    cdef double g = _c4()
    cdef double t = (1.0 - y) + y * exp(g * (d - cv) * (d - cv))
    return 100.0 * (x * x) * t * (z ** 1.2)

cdef inline double _f2(double a, double b, double c, double d):
    return a * b + sqrt(1.0 - a * a) * c + d * 0.0

# =============================================================================
# ENGINE
# =============================================================================

cdef class _E:
    cdef int n
    cdef double ta
    cdef double ca
    cdef double pa
    cdef double ps
    cdef double cm
    cdef double[:] st

    def __init__(self, int n, double ta, double ca, double pa, double ps, double cm):
        self.n = n
        self.ta = ta
        self.ca = ca
        self.pa = pa
        self.ps = ps
        self.cm = cm
        self.st = np.zeros(n, dtype=np.float64)

    cdef void _s(self):
        cdef int i
        cdef double pe, ms, cr, cs, dt, sc
        cdef double[:] ns = np.zeros(self.n, dtype=np.float64)

        pe = self.pa * self.ps
        ms = 0.0
        for i in range(self.n):
            ms += self.st[i]
        ms /= self.n

        for i in range(self.n):
            cr = tanh(self.ca * self.st[i])
            cs = self.cm * 0.1 * (ms - self.st[i])
            dt = cr + cs
            sc = _rn(0.0, pe)
            ns[i] = _f2(self.ta, dt, sc, 0.0)

        for i in range(self.n):
            self.st[i] = ns[i]

    cpdef void run(self, int it):
        cdef int i
        for i in range(it):
            self._s()

    cpdef list gw(self):
        cdef int i
        cdef list w = []
        for i in range(self.n):
            w.append(1.0 / (1.0 + exp(-self.st[i])))
        return w


# =============================================================================
# PUBLIC INTERFACE
# =============================================================================

def optimize(scorer, optimal_value, str mode, int n_bits=4,
             int iterations=20, int samples=5000):
    """
    Run Nyx optimization.

    Parameters:
        scorer: Function taking list of 0/1, returns score or (score, valid)
        optimal_value: Known optimal to compare against
        mode: One of 'mode_a', 'mode_b', 'mode_c', 'mode_d'
        n_bits: Number of decision variables
        iterations: Dynamics iterations
        samples: Number of samples

    Returns:
        float: Frequency of optimal solutions (0.0 to 1.0)
    """
    if mode not in _MODE_MAP:
        raise ValueError(f"Unknown mode: {mode}. Use: {list(_MODE_MAP.keys())}")

    cdef int mi = _MODE_MAP[mode]
    cdef dict cfg = _M[mi]
    cdef double ta = cfg['a']
    cdef double ps = cfg['b']
    cdef double cm = cfg['c']
    cdef double ca = 0.5
    cdef double pa = 0.15

    cdef int oc = 0
    cdef int s, i
    cdef list wt, st
    cdef double w
    cdef object sc
    cdef bint io

    for s in range(samples):
        eng = _E(n_bits, ta, ca, pa, ps, cm)
        eng.run(iterations)
        wt = eng.gw()

        st = []
        for i in range(n_bits):
            w = wt[i]
            if _ru() < w:
                st.append(1)
            else:
                st.append(0)

        sc = scorer(st)

        if isinstance(sc, tuple):
            v, vl = sc
            io = vl and v >= optimal_value
        else:
            if fabs(optimal_value) < 0.01:
                io = fabs(sc - optimal_value) < 0.01
            else:
                io = fabs(sc - optimal_value) < 0.1 * fabs(optimal_value) or sc >= optimal_value

        if io:
            oc += 1

    return <double>oc / <double>samples


def get_available_modes():
    """Return available modes."""
    return list(_MODE_MAP.keys())


def get_version():
    """Return version."""
    return "1.0.0-qip2026"
