"""
Subvurs Mute Button - Decoder Module

Decodes measurement results from noise-resistant encoded circuits.
"""

from typing import Dict, Tuple, List, Optional
from dataclasses import dataclass

# Import protected core
try:
    from ._core import (
        _decode_ratio,
        _decode_differential,
        _confidence_ratio,
        _confidence_differential,
        _get_decision_threshold,
        _expected_p1_zero,
        _expected_p1_one,
    )
    CORE_AVAILABLE = True
except ImportError:
    CORE_AVAILABLE = False


def _check_core():
    if not CORE_AVAILABLE:
        raise ImportError(
            "Core module not compiled. Run: python setup.py build_ext --inplace"
        )


@dataclass
class DecodeResult:
    """Result of decoding operation."""
    logical_value: int          # 0 or 1
    confidence: float           # 0.0 to 1.0
    raw_metric: float           # The underlying measurement (P1 or diff)
    mode: str                   # "ratio" or "differential"


class Decoder:
    """
    Subvurs Mute Button Decoder

    Decodes measurement results from Mute Button encoded circuits.

    Usage:
        decoder = Decoder()
        result = decoder.decode(counts)
        print(f"Logical value: {result.logical_value}")
        print(f"Confidence: {result.confidence:.2%}")
    """

    def __init__(self, mode: str = "ratio"):
        """
        Initialize decoder.

        Args:
            mode: "ratio" (default) or "differential"
        """
        _check_core()

        if mode not in ("ratio", "differential"):
            raise ValueError("Mode must be 'ratio' or 'differential'")

        self.mode = mode

        if mode == "ratio":
            self.n_qubits = 7
        else:
            self.n_qubits = 6

    def decode(self, counts: Dict[str, int]) -> DecodeResult:
        """
        Decode measurement counts to logical value.

        Args:
            counts: Dictionary of bitstring -> count from quantum measurement

        Returns:
            DecodeResult with logical value and confidence
        """
        if self.mode == "ratio":
            return self._decode_ratio(counts)
        else:
            return self._decode_differential(counts)

    def _decode_ratio(self, counts: Dict[str, int]) -> DecodeResult:
        """Decode using ratio method."""
        avg_p1 = self._calculate_avg_p1(counts, self.n_qubits)

        logical_value = _decode_ratio(avg_p1)
        confidence = _confidence_ratio(avg_p1)

        return DecodeResult(
            logical_value=logical_value,
            confidence=confidence,
            raw_metric=avg_p1,
            mode="ratio"
        )

    def _decode_differential(self, counts: Dict[str, int]) -> DecodeResult:
        """Decode using differential method."""
        group_a, group_b = self._calculate_group_p1(counts)

        logical_value = _decode_differential(group_a, group_b)
        confidence = _confidence_differential(group_a, group_b)

        diff = abs(group_a - group_b)

        return DecodeResult(
            logical_value=logical_value,
            confidence=confidence,
            raw_metric=diff,
            mode="differential"
        )

    def _calculate_avg_p1(
        self,
        counts: Dict[str, int],
        n_qubits: int
    ) -> float:
        """Calculate average P(|1⟩) across all qubits."""
        total = sum(counts.values())
        qubit_one_counts = [0] * n_qubits

        for bitstring, count in counts.items():
            # Qiskit uses little-endian, reverse for correct ordering
            for i, bit in enumerate(bitstring[::-1]):
                if i < n_qubits and bit == '1':
                    qubit_one_counts[i] += count

        qubit_probs = [c / total for c in qubit_one_counts]
        return sum(qubit_probs) / len(qubit_probs)

    def _calculate_group_p1(
        self,
        counts: Dict[str, int]
    ) -> Tuple[float, float]:
        """Calculate P(|1⟩) for Group A (0-2) and Group B (3-5)."""
        total = sum(counts.values())
        qubit_one_counts = [0] * self.n_qubits

        for bitstring, count in counts.items():
            for i, bit in enumerate(bitstring[::-1]):
                if i < self.n_qubits and bit == '1':
                    qubit_one_counts[i] += count

        qubit_probs = [c / total for c in qubit_one_counts]

        group_a = sum(qubit_probs[:3]) / 3
        group_b = sum(qubit_probs[3:6]) / 3

        return group_a, group_b

    def batch_decode(
        self,
        counts_list: List[Dict[str, int]]
    ) -> List[DecodeResult]:
        """
        Decode multiple measurement results.

        Args:
            counts_list: List of count dictionaries

        Returns:
            List of DecodeResults
        """
        return [self.decode(counts) for counts in counts_list]


# Convenience functions

def decode_ratio(counts: Dict[str, int]) -> DecodeResult:
    """Quick decode using ratio method."""
    return Decoder(mode="ratio").decode(counts)


def decode_differential(counts: Dict[str, int]) -> DecodeResult:
    """Quick decode using differential method."""
    return Decoder(mode="differential").decode(counts)
