"""
Backend abstraction for Hive Keyboard.

Supports local simulation (Aer) and IBM Quantum hardware.
"""

import time
from typing import Dict, Optional, List, Union
from qiskit_aer import AerSimulator
from .circuits import build_deterministic_circuit


def run_simulator(target_pattern: int, inverse_map: Dict[int, int],
                  shots: int = 1024, double_twist: bool = True) -> Dict:
    """
    Run on local Aer simulator.

    Args:
        target_pattern: Target pattern number (0-255)
        inverse_map: Mapping from target -> init pattern
        shots: Number of measurement shots
        double_twist: Use Protocol 6.0 double twist

    Returns:
        Dict with target, init, fidelity, counts, etc.
    """
    init = inverse_map[target_pattern]
    qc = build_deterministic_circuit(init, double_twist=double_twist)

    simulator = AerSimulator()
    job = simulator.run(qc, shots=shots)
    counts = job.result().get_counts()

    target_binary = format(target_pattern, '08b')
    hits = counts.get(target_binary, 0)
    fidelity = hits / shots * 100

    return {
        'target': target_pattern,
        'target_binary': target_binary,
        'init': init,
        'init_binary': format(init, '08b'),
        'hits': hits,
        'shots': shots,
        'fidelity': fidelity,
        'counts': counts,
        'backend': 'aer_simulator',
        'status': 'COMPLETED'
    }


def run_hardware(target_pattern: int, inverse_map: Dict[int, int],
                 backend_name: str = 'ibm_torino',
                 shots: int = 8192,
                 double_twist: bool = True,
                 wait: bool = True,
                 timeout: int = 600,
                 apply_mitigation: bool = False) -> Dict:
    """
    Run on IBM Quantum hardware.

    Args:
        target_pattern: Target pattern number (0-255)
        inverse_map: Mapping from target -> init pattern
        backend_name: IBM backend name (e.g., 'ibm_torino', 'ibm_fez')
        shots: Number of measurement shots
        double_twist: Use Protocol 6.0 double twist
        wait: If True, wait for job completion and return results
              If False, return immediately with job_id
        timeout: Maximum wait time in seconds (default: 600 = 10 min)
        apply_mitigation: If True, apply M3 error mitigation (requires mthree)

    Returns:
        Dict with job info and (if wait=True) results
    """
    # Lazy import to avoid hard dependency if just using simulator
    try:
        from qiskit_ibm_runtime import QiskitRuntimeService, SamplerV2
        from qiskit.transpiler.preset_passmanagers import generate_preset_pass_manager
    except ImportError:
        raise ImportError(
            "qiskit-ibm-runtime is required for hardware backend. "
            "Install with: pip install qiskit-ibm-runtime"
        )

    init = inverse_map[target_pattern]
    qc = build_deterministic_circuit(init, double_twist=double_twist)

    # Connect to IBM Quantum
    service = QiskitRuntimeService()
    backend = service.backend(backend_name)

    # Transpile for hardware topology
    pm = generate_preset_pass_manager(backend=backend, optimization_level=1)
    transpiled = pm.run(qc)

    # Get transpiled circuit stats
    transpiled_depth = transpiled.depth()
    transpiled_gates = sum(transpiled.count_ops().values())

    # Submit job
    sampler = SamplerV2(backend)
    job = sampler.run([transpiled], shots=shots)
    job_id = job.job_id()

    result = {
        'target': target_pattern,
        'target_binary': format(target_pattern, '08b'),
        'init': init,
        'init_binary': format(init, '08b'),
        'job_id': job_id,
        'backend': backend_name,
        'shots': shots,
        'transpiled_depth': transpiled_depth,
        'transpiled_gates': transpiled_gates,
        'status': 'SUBMITTED'
    }

    if not wait:
        return result

    # Wait for job completion
    result['status'] = 'WAITING'
    start_time = time.time()

    while True:
        status = job.status()
        if status.name in ['DONE', 'ERROR', 'CANCELLED']:
            break
        if time.time() - start_time > timeout:
            result['status'] = 'TIMEOUT'
            result['error'] = f'Job did not complete within {timeout}s'
            return result
        time.sleep(2)  # Poll every 2 seconds

    if status.name == 'ERROR':
        result['status'] = 'ERROR'
        result['error'] = str(job.result())
        return result

    if status.name == 'CANCELLED':
        result['status'] = 'CANCELLED'
        return result

    # Extract results
    try:
        pub_result = job.result()[0]
        # SamplerV2 returns DataBin with 'meas' attribute
        data = pub_result.data
        if hasattr(data, 'meas'):
            # Get counts from the measurement data
            counts = data.meas.get_counts()
        elif hasattr(data, 'c'):
            counts = data.c.get_counts()
        else:
            # Fallback: try to get counts directly
            counts = pub_result.data.get_counts() if hasattr(pub_result.data, 'get_counts') else {}
    except Exception as e:
        result['status'] = 'ERROR'
        result['error'] = f'Failed to extract counts: {str(e)}'
        return result

    # Calculate raw fidelity
    target_binary = format(target_pattern, '08b')
    hits = counts.get(target_binary, 0)
    raw_fidelity = hits / shots * 100

    result['status'] = 'COMPLETED'
    result['counts'] = counts
    result['hits'] = hits
    result['raw_fidelity'] = raw_fidelity
    result['fidelity'] = raw_fidelity  # Will be updated if mitigation applied

    # Apply M3 mitigation if requested
    if apply_mitigation:
        try:
            from .mitigation import HiveMitigation

            mitigator = HiveMitigation(backend)
            mitigator.calibrate(qubits=list(range(8)))
            mitigated_fidelity = mitigator.mitigated_fidelity(counts, target_pattern)

            result['mitigated_fidelity'] = mitigated_fidelity
            result['fidelity'] = mitigated_fidelity
            result['mitigation_applied'] = True
        except ImportError:
            result['mitigation_applied'] = False
            result['mitigation_error'] = 'mthree not installed'
        except Exception as e:
            result['mitigation_applied'] = False
            result['mitigation_error'] = str(e)

    return result


