import jax.numpy as jnp
import numpy.random as npr
from jax import jit, grad, vmap
from jax.experimental.ode import odeint
import matplotlib.pyplot as plt

def mlp(params, inputs):
  # A multi-layer perceptron, i.e. a fully-connected neural network.
  for w, b in params:
    outputs = jnp.dot(inputs, w) + b  # Linear transform
    inputs = jnp.tanh(outputs)        # Nonlinearity
  return outputs

def resnet(params, inputs, depth):
  for i in range(depth):
    outputs = mlp(params, inputs) + inputs
  return outputs

# Toy 1D dataset.
inputs = jnp.reshape(jnp.linspace(-2.0, 2.0, 10), (10, 1))
targets = inputs**3 - 0.1 * inputs

# Hyperparameters.
layer_sizes = [1, 20, 1]
param_scale = 1.0
step_size = 0.01
train_iters = 1000

resnet_depth = 3
def resnet_squared_loss(params, inputs, targets):
  preds = resnet(params, inputs, resnet_depth)
  return jnp.mean(jnp.sum((preds - targets)**2, axis=1))

def init_random_params(scale, layer_sizes, rng=npr.RandomState(42)):
  return [(scale * rng.randn(m, n), scale * rng.randn(n))
          for m, n, in zip(layer_sizes[:-1], layer_sizes[1:])]

# A simple gradient-descent optimizer.
@jit
def resnet_update(params, inputs, targets):
  grads = grad(resnet_squared_loss)(params, inputs, targets)
  return [(w - step_size * dw, b - step_size * db)
          for (w, b), (dw, db) in zip(params, grads)]

# Initialize and train.
resnet_params = init_random_params(param_scale, layer_sizes)
for i in range(train_iters):
  resnet_params = resnet_update(resnet_params, inputs, targets)

def nn_dynamics(state, time, params):
  state_and_time = jnp.hstack([state, jnp.array(time)])
  return mlp(params, state_and_time)

def odenet(params, input):
  start_and_end_times = jnp.array([0.0, 1.0])
  init_state, final_state = odeint(nn_dynamics, input, start_and_end_times, params)
  return final_state

batched_odenet = vmap(odenet, in_axes=(None, 0))

# We need to change the input dimension to 2, to allow time-dependent dynamics.
odenet_layer_sizes = [2, 20, 1]

def odenet_loss(params, inputs, targets):
  preds = batched_odenet(params, inputs)
  return jnp.mean(jnp.sum((preds - targets)**2, axis=1))

@jit
def odenet_update(params, inputs, targets):
  grads = grad(odenet_loss)(params, inputs, targets)
  return [(w - step_size * dw, b - step_size * db)
          for (w, b), (dw, db) in zip(params, grads)]

# Initialize and train ODE-Net.
odenet_params = init_random_params(param_scale, odenet_layer_sizes)

for i in range(train_iters):
  odenet_params = odenet_update(odenet_params, inputs, targets)

fine_inputs = jnp.reshape(jnp.linspace(-3.0, 3.0, 100), (100, 1))

@jit
def odenet_times(params, input, times):
  def dynamics_func(state, time, params):
    return mlp(params, jnp.hstack([state, jnp.array(time)]))
  return odeint(dynamics_func, input, times, params)

times = jnp.linspace(0.0, 1.0, 200)

fig, axs = plt.subplots(figsize=(8, 4), nrows=1, ncols=2, dpi=150)
fig.subplots_adjust(wspace=0.4)  # Add space between the two plots

for i in inputs:
    axs[0].plot(odenet_times(odenet_params, i, times), times, lw=0.5)

axs[0].set_xlabel('parameters')
axs[0].set_ylabel('time')

for i in inputs:
        trajectory = [i]
        current = i
        for depth in range(resnet_depth):
                current = mlp(resnet_params, current) + current
                trajectory.append(current)
        
        trajectory = jnp.array(trajectory)
        depths = jnp.linspace(0, 1, len(trajectory))
        axs[1].plot(trajectory, depths, lw=0.5)
        # Add markers at each evaluation point
        axs[1].scatter(trajectory, depths, color='black', s=20)

axs[1].set_xlabel('parameters')
axs[1].set_ylabel('(rel) depth')

if True:
        plt.savefig('resnet-vs-node.png', dpi=150)

plt.show()
