"""
M3 (Matrix-free Measurement Mitigation) wrapper for Hive Keyboard.

Protocol 6.0 achieves 97.96% fidelity with M3 mitigation (vs ~64% raw).
"""

import numpy as np
from typing import Dict, Optional, List

try:
    from mthree import M3Mitigation
    M3_AVAILABLE = True
except ImportError:
    M3_AVAILABLE = False


class HiveMitigation:
    """
    Measurement error mitigation for Hive Keyboard hardware runs.

    Uses IBM's M3 (Matrix-free Measurement Mitigation) to correct
    readout errors and achieve ~97% fidelity on real hardware.
    """

    def __init__(self, backend=None):
        """
        Initialize the mitigation system.

        Args:
            backend: IBM backend object (optional, can set later via calibrate())
        """
        if not M3_AVAILABLE:
            raise ImportError(
                "mthree is required for error mitigation. "
                "Install with: pip install mthree"
            )

        self.backend = backend
        self.mitigator = None
        self._calibrated = False
        self._calibration_qubits = None

    def calibrate(self, backend=None, qubits: Optional[List[int]] = None, shots: int = 8192):
        """
        Calibrate the mitigator for a specific backend and qubit set.

        Args:
            backend: IBM backend object
            qubits: List of qubit indices to calibrate (default: [0-7] for 8-qubit Hive)
            shots: Number of calibration shots (default: 8192)

        Returns:
            self for chaining
        """
        if backend is not None:
            self.backend = backend

        if self.backend is None:
            raise ValueError("No backend specified. Provide backend in __init__ or calibrate()")

        if qubits is None:
            qubits = list(range(8))  # Default 8-qubit Hive

        self.mitigator = M3Mitigation(self.backend)
        self.mitigator.cals_from_system(qubits, shots=shots)
        self._calibrated = True
        self._calibration_qubits = qubits

        return self

    def apply(self, counts: Dict[str, int], qubits: Optional[List[int]] = None) -> Dict[str, float]:
        """
        Apply error mitigation to raw measurement counts.

        Args:
            counts: Raw measurement counts from hardware {'bitstring': count}
            qubits: Qubit indices (must match calibration qubits)

        Returns:
            Mitigated quasi-probabilities {'bitstring': probability}
        """
        if not self._calibrated:
            raise RuntimeError("Mitigator not calibrated. Call calibrate() first.")

        if qubits is None:
            qubits = self._calibration_qubits

        # Apply M3 mitigation
        quasi_probs = self.mitigator.apply_correction(counts, qubits)

        # Convert to dict format
        return quasi_probs.nearest_probability_distribution()

    def mitigated_fidelity(self, counts: Dict[str, int], target_pattern: int,
                          qubits: Optional[List[int]] = None) -> float:
        """
        Calculate mitigated fidelity for a target pattern.

        Args:
            counts: Raw measurement counts
            target_pattern: Target pattern number (0-255)
            qubits: Qubit indices

        Returns:
            Mitigated fidelity as percentage (0-100)
        """
        mitigated = self.apply(counts, qubits)

        # Convert target to binary string (8-bit, matches measurement format)
        target_binary = format(target_pattern, '08b')

        # Get probability of target (may need to handle different bit orderings)
        prob = mitigated.get(target_binary, 0.0)

        return prob * 100.0

    @property
    def is_calibrated(self) -> bool:
        """Check if mitigator is calibrated."""
        return self._calibrated


def apply_m3_mitigation(counts: Dict[str, int], backend,
                        qubits: List[int] = None,
                        target_pattern: int = None) -> Dict:
    """
    Convenience function for one-shot M3 mitigation.

    Args:
        counts: Raw measurement counts
        backend: IBM backend object
        qubits: Qubit indices (default: [0-7])
        target_pattern: Optional target pattern for fidelity calculation

    Returns:
        Dict with 'mitigated_counts' and optionally 'fidelity'
    """
    if qubits is None:
        qubits = list(range(8))

    mitigator = HiveMitigation(backend)
    mitigator.calibrate(qubits=qubits)

    mitigated = mitigator.apply(counts, qubits)

    result = {
        'mitigated_counts': mitigated,
        'raw_counts': counts,
    }

    if target_pattern is not None:
        result['fidelity'] = mitigator.mitigated_fidelity(counts, target_pattern, qubits)

    return result
