"""
Hive Keyboard: Deterministic Quantum State Addressing.
(c) 2026 Subvurs Research. Proprietary and Confidential.
"""

from typing import Dict, List, Union
from .circuits import build_deterministic_circuit, get_circuit_info
from .backends import run_simulator, run_hardware, get_job_result
from .scaling import (
    calculate_inverse_n,
    generate_inverse_map,
    get_hive_layers_n,
    get_critical_patterns_n,
    get_scaling_stats,
)

try:
    from .hive_engine import inverse_engine, get_engine_layers, get_engine_critical
    _HAS_ENGINE = True
except ImportError:
    _HAS_ENGINE = False

def _get_layers():
    if _HAS_ENGINE:
        return get_engine_layers()
    return {0: {'name': 'L0', 'range': (0, 0), 'description': 'Ground'}}

def _get_critical():
    if _HAS_ENGINE:
        return get_engine_critical()
    return {0: 'P0'}

_LAYERS = _get_layers()
_CRITICAL = _get_critical()

# Public API exports (from engine)
HIVE_LAYERS = _LAYERS
CRITICAL_PATTERNS = _CRITICAL


class HiveKeyboard:
    """The Hive Keyboard: Deterministic Quantum State Addressing."""

    def __init__(self, backend: str = 'simulator', double_twist: bool = True,
                 auto_mitigation: bool = False):
        self.backend_name = backend
        self.double_twist = double_twist
        self.auto_mitigation = auto_mitigation
        self._init_map()

    def _init_map(self):
        """Initialize inverse mapping via engine."""
        if _HAS_ENGINE:
            self.inverse_map = {i: inverse_engine(i, 8) for i in range(256)}
        else:
            raise RuntimeError("Compiled engine not available")

    def _get_init(self, pattern_id: int) -> int:
        """Get initialization for a pattern."""
        return self.inverse_map[pattern_id]

    def target(self, pattern_id: int, shots: int = None,
               wait: bool = True, apply_mitigation: bool = None) -> Dict:
        """Target a specific pattern."""
        if pattern_id < 0 or pattern_id > 255:
            raise ValueError(f"Pattern must be 0-255, got {pattern_id}")

        if apply_mitigation is None:
            apply_mitigation = self.auto_mitigation

        if self.backend_name == 'simulator':
            shots = shots or 1024
            return run_simulator(pattern_id, self.inverse_map, shots, self.double_twist)
        else:
            shots = shots or 8192
            return run_hardware(pattern_id, self.inverse_map, self.backend_name,
                               shots, self.double_twist, wait, apply_mitigation=apply_mitigation)

    def sequence(self, pattern_ids: List[int], **kwargs) -> List[Dict]:
        """Run a sequence of patterns."""
        return [self.target(p, **kwargs) for p in pattern_ids]

    def creation_sequence(self, **kwargs) -> List[Dict]:
        """Run the standard Creation Sequence."""
        return self.sequence([51, 126, 155, 247], **kwargs)

    def get_circuit(self, pattern_id: int):
        """Return the circuit for a pattern."""
        init = self._get_init(pattern_id)
        return build_deterministic_circuit(init, double_twist=self.double_twist)

    def inverse(self, pattern_id: int) -> int:
        """Return the initialization key for a pattern."""
        return self._get_init(pattern_id)

    def layer(self, pattern_id: int, detailed: bool = False) -> Union[str, Dict]:
        """Get the layer information for a pattern."""
        for layer_num, info in _LAYERS.items():
            lo, hi = info['range']
            if lo <= pattern_id <= hi:
                if detailed:
                    return {
                        'layer': layer_num,
                        'name': info['name'],
                        'description': info['description'],
                        'range': info['range'],
                        'pattern': pattern_id,
                        'is_critical': pattern_id in _CRITICAL,
                        'critical_name': _CRITICAL.get(pattern_id),
                    }
                else:
                    c = f" [{_CRITICAL[pattern_id]}]" if pattern_id in _CRITICAL else ""
                    return f"Layer {layer_num} ({info['name']}): {info['description']}{c}"
        return "Unknown" if not detailed else {'layer': -1, 'name': 'Unknown'}

    def info(self, pattern_id: int) -> Dict:
        """Get information about a pattern."""
        init = self._get_init(pattern_id)
        circuit = build_deterministic_circuit(init, double_twist=self.double_twist)
        return {
            'pattern': pattern_id,
            'binary': format(pattern_id, '08b'),
            'init_pattern': init,
            'init_binary': format(init, '08b'),
            'layer': self.layer(pattern_id, detailed=True),
            'circuit': get_circuit_info(circuit),
            'is_critical': pattern_id in _CRITICAL,
        }

    def get_job(self, job_id: str, target_pattern: int = None) -> Dict:
        """Retrieve results for a hardware job."""
        return get_job_result(job_id, self.backend_name, target_pattern)

    @property
    def critical_patterns(self) -> Dict[int, str]:
        """Return critical patterns."""
        return _CRITICAL.copy()

    def __repr__(self):
        return f"HiveKeyboard(backend='{self.backend_name}', double_twist={self.double_twist})"


