# Re-import necessary libraries after execution state reset
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt

steps = True

np.random.seed(42)

# Step 1: Generate a 2D distribution that is neither normal nor uniform
# We'll use a mixture of Gaussians to create a clustered distribution
num_points = 200
centers = np.array([[0.3, 0.3], [0.7, 0.7], [0.3, 0.7], [0.7, 0.3]])  # Cluster centers
num_clusters = len(centers)
cluster_assignments = np.random.choice(num_clusters, num_points)  # Assign points to clusters
points = np.array([centers[c] + 0.05 * np.random.randn(2) for c in cluster_assignments])  # Clustered points

# Step 2: Define different scaling values for the Bernoulli edge probability
scale_values = [0.001, 0.01, 0.1]

# Create subplots to visualize the different graphs
fig, axes = plt.subplots(1, len(scale_values), figsize=(16, 5))

for idx, scale in enumerate(scale_values):
    G_var = nx.Graph()

    # Add nodes
    for i in range(num_points):
        G_var.add_node(i, pos=tuple(points[i]))

    if steps and idx == 0:
        # Draw the graph on another figure
        fig_temp, ax_temp = plt.subplots()
        pos_var = {i: tuple(points[i]) for i in range(num_points)}
        nx.draw(G_var, pos_var, node_size=50, node_color='black', ax=ax_temp)
        fig_temp.savefig("random-graph-models-scale-nodes.png", dpi=300)
        plt.close(fig_temp)

    # Add edges based on Bernoulli probability with varying scale
    for i in range(num_points):
        for j in range(i + 1, num_points):
            distance = np.linalg.norm(points[i] - points[j])
            prob = scale * np.exp(-distance**2)
            if np.random.rand() < prob:  # Bernoulli trial
                G_var.add_edge(i, j)

    if steps and idx == 0:
        # Draw the graph on another figure
        fig_temp, ax_temp = plt.subplots()
        pos_var = {i: tuple(points[i]) for i in range(num_points)}
        nx.draw(G_var, pos_var, node_size=50, node_color='black', ax=ax_temp)
        fig_temp.savefig("random-graph-models-scale-edges.png", dpi=300)
        plt.close(fig_temp)

    # Compute vector norms at each node
    vector_field = np.zeros_like(points)
    for i in range(num_points):
        neighbors = list(G_var.neighbors(i))
        if neighbors:
            neighbor_positions = np.array([points[j] for j in neighbors])
            vector_field[i] = np.mean(neighbor_positions - points[i], axis=0)  # Average displacement

    vector_norms = np.linalg.norm(points, axis=1)  # Compute norms

    # Draw the graph
    ax = axes[idx]
    pos_var = {i: tuple(points[i]) for i in range(num_points)}
    nx.draw(G_var, pos_var, node_color=vector_norms, cmap=plt.cm.plasma, node_size=50, edge_color='gray', ax=ax)

    if steps and idx == 0:
        # Draw the graph on another figure
        fig_temp, ax_temp = plt.subplots()
        pos_var = {i: tuple(points[i]) for i in range(num_points)}
        nx.draw(G_var, pos_var, node_color=vector_norms, cmap=plt.cm.plasma, node_size=50, edge_color='gray', ax=ax_temp)
        fig_temp.savefig("random-graph-models-scale-full.png", dpi=300)
        plt.close(fig_temp)

if True:
    plt.savefig("random-graph-models.png", dpi=300)

plt.show()