def run_batch(target_patterns: List[int], inverse_map: Dict[int, int],
              backend: str = 'simulator',
              shots: int = 1024,
              double_twist: bool = True,
              **kwargs) -> List[Dict]:
    """
    Run multiple patterns in batch.

    Args:
        target_patterns: List of target pattern numbers
        inverse_map: Mapping from target -> init pattern
        backend: 'simulator' or IBM backend name
        shots: Number of shots per pattern
        double_twist: Use Protocol 6.0 double twist
        **kwargs: Additional arguments passed to run_hardware

    Returns:
        List of result dicts, one per pattern
    """
    results = []

    if backend == 'simulator':
        for target in target_patterns:
            result = run_simulator(target, inverse_map, shots, double_twist)
            results.append(result)
    else:
        for target in target_patterns:
            result = run_hardware(target, inverse_map, backend, shots,
                                  double_twist, **kwargs)
            results.append(result)

    return results


def get_job_result(job_id: str, backend_name: str = 'ibm_torino',
                   target_pattern: Optional[int] = None) -> Dict:
    """
    Retrieve results for a previously submitted job.

    Args:
        job_id: IBM job ID
        backend_name: IBM backend name
        target_pattern: Optional target pattern for fidelity calculation

    Returns:
        Dict with job results
    """
    try:
        from qiskit_ibm_runtime import QiskitRuntimeService
    except ImportError:
        raise ImportError("qiskit-ibm-runtime is required")

    service = QiskitRuntimeService()
    job = service.job(job_id)

    status = job.status()
    result = {
        'job_id': job_id,
        'backend': backend_name,
        'status': status.name
    }

    if status.name != 'DONE':
        return result

    try:
        pub_result = job.result()[0]
        data = pub_result.data
        if hasattr(data, 'meas'):
            counts = data.meas.get_counts()
        elif hasattr(data, 'c'):
            counts = data.c.get_counts()
        else:
            counts = {}

        result['counts'] = counts
        result['status'] = 'COMPLETED'

        if target_pattern is not None:
            target_binary = format(target_pattern, '08b')
            total_shots = sum(counts.values())
            hits = counts.get(target_binary, 0)
            result['fidelity'] = hits / total_shots * 100 if total_shots > 0 else 0
            result['target'] = target_pattern
            result['hits'] = hits
            result['shots'] = total_shots

    except Exception as e:
        result['status'] = 'ERROR'
        result['error'] = str(e)

    return result
