"""
Unit tests for hive_keyboard.scaling module.
"""

import pytest
from hive_keyboard import (
    HiveKeyboardN,
    calculate_inverse_n,
    generate_inverse_map,
    get_hive_layers_n,
    get_critical_patterns_n,
    get_scaling_stats,
    verify_bijection,
)


class TestCalculateInverseN:
    """Tests for calculate_inverse_n function."""

    def test_8_qubit_matches_original(self):
        """Test that 8-qubit inverse matches original HiveKeyboard."""
        from hive_keyboard import HiveKeyboard

        hive = HiveKeyboard()
        for target in [0, 51, 126, 155, 255]:
            expected = hive.inverse(target)
            actual = calculate_inverse_n(target, 8)
            assert actual == expected, f"Mismatch at target {target}"

    def test_minimum_qubits(self):
        """Test minimum qubit count (3)."""
        result = calculate_inverse_n(7, 3)
        assert 0 <= result <= 7

    def test_invalid_qubit_count(self):
        """Test that < 3 qubits raises error."""
        with pytest.raises(ValueError):
            calculate_inverse_n(0, 2)

    def test_large_qubit_count(self):
        """Test inverse calculation for large qubit counts."""
        for n in [16, 32, 64]:
            max_val = (1 << n) - 1
            result = calculate_inverse_n(max_val // 2, n)
            assert 0 <= result <= max_val


class TestGenerateInverseMap:
    """Tests for generate_inverse_map function."""

    def test_map_size(self):
        """Test that map has correct size."""
        for n in [3, 4, 8]:
            inv_map = generate_inverse_map(n)
            assert len(inv_map) == (1 << n)

    def test_map_values_in_range(self):
        """Test that all values are in valid range."""
        inv_map = generate_inverse_map(8)
        max_val = 255
        assert all(0 <= v <= max_val for v in inv_map.values())


class TestVerifyBijection:
    """Tests for verify_bijection function."""

    def test_small_qubits_valid(self):
        """Test that small qubit counts produce valid bijections."""
        for n in [3, 4, 5, 6, 7, 8]:
            assert verify_bijection(n), f"{n}-qubit bijection failed"

    def test_medium_qubits_valid(self):
        """Test medium qubit counts."""
        for n in [10, 12]:
            assert verify_bijection(n), f"{n}-qubit bijection failed"


class TestGetHiveLayersN:
    """Tests for get_hive_layers_n function."""

    def test_layer_count(self):
        """Test that 8 layers are generated."""
        for n in [8, 16, 32]:
            layers = get_hive_layers_n(n)
            assert len(layers) == 8

    def test_layer_coverage(self):
        """Test that layers cover all patterns."""
        for n in [8, 16]:
            layers = get_hive_layers_n(n)
            max_pattern = (1 << n) - 1

            covered = set()
            for info in layers.values():
                low, high = info['range']
                covered.update(range(low, high + 1))

            assert 0 in covered
            assert max_pattern in covered

    def test_layer_structure(self):
        """Test that each layer has required fields."""
        layers = get_hive_layers_n(16)
        for layer_num, info in layers.items():
            assert 'name' in info
            assert 'range' in info
            assert 'description' in info


class TestGetCriticalPatternsN:
    """Tests for get_critical_patterns_n function."""

    def test_zero_always_critical(self):
        """Test that pattern 0 is always critical."""
        for n in [8, 16, 32]:
            patterns = get_critical_patterns_n(n)
            assert 0 in patterns
            # IP-protected: names are obfuscated as 'Pxx'
            assert patterns[0].startswith('P')

    def test_max_always_critical(self):
        """Test that max pattern is always critical."""
        for n in [8, 16, 32]:
            patterns = get_critical_patterns_n(n)
            max_pattern = (1 << n) - 1
            assert max_pattern in patterns
            assert patterns[max_pattern].startswith('P')

    def test_critical_patterns_exist(self):
        """Test that critical patterns exist."""
        for n in [8, 16, 32]:
            patterns = get_critical_patterns_n(n)
            # Should have at least some critical patterns
            assert len(patterns) > 0, f"No critical patterns for {n} qubits"


class TestGetScalingStats:
    """Tests for get_scaling_stats function."""

    def test_stats_structure(self):
        """Test that stats have expected keys."""
        stats = get_scaling_stats(8)

        assert 'n_qubits' in stats
        assert 'num_patterns' in stats
        assert 'num_layers' in stats
        assert 'inverse_complexity' in stats

    def test_pattern_count_correct(self):
        """Test that pattern count is 2^n."""
        for n in [8, 16, 32]:
            stats = get_scaling_stats(n)
            assert stats['num_patterns'] == (1 << n)


class TestHiveKeyboardN:
    """Tests for HiveKeyboardN class."""

    def test_8_qubit_100_fidelity(self):
        """Test 8-qubit achieves 100% fidelity."""
        hive = HiveKeyboardN(n_qubits=8)
        result = hive.target(126, shots=100)
        assert result['fidelity'] == 100.0

    def test_16_qubit_100_fidelity(self):
        """Test 16-qubit achieves 100% fidelity."""
        hive = HiveKeyboardN(n_qubits=16)
        core = hive.max_pattern // 2
        result = hive.target(core, shots=100)
        assert result['fidelity'] == 100.0

    def test_12_qubit_all_patterns(self):
        """Test all 4096 patterns in 12-qubit Hive."""
        hive = HiveKeyboardN(n_qubits=12)

        # Sample patterns across the space
        test_patterns = [0, 1000, 2048, 3000, 4095]
        for p in test_patterns:
            result = hive.target(p, shots=50)
            assert result['fidelity'] == 100.0, f"Pattern {p} failed"

    def test_invalid_pattern_range(self):
        """Test that invalid patterns raise ValueError."""
        hive = HiveKeyboardN(n_qubits=8)

        with pytest.raises(ValueError):
            hive.target(-1)

        with pytest.raises(ValueError):
            hive.target(256)

    def test_minimum_qubits(self):
        """Test minimum qubit count (3)."""
        hive = HiveKeyboardN(n_qubits=3)
        assert hive.num_patterns == 8
        result = hive.target(4, shots=100)
        assert result['fidelity'] == 100.0

    def test_invalid_qubit_count(self):
        """Test that < 3 qubits raises error."""
        with pytest.raises(ValueError):
            HiveKeyboardN(n_qubits=2)

    def test_layer_method(self):
        """Test layer() method works for N qubits."""
        hive = HiveKeyboardN(n_qubits=16)
        core = hive.max_pattern // 2
        layer_info = hive.layer(core)
        # IP-protected: layer names are obfuscated as 'Lx'
        assert 'Layer' in layer_info

    def test_info_method(self):
        """Test info() method works for N qubits."""
        hive = HiveKeyboardN(n_qubits=16)
        info = hive.info(32767)

        assert info['n_qubits'] == 16
        assert len(info['binary']) == 16
        assert 'circuit' in info

    def test_stats_method(self):
        """Test stats() method."""
        hive = HiveKeyboardN(n_qubits=16)
        stats = hive.stats()

        assert stats['n_qubits'] == 16
        assert stats['num_patterns'] == 65536

    def test_repr(self):
        """Test string representation."""
        hive = HiveKeyboardN(n_qubits=16)
        assert 'HiveKeyboardN' in repr(hive)
        assert '16' in repr(hive)
