from math import gcd
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt

np.random.seed(234)

# Generate stochastic block model
n = 100  # Number of nodes
p = 0.1  # Probability of edge within a block
q = 0.01  # Probability of edge between blocks

# Use four communities
block_sizes = [n//4] * 4
block_probs = [[p if i == j else q for i in range(4)] for j in range(4)]

G = nx.stochastic_block_model(block_sizes, block_probs)

colors = ['r'] * block_sizes[0] + ['g'] * block_sizes[1] + ['b'] * block_sizes[2] + ['y'] * block_sizes[3]
# Draw the graph with small node size, using a spectral layout
plt.figure(figsize=(10, 5))

pos = nx.random_layout(G, seed=33)

# Define positions for each block
block_positions = {
    0: (-1, 0),
    1: (1, 0),
    2: (0, 1),
    3: (0, -1)
}

# Adjust node positions based on block positions
node_positions = {node: (pos[node][0] + block_positions[block][0], pos[node][1] + block_positions[block][1]) for node, block in nx.get_node_attributes(G, 'block').items()}

# Draw edges with different colors based on block connections
for u, v, d in G.edges(data=True):
    if G.nodes[u]['block'] == G.nodes[v]['block']:
        edge_color = 'black'  # Intra-block edge
    else:
        edge_color = 'grey'  # Extra-block edge
    nx.draw_networkx_edges(G, node_positions, edgelist=[(u, v)], edge_color=edge_color)

# Draw nodes with colors
nx.draw_networkx_nodes(G, node_positions, node_color=colors, node_size=20, alpha=0.8)

# Remove axis
plt.axis('off')

# plt.show()
plt.savefig("stochastic_block_model.png", dpi=300)

# Plot the adjacency matrix
plt.figure(figsize=(5, 5))
plt.matshow(nx.to_numpy_array(G), cmap='binary')

# Draw squares on diagonal blocks
for i in range(4):
    start = sum(block_sizes[:i])
    end = sum(block_sizes[:i+1])
    plt.plot([start, end, end, start, start], [start, start, end, end, start], color=colors[start], linewidth=2)
plt.savefig("stochastic_block_model_adjacency.png", dpi=300)
#plt.show()
