import os
import sys

import matplotlib.pyplot as plt
import numpy as np
import optax

from jaxparrow.cyclogeostrophy import _iterative, _variational
from jaxparrow.geostrophy import _geostrophy
from jaxparrow.tools.kinematics import magnitude
from jaxparrow.tools.operators import interpolation
from jaxparrow.tools.sanitize import init_land_mask

sys.path.extend([os.path.join(os.path.dirname(os.getcwd()), "tests")])
from tests import gaussian_eddy as ge  # noqa

%reload_ext
autoreload
%autoreload
2

Method validation in the idealized gaussian eddy scenario

We want to use a gaussian eddy for our functional tests, as analytical solutions can be derived in that setting.

The gaussian eddy we consider is of the form $\eta = \eta_0 \exp^{-(r/R_0)^2}$, with $R_0$ its radius, $\eta_0$ the SSH anomaly at its center, and $r$ the radial distance. We choose to use a constant spatial step in meters.

# Alboran sea settings
R0 = 50e3
ETA0 = .2
LAT = 36

dxy = 10e3

Simulating the eddy

X, Y, R, dXY, coriolis_factor, ssh, u_geos_t, v_geos_t, u_cyclo_t, v_cyclo_t, = ge.simulate_gaussian_eddy(
    R0,
    dxy,
    ETA0,
    LAT
)

We just make sure that the grids are correct.

_, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
ax1.set_title("X")
im = ax1.pcolormesh(X, shading="auto")
plt.colorbar(im, ax=ax1)
ax2.set_title("Y")
im = ax2.pcolormesh(Y, shading="auto")
plt.colorbar(im, ax=ax2)
plt.show()

png

_, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
ax1.set_title("R")
im = ax1.pcolormesh(X, Y, R, shading="auto")
plt.colorbar(im, ax=ax1)
ax2.set_title("ssh")
im = ax2.pcolormesh(X, Y, ssh, shading="auto")
plt.colorbar(im, ax=ax2)
plt.show()

png

Geostrophy

Analytical

$$u_g = 2y \frac{g \eta_0}{f R_0^2} \exp^{-(r/R_0)^2} = 2y \frac{g \eta}{f R_0^2}$$

$$v_g = -2x \frac{g \eta_0}{f R_0^2} \exp^{-(r/R_0)^2} = -2x \frac{g \eta}{f R_0^2}$$

azim_geos = magnitude(u_geos_t, v_geos_t, interpolate=False)
_, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(19, 5))
ax1.set_title("U")
im = ax1.pcolormesh(X, Y, u_geos_t, shading="auto")
plt.colorbar(im, ax=ax1)
ax2.set_title("V")
im = ax2.pcolormesh(X, Y, v_geos_t, shading="auto")
plt.colorbar(im, ax=ax2)
ax3.set_title("azimuthal")
im = ax3.pcolormesh(X, Y, azim_geos, shading="auto")
plt.colorbar(im, ax=ax3)
ax3.quiver(X[::5, ::5], Y[::5, ::5], u_geos_t[::5, ::5], v_geos_t[::5, ::5], color='k')
plt.show()

png

ax = plt.subplot()
ax.set_title("analytical geostrophy")
ax.set_xlabel("radial distance (m)")
ax.set_ylabel("azimuthal velocity (m/s)")
ax.scatter(R.flatten(), azim_geos.flatten(), s=1)
ax.vlines(R.flatten()[np.abs(azim_geos).flatten().argmax()],
          ymin=azim_geos.min(), ymax=azim_geos.max(), colors="r", linestyles="dashed")
plt.show()

png

Numerical

$f\mathbf{k} \times \mathbf{u_g} = -g \nabla \eta$

u_geos_est, v_geos_est = _geostrophy(ssh, dXY, dXY, coriolis_factor)

u_geos_est_t = interpolation(u_geos_est, axis=1, padding="left")
v_geos_est_t = interpolation(v_geos_est, axis=0, padding="left")

azim_geos_est = magnitude(u_geos_est_t, v_geos_est_t, interpolate=False)
_, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(19, 5))
ax1.set_title("U")
im = ax1.pcolormesh(X, Y, u_geos_est_t, shading="auto")
plt.colorbar(im, ax=ax1)
ax2.set_title("V")
im = ax2.pcolormesh(X, Y, v_geos_est_t, shading="auto")
plt.colorbar(im, ax=ax2)
ax3.set_title("azimuthal")
im = ax3.pcolormesh(X, Y, azim_geos_est, shading="auto")
plt.colorbar(im, ax=ax3)
ax3.quiver(X[::5, ::5], Y[::5, ::5], u_geos_est_t[::5, ::5], v_geos_est_t[::5, ::5], color='k')
plt.show()

png

_, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
ax1.set_title("numerical geostrophy")
ax1.set_xlabel("radial distance (m)")
ax1.set_ylabel("azimuthal velocity (m/s)")
ax1.scatter(R.flatten(), azim_geos_est.flatten(), s=1)
ax1.vlines(R.flatten()[np.abs(azim_geos_est).flatten().argmax()],
           ymin=azim_geos_est.min(), ymax=azim_geos_est.max(), colors="r", linestyles="dashed")
