"""
Unit tests for hive_keyboard.encoding module.
"""

import pytest
import networkx as nx
from hive_keyboard import encode_maxcut, create_hive_isomorphic_graph
from hive_keyboard.encoding import _maxcut_value


class TestCreateHiveIsomorphicGraph:
    """Tests for create_hive_isomorphic_graph function."""

    def test_default_graph_creation(self):
        """Test default graph creation (8 nodes, pattern 126)."""
        graph = create_hive_isomorphic_graph()

        assert len(graph.nodes) == 8
        assert len(graph.edges) > 0

    def test_custom_node_count(self):
        """Test graph creation with custom node count."""
        graph = create_hive_isomorphic_graph(n_nodes=16, target_pattern=32768)

        assert len(graph.nodes) == 16

    def test_graph_structure_for_pattern_126(self):
        """Test graph structure for pattern 126 (01111110)."""
        graph = create_hive_isomorphic_graph(n_nodes=8, target_pattern=126)

        # Pattern 126 = 01111110
        # Set A (1s): nodes 1,2,3,4,5,6
        # Set B (0s): nodes 0,7
        # Edges should connect across sets

        assert len(graph.nodes) == 8
        # Edges: 0-{1,2,3,4,5,6} and 7-{1,2,3,4,5,6} = 6+6 = 12 edges
        assert len(graph.edges) == 12

    def test_graph_structure_for_pattern_0(self):
        """Test graph structure for pattern 0 (all zeros)."""
        graph = create_hive_isomorphic_graph(n_nodes=8, target_pattern=0)

        # Pattern 0 = 00000000 (all in same set)
        # No edges connecting different sets
        assert len(graph.edges) == 0

    def test_graph_structure_for_pattern_255(self):
        """Test graph structure for pattern 255 (all ones)."""
        graph = create_hive_isomorphic_graph(n_nodes=8, target_pattern=255)

        # Pattern 255 = 11111111 (all in same set)
        # No edges connecting different sets
        assert len(graph.edges) == 0

    def test_graph_structure_for_alternating(self):
        """Test graph structure for alternating pattern (85 = 01010101)."""
        graph = create_hive_isomorphic_graph(n_nodes=8, target_pattern=85)

        # Pattern 85 = 01010101
        # Set A (1s): nodes 1,3,5,7
        # Set B (0s): nodes 0,2,4,6
        # Each node in A connects to all 4 nodes in B = 4*4 = 16 edges
        assert len(graph.edges) == 16


class TestEncodeMaxcut:
    """Tests for encode_maxcut function."""

    def test_encode_pattern_126(self):
        """Test encoding graph for pattern 126 returns optimal solution."""
        graph = create_hive_isomorphic_graph(n_nodes=8, target_pattern=126)
        encoded = encode_maxcut(graph)

        # Pattern 126 is successfully recovered by spectral encoding
        assert encoded == 126

    def test_encode_returns_valid_pattern(self):
        """Test that encoding always returns a valid pattern in range."""
        graph = create_hive_isomorphic_graph(n_nodes=8, target_pattern=51)
        encoded = encode_maxcut(graph)

        assert 0 <= encoded <= 255

    def test_encode_small_graph_padded(self):
        """Test encoding a 4-node graph pads to n_qubits."""
        graph = nx.cycle_graph(4)
        encoded = encode_maxcut(graph, n_qubits=8)

        # Result should be in 8-bit range
        assert 0 <= encoded <= 255

    def test_roundtrip_isomorphic_graph_126(self):
        """Test that isomorphic graph encoding matches target for P126."""
        target = 126
        graph = create_hive_isomorphic_graph(n_nodes=8, target_pattern=target)
        encoded = encode_maxcut(graph)

        # P126 roundtrips perfectly
        assert encoded == target

    def test_maxcut_quality_for_alternating(self):
        """Test that spectral encoding achieves optimal cut for alternating pattern."""
        target = 85  # 01010101
        graph = create_hive_isomorphic_graph(n_nodes=8, target_pattern=target)
        encoded = encode_maxcut(graph)

        # Both 85 and its complement 170 are optimal
        target_cut = _maxcut_value(graph, target, 8)
        encoded_cut = _maxcut_value(graph, encoded, 8)

        assert encoded_cut == target_cut  # Optimal

    def test_maxcut_quality_reasonable(self):
        """Test that spectral encoding achieves at least 50% of optimal cut."""
        target = 51  # A more challenging symmetric pattern
        graph = create_hive_isomorphic_graph(n_nodes=8, target_pattern=target)
        encoded = encode_maxcut(graph)

        target_cut = _maxcut_value(graph, target, 8)
        encoded_cut = _maxcut_value(graph, encoded, 8)

        # Spectral should achieve at least 50% of optimal
        assert encoded_cut >= target_cut * 0.5