class HiveKeyboardN:
    """N-Qubit Hive Keyboard: Scalable Deterministic Quantum State Addressing."""

    def __init__(self, n_qubits: int = 8, backend: str = 'simulator',
                 double_twist: bool = True, auto_mitigation: bool = False):
        if n_qubits < 3:
            raise ValueError("Hive requires at least 3 qubits")

        self.n_qubits = n_qubits
        self.num_patterns = 1 << n_qubits
        self.max_pattern = self.num_patterns - 1
        self.backend_name = backend
        self.double_twist = double_twist
        self.auto_mitigation = auto_mitigation

        self._hive_layers = get_hive_layers_n(n_qubits)
        self._critical_patterns = get_critical_patterns_n(n_qubits)

        if n_qubits <= 16:
            self.inverse_map = generate_inverse_map(n_qubits)
        else:
            self.inverse_map = None

    def inverse(self, pattern_id: int) -> int:
        """Return the initialization key for a pattern."""
        if self.inverse_map is not None:
            return self.inverse_map[pattern_id]
        return calculate_inverse_n(pattern_id, self.n_qubits)

    def target(self, pattern_id: int, shots: int = None,
               wait: bool = True, apply_mitigation: bool = None) -> Dict:
        """Target a specific pattern."""
        if pattern_id < 0 or pattern_id > self.max_pattern:
            raise ValueError(f"Pattern must be 0-{self.max_pattern}, got {pattern_id}")

        if apply_mitigation is None:
            apply_mitigation = self.auto_mitigation

        init = self.inverse(pattern_id)

        if self.backend_name == 'simulator':
            shots = shots or 1024
            return self._run_sim(pattern_id, init, shots)
        else:
            raise NotImplementedError("Hardware execution for N > 8 qubits not yet available")

    def _run_sim(self, target: int, init: int, shots: int) -> Dict:
        """Run on simulator."""
        from qiskit_aer import AerSimulator

        qc = build_deterministic_circuit(init, n_qubits=self.n_qubits,
                                         double_twist=self.double_twist)
        sim = AerSimulator()
        job = sim.run(qc, shots=shots)
        counts = job.result().get_counts()

        target_bin = format(target, f'0{self.n_qubits}b')
        hits = counts.get(target_bin, 0)

        return {
            'target': target,
            'target_binary': target_bin,
            'init': init,
            'init_binary': format(init, f'0{self.n_qubits}b'),
            'hits': hits,
            'shots': shots,
            'fidelity': hits / shots * 100,
            'counts': counts,
            'backend': 'aer_simulator',
            'n_qubits': self.n_qubits,
            'status': 'COMPLETED'
        }

    def get_circuit(self, pattern_id: int):
        """Return the circuit for a pattern."""
        init = self.inverse(pattern_id)
        return build_deterministic_circuit(init, n_qubits=self.n_qubits,
                                           double_twist=self.double_twist)

    def layer(self, pattern_id: int, detailed: bool = False) -> Union[str, Dict]:
        """Get layer information for a pattern."""
        for layer_num, info in self._hive_layers.items():
            lo, hi = info['range']
            if lo <= pattern_id <= hi:
                is_crit = pattern_id in self._critical_patterns
                if detailed:
                    return {
                        'layer': layer_num,
                        'name': info['name'],
                        'description': info['description'],
                        'range': info['range'],
                        'pattern': pattern_id,
                        'is_critical': is_crit,
                        'critical_name': self._critical_patterns.get(pattern_id),
                    }
                else:
                    c = f" [{self._critical_patterns[pattern_id]}]" if is_crit else ""
                    return f"Layer {layer_num} ({info['name']}): {info['description']}{c}"
        return "Unknown" if not detailed else {'layer': -1, 'name': 'Unknown'}

    def info(self, pattern_id: int) -> Dict:
        """Get information about a pattern."""
        init = self.inverse(pattern_id)
        circuit = self.get_circuit(pattern_id)
        return {
            'pattern': pattern_id,
            'binary': format(pattern_id, f'0{self.n_qubits}b'),
            'init_pattern': init,
            'init_binary': format(init, f'0{self.n_qubits}b'),
            'n_qubits': self.n_qubits,
            'layer': self.layer(pattern_id, detailed=True),
            'circuit': get_circuit_info(circuit),
            'is_critical': pattern_id in self._critical_patterns,
        }

    def sequence(self, pattern_ids: List[int], **kwargs) -> List[Dict]:
        """Run a sequence of patterns."""
        return [self.target(p, **kwargs) for p in pattern_ids]

    def stats(self) -> Dict:
        """Get statistics about this Hive."""
        return get_scaling_stats(self.n_qubits)

    @property
    def hive_layers(self) -> Dict:
        """Return layer definitions."""
        return self._hive_layers.copy()

    @property
    def critical_patterns(self) -> Dict[int, str]:
        """Return critical patterns."""
        return self._critical_patterns.copy()

    def __repr__(self):
        return f"HiveKeyboardN(n_qubits={self.n_qubits}, backend='{self.backend_name}')"