ax2.set_title("numerical error")
ax2.set_xlabel("radial distance (m)")
ax2.set_ylabel("absolute error (m/s)")
ax2.scatter(R.flatten(), azim_geos_est.flatten() - azim_geos.flatten(), s=1)
plt.show()

png

ge.compute_rmse(u_geos_t, u_geos_est_t), ge.compute_rmse(v_geos_t, v_geos_est_t)
(Array(0.0068815, dtype=float32), Array(0.0068815, dtype=float32))

Cyclogeostrophic azimuthal velocity

Analytical

$$V_{gr}=\frac{2V_g}{1+\sqrt{1+4V_g/(fr)}}$$

$$u_{gr} = u_g + sin(\theta) \frac{V_{gr}^2}{fr}$$ $$v_{gr} = v_g - cos(\theta) \frac{V_{gr}^2}{fr}$$

azim_cyclo = magnitude(u_cyclo_t, v_cyclo_t, interpolate=False)
_, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(19, 5))
ax1.set_title("U")
im = ax1.pcolormesh(X, Y, u_cyclo_t, shading="auto")
plt.colorbar(im, ax=ax1)
ax2.set_title("V")
im = ax2.pcolormesh(X, Y, v_cyclo_t, shading="auto")
plt.colorbar(im, ax=ax2)
ax3.set_title("azimuthal")
im = ax3.pcolormesh(X, Y, azim_cyclo, shading="auto")
plt.colorbar(im, ax=ax3)
ax3.quiver(X[::5, ::5], Y[::5, ::5], u_cyclo_t[::5, ::5], v_cyclo_t[::5, ::5], color='k')
plt.show()

png

ax = plt.subplot()
ax.set_title("analytical cyclogeostrophy")
ax.set_xlabel("radial distance (m)")
ax.set_ylabel("azimuthal velocity (m/s)")
ax.scatter(R.flatten(), azim_cyclo.flatten(), s=1)
ax.vlines(R.flatten()[np.abs(azim_cyclo).flatten().argmax()],
          ymin=azim_cyclo.min(), ymax=azim_cyclo.max(), colors="r", linestyles="dashed")
plt.show()

png

Numerical

$\mathbf{u} - \frac{\mathbf{k}}{f} \times (\mathbf{u} \cdot \nabla \mathbf{u}) = \mathbf{u_g}$

u_geos_u = u_geos_est
v_geos_v = v_geos_est
mask = init_land_mask(u_geos_t)

Variational estimation

optim = optax.sgd(learning_rate=5e-2)
u_cyclo_est, v_cyclo_est, _ = _variational(u_geos_u, v_geos_v, dXY, dXY, dXY, dXY,
                                           coriolis_factor, coriolis_factor, mask,
                                           n_it=20, optim=optim)

u_cyclo_est_t = interpolation(u_cyclo_est, axis=1, padding="left")
v_cyclo_est_t = interpolation(v_cyclo_est, axis=0, padding="left")

azim_cyclo_est = magnitude(u_cyclo_est_t, v_cyclo_est_t, interpolate=False)
_, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(19, 5))
ax1.set_title("U")
im = ax1.pcolormesh(X, Y, u_cyclo_est_t, shading="auto")
plt.colorbar(im, ax=ax1)
ax2.set_title("V")
im = ax2.pcolormesh(X, Y, v_cyclo_est_t, shading="auto")
plt.colorbar(im, ax=ax2)
ax3.set_title("azimuthal")
im = ax3.pcolormesh(X, Y, azim_cyclo_est, shading="auto")
plt.colorbar(im, ax=ax3)
ax3.quiver(X[::5, ::5], Y[::5, ::5], u_cyclo_est_t[::5, ::5], v_cyclo_est_t[::5, ::5], color='k')
plt.show()

png

_, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
ax1.set_title("variational cyclogeostrophy")
ax1.set_xlabel("radial distance (m)")
ax1.set_ylabel("azimuthal velocity (m/s)")
ax1.scatter(R.flatten(), azim_cyclo_est.flatten(), s=1)
ax1.vlines(R.flatten()[np.abs(azim_cyclo_est).flatten().argmax()],
           ymin=azim_cyclo_est.min(), ymax=azim_cyclo_est.max(), colors="r", linestyles="dashed")
ax2.set_title("numerical error")
ax2.set_xlabel("radial distance (m)")
ax2.set_ylabel("absolute error (m/s)")
ax2.scatter(R.flatten(), azim_cyclo_est.flatten() - azim_cyclo.flatten(), s=1)
plt.show()

png

ge.compute_rmse(u_cyclo_t, u_cyclo_est_t), ge.compute_rmse(v_cyclo_t, v_cyclo_est_t)
(Array(0.00562905, dtype=float32), Array(0.00562905, dtype=float32))

Iterative estimation