class TestEncodeMaxcutRealGraphs:
    """Tests for encode_maxcut on realistic graph types.

    Note: The simple spectral encoding (adjacency eigenvector) works best for
    graphs with clear bipartite structure. For highly symmetric graphs (cycles,
    regular graphs), the principal eigenvector is constant and provides no
    partitioning information. More sophisticated methods (Laplacian Fiedler
    vector, SDP relaxation) would be needed for general MaxCut.
    """

    def test_cycle_graph_returns_valid_pattern(self):
        """Test encoding a cycle graph returns a valid pattern."""
        graph = nx.cycle_graph(8)
        encoded = encode_maxcut(graph)

        # Cycle graphs have constant eigenvectors; spectral method
        # may not find optimal, but should return valid pattern
        assert 0 <= encoded <= 255

    def test_complete_bipartite_graph(self):
        """Test encoding a complete bipartite graph K_{4,4}."""
        graph = nx.complete_bipartite_graph(4, 4)
        encoded = encode_maxcut(graph)

        # K_{4,4} has clear bipartite structure
        # Spectral method should find a reasonable cut
        cut_value = _maxcut_value(graph, encoded, 8)
        # At least some edges should be cut (not the degenerate solution)
        assert cut_value >= 4

    def test_petersen_graph_returns_valid(self):
        """Test encoding the Petersen graph returns valid pattern."""
        graph = nx.petersen_graph()
        encoded = encode_maxcut(graph, n_qubits=10)

        # Should return a valid 10-bit pattern
        assert 0 <= encoded <= 1023

    def test_empty_graph(self):
        """Test encoding an empty graph returns 0."""
        graph = nx.Graph()
        encoded = encode_maxcut(graph)

        assert encoded == 0

    def test_random_gnp_graph(self):
        """Test encoding a random G(n,p) graph."""
        # Use seed for reproducibility
        graph = nx.gnp_random_graph(8, 0.5, seed=42)
        encoded = encode_maxcut(graph)

        # Should return valid pattern with non-zero cut
        assert 0 <= encoded <= 255
        cut_value = _maxcut_value(graph, encoded, 8)
        # Random graphs typically have some edges cut
        assert cut_value > 0


class TestMaxcutValueHelper:
    """Tests for _maxcut_value helper function."""

    def test_maxcut_value_alternating(self):
        """Test MaxCut value calculation for alternating pattern."""
        graph = create_hive_isomorphic_graph(n_nodes=8, target_pattern=85)

        # Pattern 85 = 01010101, should cut all 16 edges
        value = _maxcut_value(graph, 85, 8)
        assert value == 16.0

    def test_maxcut_value_same_set(self):
        """Test MaxCut value when all nodes in same set."""
        graph = create_hive_isomorphic_graph(n_nodes=8, target_pattern=126)

        # Pattern 0 = 00000000, no cuts
        value = _maxcut_value(graph, 0, 8)
        assert value == 0.0

    def test_maxcut_value_complement_equal(self):
        """Test that pattern and complement have same MaxCut value."""
        graph = create_hive_isomorphic_graph(n_nodes=8, target_pattern=126)

        # 126 and 129 (complement) should give same cut
        value_126 = _maxcut_value(graph, 126, 8)
        value_129 = _maxcut_value(graph, 129, 8)
        assert value_126 == value_129