$\mathbf{u}^{(n+1)} = \mathbf{u_g} + \frac{\mathbf{k}}{f} \times (\mathbf{u}^{(n)} \cdot \nabla \mathbf{u}^{(n)})$

Ioannou

Use of a convolution filter when computing the residuals.

u_cyclo_est, v_cyclo_est, _ = _iterative(u_geos_u, v_geos_v, dXY, dXY, dXY, dXY,
                                         coriolis_factor, coriolis_factor, mask,
                                         n_it=20, res_eps=0.01,
                                         use_res_filter=True, res_filter_size=3,
                                         return_losses=False)

u_cyclo_est_t = interpolation(u_cyclo_est, axis=1, padding="left")
v_cyclo_est_t = interpolation(v_cyclo_est, axis=0, padding="left")

azim_cyclo_est = magnitude(u_cyclo_est_t, v_cyclo_est_t, interpolate=False)
_, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(19, 5))
ax1.set_title("U")
im = ax1.pcolormesh(X, Y, u_cyclo_est_t, shading="auto")
plt.colorbar(im, ax=ax1)
ax2.set_title("V")
im = ax2.pcolormesh(X, Y, v_cyclo_est_t, shading="auto")
plt.colorbar(im, ax=ax2)
ax3.set_title("azimuthal")
im = ax3.pcolormesh(X, Y, azim_cyclo_est, shading="auto")
plt.colorbar(im, ax=ax3)
ax3.quiver(X[::5, ::5], Y[::5, ::5], u_cyclo_est_t[::5, ::5], v_cyclo_est_t[::5, ::5], color='k')
plt.show()

png

_, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
ax1.set_title("iterative (filter) cyclogeostrophy")
ax1.set_xlabel("radial distance (m)")
ax1.set_ylabel("azimuthal velocity (m/s)")
ax1.scatter(R.flatten(), azim_cyclo_est.flatten(), s=1)
ax1.vlines(R.flatten()[np.abs(azim_cyclo_est).flatten().argmax()],
           ymin=azim_cyclo_est.min(), ymax=azim_cyclo_est.max(), colors="r", linestyles="dashed")
ax2.set_title("numerical error")
ax2.set_xlabel("radial distance (m)")
ax2.set_ylabel("absolute error (m/s)")
ax2.scatter(R.flatten(), azim_cyclo_est.flatten() - azim_cyclo.flatten(), s=1)
plt.show()

png

ge.compute_rmse(u_cyclo_t, u_cyclo_est_t), ge.compute_rmse(v_cyclo_t, v_cyclo_est_t)
(Array(0.00847729, dtype=float32), Array(0.00847729, dtype=float32))

Penven

No convolution filter, original approach.

u_cyclo_est, v_cyclo_est, _ = _iterative(u_geos_u, v_geos_v, dXY, dXY, dXY, dXY,
                                         coriolis_factor, coriolis_factor, mask,
                                         n_it=20, res_eps=0.01,
                                         use_res_filter=False, res_filter_size=1,
                                         return_losses=False)

u_cyclo_est_t = interpolation(u_cyclo_est, axis=1, padding="left")
v_cyclo_est_t = interpolation(v_cyclo_est, axis=0, padding="left")

azim_cyclo_est = magnitude(u_cyclo_est_t, v_cyclo_est_t, interpolate=False)
_, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(19, 5))
ax1.set_title("U")
im = ax1.pcolormesh(X, Y, u_cyclo_est_t, shading="auto")
plt.colorbar(im, ax=ax1)
ax2.set_title("V")
im = ax2.pcolormesh(X, Y, v_cyclo_est_t, shading="auto")
plt.colorbar(im, ax=ax2)
ax3.set_title("azimuthal")
im = ax3.pcolormesh(X, Y, azim_cyclo_est, shading="auto")
plt.colorbar(im, ax=ax3)
ax3.quiver(X[::5, ::5], Y[::5, ::5], u_cyclo_est_t[::5, ::5], v_cyclo_est_t[::5, ::5], color='k')
plt.show()

png

_, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
ax1.set_title("iterative cyclogeostrophy")
ax1.set_xlabel("radial distance (m)")
ax1.set_ylabel("azimuthal velocity (m/s)")
ax1.scatter(R.flatten(), azim_cyclo_est.flatten(), s=1)
ax1.vlines(R.flatten()[np.abs(azim_cyclo_est).flatten().argmax()],
           ymin=azim_cyclo_est.min(), ymax=azim_cyclo_est.max(), colors="r", linestyles="dashed")
ax2.set_title("numerical error")
ax2.set_xlabel("radial distance (m)")
ax2.set_ylabel("absolute error (m/s)")
ax2.scatter(R.flatten(), azim_cyclo_est.flatten() - azim_cyclo.flatten(), s=1)
plt.show()

png

ge.compute_rmse(u_cyclo_t, u_cyclo_est_t), ge.compute_rmse(v_cyclo_t, v_cyclo_est_t)
(Array(0.00861186, dtype=float32), Array(0.00861186, dtype=float32))