Skip to content

Pseudo-SWOT observations from eNATL60 model data

Show/hide code
%load_ext autoreload
%autoreload 2
Show/hide code
import cartopy.crs as ccrs
import cmocean.cm as cmo
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
import xarray as xr
import xesmf as xe

from jaxparrow.cyclogeostrophy import fixed_point, minimization_based
from jaxparrow.utils import geometry, kinematics


jax.config.update("jax_enable_x64", True)

Instead of considering idealized scenarios, we can use model data as validation material. To make the demonstration closer to real data, we will generate pseudo-SWOT swath observations from the submesoscale-permitting eNATL60 simulation.

Show/hide code
def get_karin_mask(swot_ds):
    return (np.abs(swot_ds.cross_track_distance) <= 60.0) & (np.abs(swot_ds.cross_track_distance) >= 10.0)


def get_left_swath_mask(swot_ds):
    """Return mask for left swath (cross_track_distance < -10 km)."""
    return (swot_ds.cross_track_distance <= -10.0) & (swot_ds.cross_track_distance >= -60.0)


def get_right_swath_mask(swot_ds):
    """Return mask for right swath (cross_track_distance > 10 km)."""
    return (swot_ds.cross_track_distance >= 10.0) & (swot_ds.cross_track_distance <= 60.0)


def split_swath_to_left_right(da, swot_ds):
    """
    Split a DataArray into left and right swath parts.

    Parameters
    ----------
    da : xr.DataArray
        Data array on the full SWOT grid (num_lines, num_pixels)
    swot_ds : xr.Dataset
        SWOT dataset with cross_track_distance

    Returns
    -------
    da_left, da_right, left_pixels, right_pixels : xr.DataArray, array
        Left and right swath parts, selecting only the valid pixels for each side
    """
    # Get pixel indices for left and right swaths
    cross_track = swot_ds.cross_track_distance.max(dim="num_lines")  # Same for all lines, use max to avoid NaNs
    left_pixels = np.argwhere((cross_track <= -10.0).values).ravel()
    right_pixels = np.argwhere((cross_track >= 10.0).values).ravel()

    da_left = da.isel(num_pixels=left_pixels)
    da_right = da.isel(num_pixels=right_pixels)

    return da_left, da_right, left_pixels, right_pixels


def combine_left_right_results(arr_left, arr_right, n_pixels_total, left_indices, right_indices):
    """
    Combine left and right swath results back into a full array.

    Parameters
    ----------
    arr_left : jnp.ndarray
        Array for left swath (num_lines, n_left_pixels)
    arr_right : jnp.ndarray
        Array for right swath (num_lines, n_right_pixels)
    n_pixels_total : int
        Total number of pixels in the original full array
    left_indices : array
        Pixel indices for left swath
    right_indices : array
        Pixel indices for right swath

    Returns
    -------
    arr_full : jnp.ndarray
        Combined array with NaN in nadir gap
    """
    n_lines = arr_left.shape[0]
    arr_full = jnp.full((n_lines, n_pixels_total), jnp.nan)
    arr_full = arr_full.at[:, left_indices].set(arr_left)
    arr_full = arr_full.at[:, right_indices].set(arr_right)
    return arr_full


def get_regridder(enatl60_ds, swot_ds):
    regridder = xe.Regridder(enatl60_ds, swot_ds, "bilinear")
    return regridder


def regrid_enatl60_variable(enatl60_mb_da, regridder):
    regridded_da = regridder(enatl60_mb_da)
    return regridded_da


def arr_to_da(arr, lat, lon, name, units=None):
    attrs = {"long_name": name}
    if units is not None:
        attrs["units"] = units

    return xr.DataArray(
        arr,
        dims=("y", "x"),
        coords={"lat": (("y", "x"), lat.values), "lon": (("y", "x"), lon.values)},
        attrs=attrs
    )
Show/hide code
ENATL60_PATH = "data/eNATL60"
SWOT_PATH = "data/SWOT"

Produce pseudo-SWOT SSH data from eNATL60

This is done by regridding eNATL60 data onto the SWOT KaRin-swath grid. Notes:

  • use of several masks to avoid regridding over land or outside of SWOT bounding box,
  • linear and nearest-neighbor reinterpolation of eNATL60 latitude and longitude whose values are (0, 0) to avoid potential regridding artifacts,
  • separate regridding for the T, U, and V grids.
Show/hide code
swot_003_ds = xr.open_dataset(f"{SWOT_PATH}/SWOT_L3_LR_SSH_Expert_003_v2.0.1-med.nc")
karin_003_mask = get_karin_mask(swot_003_ds)
swot_016_ds = xr.open_dataset(f"{SWOT_PATH}/SWOT_L3_LR_SSH_Expert_016_v2.0.1-med.nc")
karin_016_mask = get_karin_mask(swot_016_ds)

enatl60_t_ds = xr.open_dataset(f"{ENATL60_PATH}/eNATL60-BLB002_y2009m08d15.1d_SSH-med.nc")
enatl60_u_ds = xr.open_dataset(f"{ENATL60_PATH}/eNATL60-BLB002_y2009m08d15.1d_SSU-med.nc")
enatl60_v_ds = xr.open_dataset(f"{ENATL60_PATH}/eNATL60-BLB002_y2009m08d15.1d_SSV-med.nc")

enatl60_lat_t = jnp.asarray(enatl60_t_ds.lat)
enatl60_lon_t = jnp.asarray(enatl60_t_ds.lon)
enatl60_lat_u, enatl60_lon_u, enatl60_lat_v, enatl60_lon_v = geometry.compute_uv_grids(enatl60_lat_t, enatl60_lon_t)

enatl60_t_ds["uv"] = (
    ("y", "x"),
    np.asarray(
        kinematics.magnitude(
            jnp.asarray(enatl60_u_ds.sozocrtx.values), jnp.asarray(enatl60_v_ds.somecrty.values), uv_on_t=False
        )
    ),
    {"units": "m/s", "long_name": "surface current (eNATL60)"}
)
enatl60_t_ds["nrv"] = (
    ("y", "x"),
    np.asarray(
        kinematics.vorticity(
            jnp.asarray(enatl60_u_ds.sozocrtx.values), jnp.asarray(enatl60_v_ds.somecrty.values),
            enatl60_lat_u, enatl60_lon_u, enatl60_lat_v, enatl60_lon_v,
            uv_on_t=False
        )
    ),
    {"long_name": "normalized relative vorticity (eNATL60)"}
)

We consider the Balearic Sea on the 15 August 2009.

Show/hide code
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, subplot_kw={"projection": ccrs.PlateCarree()}, figsize=(14, 5))

fig.suptitle("Original eNATL60 surface fields")

ax1.set_title("$\\eta$")
im = ax1.pcolormesh(enatl60_t_ds.lon, enatl60_t_ds.lat, enatl60_t_ds.sossheig, cmap=cmo.deep_r)
ax1.coastlines()
cbar = plt.colorbar(im, ax=ax1)
cbar.ax.set_title("m")

ax2.set_title("$\|\\mathbf{u}\|$")
im = ax2.pcolormesh(enatl60_t_ds.lon, enatl60_t_ds.lat, enatl60_t_ds.uv, cmap=cmo.speed, vmin=0)
ax2.coastlines()
cbar = plt.colorbar(im, ax=ax2)
cbar.ax.set_title("m/s")

ax3.set_title("$\\zeta / f$")
im = ax3.pcolormesh(enatl60_t_ds.lon, enatl60_t_ds.lat, enatl60_t_ds.nrv, cmap=cmo.curl, vmin=-1, vmax=1)
ax3.coastlines()
cbar = plt.colorbar(im, ax=ax3, extend="both")

fig.tight_layout()
plt.show()
No description has been provided for this image
Show/hide code
t_003_regridder = get_regridder(enatl60_t_ds, swot_003_ds)
u_003_regridder = get_regridder(enatl60_u_ds, swot_003_ds)
v_003_regridder = get_regridder(enatl60_v_ds, swot_003_ds)

t_016_regridder = get_regridder(enatl60_t_ds, swot_016_ds)
u_016_regridder = get_regridder(enatl60_u_ds, swot_016_ds)
v_016_regridder = get_regridder(enatl60_v_ds, swot_016_ds)

# Regrid SSH onto SWOT grid and apply KaRin mask
ssh_003_regridded = regrid_enatl60_variable(enatl60_t_ds.sossheig, t_003_regridder)
ssh_016_regridded = regrid_enatl60_variable(enatl60_t_ds.sossheig, t_016_regridder)

# Apply KaRin mask
ssh_003_da = ssh_003_regridded.where(karin_003_mask)
ssh_016_da = ssh_016_regridded.where(karin_016_mask)

# Split into left and right swaths (no gap filling needed)
ssh_003_left_da, ssh_003_right_da, left_pixels_003, right_pixels_003 = split_swath_to_left_right(
    ssh_003_da, swot_003_ds
)
ssh_016_left_da, ssh_016_right_da, left_pixels_016, right_pixels_016 = split_swath_to_left_right(
    ssh_016_da, swot_016_ds
)

# Transpose for jaxparrow input
ssh_003_left_da = ssh_003_left_da.transpose("num_lines", "num_pixels")
ssh_003_right_da = ssh_003_right_da.transpose("num_lines", "num_pixels")
ssh_016_left_da = ssh_016_left_da.transpose("num_lines", "num_pixels")
ssh_016_right_da = ssh_016_right_da.transpose("num_lines", "num_pixels")

swot_003_lat_t = jnp.asarray(swot_003_ds.lat)
swot_003_lon_t = jnp.asarray(swot_003_ds.lon)
swot_003_lat_left = swot_003_lat_t[:, left_pixels_003]
swot_003_lon_left = swot_003_lon_t[:, left_pixels_003]
swot_003_lat_right = swot_003_lat_t[:, right_pixels_003]
swot_003_lon_right = swot_003_lon_t[:, right_pixels_003]

swot_016_lat_t = jnp.asarray(swot_016_ds.lat)
swot_016_lon_t = jnp.asarray(swot_016_ds.lon)
swot_016_lat_left = swot_016_lat_t[:, left_pixels_016]
swot_016_lon_left = swot_016_lon_t[:, left_pixels_016]
swot_016_lat_right = swot_016_lat_t[:, right_pixels_016]
swot_016_lon_right = swot_016_lon_t[:, right_pixels_016]

# Store total number of pixels for recombination
n_pixels_003 = swot_003_ds.dims["num_pixels"]
n_pixels_016 = swot_016_ds.dims["num_pixels"]

# Validation data: keep KaRin mask (no gap filling)
uv_003_da = regrid_enatl60_variable(enatl60_t_ds.uv, t_003_regridder).where(karin_003_mask)
nrv_003_da = regrid_enatl60_variable(enatl60_t_ds.nrv, t_003_regridder).where(karin_003_mask)
u_003_da = regrid_enatl60_variable(enatl60_u_ds.sozocrtx, u_003_regridder).where(karin_003_mask)
v_003_da = regrid_enatl60_variable(enatl60_v_ds.somecrty, v_003_regridder).where(karin_003_mask)

uv_016_da = regrid_enatl60_variable(enatl60_t_ds.uv, t_016_regridder).where(karin_016_mask)
nrv_016_da = regrid_enatl60_variable(enatl60_t_ds.nrv, t_016_regridder).where(karin_016_mask)
u_016_da = regrid_enatl60_variable(enatl60_u_ds.sozocrtx, u_016_regridder).where(karin_016_mask)
v_016_da = regrid_enatl60_variable(enatl60_v_ds.somecrty, v_016_regridder).where(karin_016_mask)
/var/folders/xc/bksmt58x2nq8jshz2jbf9b_m0000gn/T/ipykernel_45399/1634369195.py:46: FutureWarning: The return type of `Dataset.dims` will be changed to return a set of dimension names in future, in order to be more consistent with `DataArray.dims`. To access a mapping from dimension names to lengths, please use `Dataset.sizes`.
  n_pixels_003 = swot_003_ds.dims["num_pixels"]
/var/folders/xc/bksmt58x2nq8jshz2jbf9b_m0000gn/T/ipykernel_45399/1634369195.py:47: FutureWarning: The return type of `Dataset.dims` will be changed to return a set of dimension names in future, in order to be more consistent with `DataArray.dims`. To access a mapping from dimension names to lengths, please use `Dataset.sizes`.
  n_pixels_016 = swot_016_ds.dims["num_pixels"]

Show/hide code
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, subplot_kw={"projection": ccrs.PlateCarree()}, figsize=(14, 5))

fig.suptitle("Regridded eNATL60 surface fields")

ssh_min = np.nanmin(enatl60_t_ds.sossheig)
ssh_max = np.nanmax(enatl60_t_ds.sossheig)
uv_max = np.nanmax(enatl60_t_ds.uv)

ax1.set_title("$\\eta$")
_ = ax1.contourf(enatl60_t_ds.lon, enatl60_t_ds.lat, enatl60_t_ds.sossheig, cmap=cmo.deep_r, levels=20, alpha=0.3)
_ = ax1.pcolormesh(ssh_003_da.lon, ssh_003_da.lat, ssh_003_da, cmap=cmo.deep_r, vmin=ssh_min, vmax=ssh_max)
im = ax1.pcolormesh(ssh_016_da.lon, ssh_016_da.lat, ssh_016_da, cmap=cmo.deep_r, vmin=ssh_min, vmax=ssh_max)
ax1.coastlines()
cbar = plt.colorbar(im, ax=ax1)
cbar.ax.set_title("m")

ax2.set_title("$\|\\mathbf{u}\|$")
_ = ax2.contourf(enatl60_t_ds.lon, enatl60_t_ds.lat, enatl60_t_ds.sossheig, cmap=cmo.deep_r, levels=20, alpha=0.3)
_ = ax2.pcolormesh(uv_003_da.lon, uv_003_da.lat, uv_003_da, cmap=cmo.speed, vmin=0, vmax=uv_max)
im = ax2.pcolormesh(uv_016_da.lon, uv_016_da.lat, uv_016_da, cmap=cmo.speed, vmin=0, vmax=uv_max)
ax2.coastlines()
cbar = plt.colorbar(im, ax=ax2)
cbar.ax.set_title("m/s")

ax3.set_title("$\\zeta / f$")
_ = ax3.contourf(enatl60_t_ds.lon, enatl60_t_ds.lat, enatl60_t_ds.sossheig, cmap=cmo.deep_r, levels=20, alpha=0.3)
_ = ax3.pcolormesh(nrv_003_da.lon, nrv_003_da.lat, nrv_003_da, cmap=cmo.curl, vmin=-1, vmax=1)
im = ax3.pcolormesh(nrv_016_da.lon, nrv_016_da.lat, nrv_016_da, cmap=cmo.curl, vmin=-1, vmax=1)
ax3.coastlines()
cbar = plt.colorbar(im, ax=ax3, extend="both")

fig.tight_layout()
plt.show()
No description has been provided for this image

Surface currents inversion of the pseudo-SWOT SSH data

We then apply the minimization-based and the fixed-point inversion methods to the pseudo-SWOT SSH data to estimate the cyclogeostrophic surface currents by calling the functions minimization_based and fixed_point.

For the minimization-based method, we minimize the cyclogeostrophic imbalance using gradient descent and clipping to avoid large updates.

From the estimated surface currents velocity \(u\) and \(v\) components, we compute the magnitude of the velocity and the normalized relative vorticity using magnitude and relative_vorticity.

Show/hide code
optim = optax.chain(optax.clip(1), optax.sgd(learning_rate=5e-3))

# Run inversion on LEFT swath
mb_results_003_left = minimization_based(
    lat_t=swot_003_lat_left, lon_t=swot_003_lon_left, ssh_t=jnp.asarray(ssh_003_left_da.values), optim=optim, 
    return_geos=True
)
fp_results_003_left = fixed_point(
    lat_t=swot_003_lat_left, lon_t=swot_003_lon_left, ssh_t=jnp.asarray(ssh_003_left_da.values), return_geos=False
)

# Run inversion on RIGHT swath
mb_results_003_right = minimization_based(
    lat_t=swot_003_lat_right, lon_t=swot_003_lon_right, ssh_t=jnp.asarray(ssh_003_right_da.values), optim=optim, 
    return_geos=True
)
fp_results_003_right = fixed_point(
    lat_t=swot_003_lat_right, lon_t=swot_003_lon_right, ssh_t=jnp.asarray(ssh_003_right_da.values), return_geos=False
)

# Same for pass 016
mb_results_016_left = minimization_based(
    lat_t=swot_016_lat_left, lon_t=swot_016_lon_left, ssh_t=jnp.asarray(ssh_016_left_da.values), optim=optim, 
    return_geos=True
)
fp_results_016_left = fixed_point(
    lat_t=swot_016_lat_left, lon_t=swot_016_lon_left, ssh_t=jnp.asarray(ssh_016_left_da.values), return_geos=False
)

mb_results_016_right = minimization_based(
    lat_t=swot_016_lat_right, lon_t=swot_016_lon_right, ssh_t=jnp.asarray(ssh_016_right_da.values), optim=optim, 
    return_geos=True
)
fp_results_016_right = fixed_point(
    lat_t=swot_016_lat_right, lon_t=swot_016_lon_right, ssh_t=jnp.asarray(ssh_016_right_da.values), return_geos=False
)

# Combine left and right results for pass 003
u_g_003 = combine_left_right_results(
    mb_results_003_left.ug, mb_results_003_right.ug, n_pixels_003, left_pixels_003, right_pixels_003
)
v_g_003 = combine_left_right_results(
    mb_results_003_left.vg, mb_results_003_right.vg, n_pixels_003, left_pixels_003, right_pixels_003
)
u_cg_mb_003 = combine_left_right_results(
    mb_results_003_left.ucg, mb_results_003_right.ucg, n_pixels_003, left_pixels_003, right_pixels_003
)
v_cg_mb_003 = combine_left_right_results(
    mb_results_003_left.vcg, mb_results_003_right.vcg, n_pixels_003, left_pixels_003, right_pixels_003
)
u_cg_fp_003 = combine_left_right_results(
    fp_results_003_left.ucg, fp_results_003_right.ucg, n_pixels_003, left_pixels_003, right_pixels_003
)
v_cg_fp_003 = combine_left_right_results(
    fp_results_003_left.vcg, fp_results_003_right.vcg, n_pixels_003, left_pixels_003, right_pixels_003
)

# Combine left and right results for pass 016
u_g_016 = combine_left_right_results(
    mb_results_016_left.ug, mb_results_016_right.ug, n_pixels_016, left_pixels_016, right_pixels_016
)
v_g_016 = combine_left_right_results(
    mb_results_016_left.vg, mb_results_016_right.vg, n_pixels_016, left_pixels_016, right_pixels_016
)
u_cg_mb_016 = combine_left_right_results(
    mb_results_016_left.ucg, mb_results_016_right.ucg, n_pixels_016, left_pixels_016, right_pixels_016
)
v_cg_mb_016 = combine_left_right_results(
    mb_results_016_left.vcg, mb_results_016_right.vcg, n_pixels_016, left_pixels_016, right_pixels_016
)
u_cg_fp_016 = combine_left_right_results(
    fp_results_016_left.ucg, fp_results_016_right.ucg, n_pixels_016, left_pixels_016, right_pixels_016
)
v_cg_fp_016 = combine_left_right_results(
    fp_results_016_left.vcg, fp_results_016_right.vcg, n_pixels_016, left_pixels_016, right_pixels_016
)

# Compute magnitude and NRV on combined data
uv_g_003, uv_cg_mb_003, uv_cg_fp_003 = map(
    lambda args: kinematics.magnitude(*args), 
    ((u_g_003, v_g_003), (u_cg_mb_003, v_cg_mb_003), (u_cg_fp_003, v_cg_fp_003))
)
nrv_g_003, nrv_cg_mb_003, nrv_cg_fp_003 = map(
    lambda args: kinematics.vorticity(*args, lat_t=swot_003_lat_t, lon_t=swot_003_lon_t), 
    ((u_g_003, v_g_003), (u_cg_mb_003, v_cg_mb_003), (u_cg_fp_003, v_cg_fp_003))
)

uv_g_016, uv_cg_mb_016, uv_cg_fp_016 = map(
    lambda args: kinematics.magnitude(*args), 
    ((u_g_016, v_g_016), (u_cg_mb_016, v_cg_mb_016), (u_cg_fp_016, v_cg_fp_016))
)
nrv_g_016, nrv_cg_mb_016, nrv_cg_fp_016 = map(
    lambda args: kinematics.vorticity(*args, lat_t=swot_016_lat_t, lon_t=swot_016_lon_t), 
    ((u_g_016, v_g_016), (u_cg_mb_016, v_cg_mb_016), (u_cg_fp_016, v_cg_fp_016))
)

Comparison with eNATL60

To compute difference metrics with SWOT–like reference data from eNATL60, we have to interpolate reconstructed \(u\) and \(v\) components onto the T grid, this can be achevied using the interpolation function.

Surface currents

Show/hide code
fig, axes = plt.subplots(2, 2, subplot_kw={"projection": ccrs.PlateCarree()}, figsize=(10, 10))
ax1, ax2, ax3, ax4 = axes.flatten()

fig.suptitle("Pseudo-SWOT surface currents magnitude")

ax1.set_title("eNATL60")
_ = ax1.contourf(enatl60_t_ds.lon, enatl60_t_ds.lat, enatl60_t_ds.sossheig, cmap=cmo.deep_r, levels=20, alpha=0.3)
_ = ax1.pcolormesh(
    uv_003_da.lon, uv_003_da.lat, uv_003_da, 
    cmap=cmo.speed, vmin=0, vmax=uv_max
)
im = ax1.pcolormesh(
    uv_016_da.lon, uv_016_da.lat, uv_016_da, 
    cmap=cmo.speed, vmin=0, vmax=uv_max
)
ax1.coastlines()
cbar = plt.colorbar(im, ax=ax1)
cbar.set_label("m/s")

ax2.set_title("Cyclogeostrophy (minimization-based)")
_ = ax2.contourf(enatl60_t_ds.lon, enatl60_t_ds.lat, enatl60_t_ds.sossheig, cmap=cmo.deep_r, levels=20, alpha=0.3)
_ = ax2.pcolormesh(
    uv_003_da.lon, uv_003_da.lat, uv_cg_mb_003, 
    cmap=cmo.speed, vmin=0, vmax=uv_max
)
im = ax2.pcolormesh(
    uv_016_da.lon, uv_016_da.lat, uv_cg_mb_016, 
    cmap=cmo.speed, vmin=0, vmax=uv_max
)
ax2.coastlines()
cbar = plt.colorbar(im, ax=ax2)
cbar.set_label("m/s")

ax3.set_title("Geostrophy")
_ = ax3.contourf(enatl60_t_ds.lon, enatl60_t_ds.lat, enatl60_t_ds.sossheig, cmap=cmo.deep_r, levels=20, alpha=0.3)
_ = ax3.pcolormesh(
    uv_003_da.lon, uv_003_da.lat, uv_g_003, 
    cmap=cmo.speed, vmin=0, vmax=uv_max
)
im = ax3.pcolormesh(
    uv_016_da.lon, uv_016_da.lat, uv_g_016, 
    cmap=cmo.speed, vmin=0, vmax=uv_max
)
ax3.coastlines()
cbar = plt.colorbar(im, ax=ax3)
cbar.set_label("m/s")

ax4.set_title("Cyclogeostrophy (fixed-point)")
_ = ax4.contourf(enatl60_t_ds.lon, enatl60_t_ds.lat, enatl60_t_ds.sossheig, cmap=cmo.deep_r, levels=20, alpha=0.3)
_ = ax4.pcolormesh(
    uv_003_da.lon, uv_003_da.lat, uv_cg_fp_003, 
    cmap=cmo.speed, vmin=0, vmax=uv_max
)
im = ax4.pcolormesh(
    uv_016_da.lon, uv_016_da.lat, uv_cg_fp_016, 
    cmap=cmo.speed, vmin=0, vmax=uv_max
)
ax4.coastlines()
cbar = plt.colorbar(im, ax=ax4)
cbar.set_label("m/s")

fig.tight_layout()
plt.show()
No description has been provided for this image
Show/hide code
u_003_ref = jnp.asarray(u_003_da.values)
v_003_ref = jnp.asarray(v_003_da.values)

uv_err_g_003, uv_err_cg_mb_003, uv_err_cg_fp_003 = map(
    lambda u, v: kinematics.magnitude(u - u_003_ref, v - v_003_ref), 
    (u_g_003, u_cg_mb_003, u_cg_fp_003), (v_g_003, v_cg_mb_003, v_cg_fp_003)
)

u_016_ref = jnp.asarray(u_016_da.values)
v_016_ref = jnp.asarray(v_016_da.values)

uv_err_g_016, uv_err_cg_mb_016, uv_err_cg_fp_016 = map(
    lambda u, v: kinematics.magnitude(u - u_016_ref, v - v_016_ref), 
    (u_g_016, u_cg_mb_016, u_cg_fp_016), (v_g_016, v_cg_mb_016, v_cg_fp_016)
)

uv_err_max = max(
    jnp.nanmax(uv_err_g_003), jnp.nanmax(uv_err_g_016), 
    jnp.nanmax(uv_err_cg_mb_003), jnp.nanmax(uv_err_cg_mb_016), 
    jnp.nanmax(uv_err_cg_fp_003), jnp.nanmax(uv_err_cg_fp_016)
)
Show/hide code
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, subplot_kw={"projection": ccrs.PlateCarree()}, figsize=(14, 5))

fig.suptitle("Surface current error")

ax1.set_title("Geostrophy")
_ = ax1.contourf(enatl60_t_ds.lon, enatl60_t_ds.lat, enatl60_t_ds.sossheig, cmap=cmo.deep_r, levels=20, alpha=0.3)
_ = ax1.pcolormesh(
    ssh_003_da.lon, ssh_003_da.lat, uv_err_g_003, 
    cmap=cmo.amp, vmin=0, vmax=uv_err_max
)
im = ax1.pcolormesh(
    ssh_016_da.lon, ssh_016_da.lat, uv_err_g_016, 
    cmap=cmo.amp, vmin=0, vmax=uv_err_max
)
ax1.coastlines()
cbar = plt.colorbar(im, ax=ax1)
cbar.ax.set_title("m/s")

ax2.set_title("Cyclogeostrophy (minimization-based)")
_ = ax2.contourf(enatl60_t_ds.lon, enatl60_t_ds.lat, enatl60_t_ds.sossheig, cmap=cmo.deep_r, levels=20, alpha=0.3)
_ = ax2.pcolormesh(
    ssh_003_da.lon, ssh_003_da.lat, uv_err_cg_mb_003, 
    cmap=cmo.amp, vmin=0, vmax=uv_err_max
)
im = ax2.pcolormesh(
    ssh_016_da.lon, ssh_016_da.lat, uv_err_cg_mb_016, 
    cmap=cmo.amp, vmin=0, vmax=uv_err_max
)
ax2.coastlines()
cbar = plt.colorbar(im, ax=ax2)
cbar.ax.set_title("m/s")

ax3.set_title("Cyclogeostrophy (fixed-point)")
_ = ax3.contourf(enatl60_t_ds.lon, enatl60_t_ds.lat, enatl60_t_ds.sossheig, cmap=cmo.deep_r, levels=20, alpha=0.3)
_ = ax3.pcolormesh(
    ssh_003_da.lon, ssh_003_da.lat, uv_err_cg_fp_003, 
    cmap=cmo.amp, vmin=0, vmax=uv_err_max
)
im = ax3.pcolormesh(
    ssh_016_da.lon, ssh_016_da.lat, uv_err_cg_fp_016, 
    cmap=cmo.amp, vmin=0, vmax=uv_err_max
)
ax3.coastlines()
cbar = plt.colorbar(im, ax=ax3)
cbar.ax.set_title("m/s")

fig.tight_layout()
plt.show()
No description has been provided for this image
Show/hide code
uv_err_diff_geos_mb_003 = (uv_err_g_003 - uv_err_cg_mb_003)
uv_err_diff_geos_fp_003 = (uv_err_g_003 - uv_err_cg_fp_003)
uv_err_diff_fp_mb_003 = (uv_err_cg_fp_003 - uv_err_cg_mb_003)

uv_err_diff_geos_mb_016 = (uv_err_g_016 - uv_err_cg_mb_016)
uv_err_diff_geos_fp_016 = (uv_err_g_016 - uv_err_cg_fp_016)
uv_err_diff_fp_mb_016 = (uv_err_cg_fp_016 - uv_err_cg_mb_016)

uv_err_diff_max = max(
    np.nanmax(np.abs(uv_err_diff_geos_mb_003)),
    np.nanmax(np.abs(uv_err_diff_geos_fp_003)),
    np.nanmax(np.abs(uv_err_diff_fp_mb_003)),
    np.nanmax(np.abs(uv_err_diff_geos_mb_016)),
    np.nanmax(np.abs(uv_err_diff_geos_fp_016)),
    np.nanmax(np.abs(uv_err_diff_fp_mb_016))
)
Show/hide code
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, subplot_kw={"projection": ccrs.PlateCarree()}, figsize=(14, 5))

fig.suptitle("Surface current error differences")

ax1.set_title("Geos. - Fixed-point")
_ = ax1.contourf(enatl60_t_ds.lon, enatl60_t_ds.lat, enatl60_t_ds.sossheig, cmap=cmo.deep_r, levels=20, alpha=0.3)
_ = ax1.pcolormesh(
    ssh_003_da.lon, ssh_003_da.lat, uv_err_diff_geos_fp_003, 
    cmap=cmo.balance_r, vmin=-uv_err_diff_max, vmax=uv_err_diff_max
)
im = ax1.pcolormesh(
    ssh_016_da.lon, ssh_016_da.lat, uv_err_diff_geos_fp_016, 
    cmap=cmo.balance_r, vmin=-uv_err_diff_max, vmax=uv_err_diff_max
)
ax1.coastlines()
cbar = plt.colorbar(im, ax=ax1)
cbar.ax.set_title("m/s")

ax2.set_title("Geos. - Minimization-based")
_ = ax2.contourf(enatl60_t_ds.lon, enatl60_t_ds.lat, enatl60_t_ds.sossheig, cmap=cmo.deep_r, levels=20, alpha=0.3)
_ = ax2.pcolormesh(
    ssh_003_da.lon, ssh_003_da.lat, uv_err_diff_geos_mb_003, 
    cmap=cmo.balance_r, vmin=-uv_err_diff_max, vmax=uv_err_diff_max
)
im = ax2.pcolormesh(
    ssh_016_da.lon, ssh_016_da.lat, uv_err_diff_geos_mb_016, 
    cmap=cmo.balance_r, vmin=-uv_err_diff_max, vmax=uv_err_diff_max
)
ax2.coastlines()
cbar = plt.colorbar(im, ax=ax2)
cbar.ax.set_title("m/s")

ax3.set_title("Fixed-point - Minimization-based")
_ = ax3.contourf(enatl60_t_ds.lon, enatl60_t_ds.lat, enatl60_t_ds.sossheig, cmap=cmo.deep_r, levels=20, alpha=0.3)
_ = ax3.pcolormesh(
    ssh_003_da.lon, ssh_003_da.lat, uv_err_diff_fp_mb_003, 
    cmap=cmo.balance_r, vmin=-uv_err_diff_max, vmax=uv_err_diff_max
)
im = ax3.pcolormesh(
    ssh_016_da.lon, ssh_016_da.lat, uv_err_diff_fp_mb_016, 
    cmap=cmo.balance_r, vmin=-uv_err_diff_max, vmax=uv_err_diff_max
)
ax3.coastlines()
cbar = plt.colorbar(im, ax=ax3)
cbar.ax.set_title("m/s")

fig.tight_layout()
plt.show()
No description has been provided for this image

Both cyclogeostrophic inversion methods give similar results but it is already possible to notice that the minimization-based approach produce spatially smoother estimates.

Normalized relative vorticity

The grainy aspect of the fixed-point estimate will be more visible in the normalized relative vorticity fields as it involves the computation of spatial derivatives.

Show/hide code
fig, axes = plt.subplots(2, 2, subplot_kw={"projection": ccrs.PlateCarree()}, figsize=(10, 10))
ax1, ax2, ax3, ax4 = axes.flatten()

fig.suptitle("Pseudo-SWOT normalized relative vorticity")

ax1.set_title("eNATL60")
_ = ax1.contourf(enatl60_t_ds.lon, enatl60_t_ds.lat, enatl60_t_ds.sossheig, cmap=cmo.deep_r, levels=20, alpha=0.3)
_ = ax1.pcolormesh(
    nrv_003_da.lon, nrv_003_da.lat, nrv_003_da, 
    cmap=cmo.curl, vmin=-1, vmax=1
)
im = ax1.pcolormesh(
    nrv_016_da.lon, nrv_016_da.lat, nrv_016_da, 
    cmap=cmo.curl, vmin=-1, vmax=1
)
ax1.coastlines()
cbar = plt.colorbar(im, ax=ax1, extend="both")

ax2.set_title("Cyclogeostrophy (minimization-based)")
_ = ax2.contourf(enatl60_t_ds.lon, enatl60_t_ds.lat, enatl60_t_ds.sossheig, cmap=cmo.deep_r, levels=20, alpha=0.3)
_ = ax2.pcolormesh(
    nrv_003_da.lon, nrv_003_da.lat, nrv_cg_mb_003, 
    cmap=cmo.curl, vmin=-1, vmax=1
)
im = ax2.pcolormesh(
    nrv_016_da.lon, nrv_016_da.lat, nrv_cg_mb_016, 
    cmap=cmo.curl, vmin=-1, vmax=1
)
ax2.coastlines()
cbar = plt.colorbar(im, ax=ax2, extend="both")

ax3.set_title("Geostrophy")
_ = ax3.contourf(enatl60_t_ds.lon, enatl60_t_ds.lat, enatl60_t_ds.sossheig, cmap=cmo.deep_r, levels=20, alpha=0.3)
_ = ax3.pcolormesh(
    nrv_003_da.lon, nrv_003_da.lat, nrv_g_003, 
    cmap=cmo.curl, vmin=-1, vmax=1
)
im = ax3.pcolormesh(
    nrv_016_da.lon, nrv_016_da.lat, nrv_g_016, 
    cmap=cmo.curl, vmin=-1, vmax=1
)
ax3.coastlines()
cbar = plt.colorbar(im, ax=ax3, extend="both")

ax4.set_title("Cyclogeostrophy (fixed-point)")
_ = ax4.contourf(enatl60_t_ds.lon, enatl60_t_ds.lat, enatl60_t_ds.sossheig, cmap=cmo.deep_r, levels=20, alpha=0.3)
_ = ax4.pcolormesh(
    nrv_003_da.lon, nrv_003_da.lat, nrv_cg_fp_003, 
    cmap=cmo.curl, vmin=-1, vmax=1
)
im = ax4.pcolormesh(
    nrv_016_da.lon, nrv_016_da.lat, nrv_cg_fp_016, 
    cmap=cmo.curl, vmin=-1, vmax=1
)
ax4.coastlines()
cbar = plt.colorbar(im, ax=ax4, extend="both")

fig.tight_layout()
plt.show()
No description has been provided for this image
Show/hide code
nrv_err_g_003, nrv_err_cg_mb_003, nrv_err_cg_fp_003 = map(
    lambda nrv_est: np.abs(nrv_003_da.values - nrv_est), 
    (nrv_g_003, nrv_cg_mb_003, nrv_cg_fp_003)
)

nrv_err_g_016, nrv_err_cg_mb_016, nrv_err_cg_fp_016 = map(
    lambda nrv_est: np.abs(nrv_016_da.values - nrv_est), 
    (nrv_g_016, nrv_cg_mb_016, nrv_cg_fp_016)
)

nrv_err_max = 0.25
Show/hide code
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, subplot_kw={"projection": ccrs.PlateCarree()}, figsize=(14, 5))

fig.suptitle("Normalized relative vorticity error")

ax1.set_title("Geostrophy")
_ = ax1.contourf(enatl60_t_ds.lon, enatl60_t_ds.lat, enatl60_t_ds.sossheig, cmap=cmo.deep_r, levels=20, alpha=0.3)
_ = ax1.pcolormesh(
    ssh_003_da.lon, ssh_003_da.lat, nrv_err_g_003, 
    cmap=cmo.amp, vmin=0, vmax=nrv_err_max
)
im = ax1.pcolormesh(
    ssh_016_da.lon, ssh_016_da.lat, nrv_err_g_016, 
    cmap=cmo.amp, vmin=0, vmax=nrv_err_max
)
ax1.coastlines()
cbar = plt.colorbar(im, ax=ax1, extend="max")

ax2.set_title("Cyclogeostrophy (minimization-based)")
_ = ax2.contourf(enatl60_t_ds.lon, enatl60_t_ds.lat, enatl60_t_ds.sossheig, cmap=cmo.deep_r, levels=20, alpha=0.3)
_ = ax2.pcolormesh(
    ssh_003_da.lon, ssh_003_da.lat, nrv_err_cg_mb_003, 
    cmap=cmo.amp, vmin=0, vmax=nrv_err_max
)
im = ax2.pcolormesh(
    ssh_016_da.lon, ssh_016_da.lat, nrv_err_cg_mb_016, 
    cmap=cmo.amp, vmin=0, vmax=nrv_err_max
)
ax2.coastlines()
cbar = plt.colorbar(im, ax=ax2, extend="max")

ax3.set_title("Cyclogeostrophy (fixed-point)")
_ = ax3.contourf(enatl60_t_ds.lon, enatl60_t_ds.lat, enatl60_t_ds.sossheig, cmap=cmo.deep_r, levels=20, alpha=0.3)
_ = ax3.pcolormesh(
    ssh_003_da.lon, ssh_003_da.lat, nrv_err_cg_fp_003, 
    cmap=cmo.amp, vmin=0, vmax=nrv_err_max
)
im = ax3.pcolormesh(
    ssh_016_da.lon, ssh_016_da.lat, nrv_err_cg_fp_016, 
    cmap=cmo.amp, vmin=0, vmax=nrv_err_max
)
ax3.coastlines()
cbar = plt.colorbar(im, ax=ax3, extend="max")

fig.tight_layout()
plt.show()
No description has been provided for this image
Show/hide code
nrv_err_diff_geos_mb_003 = (nrv_err_g_003 - nrv_err_cg_mb_003)
nrv_err_diff_geos_fp_003 = (nrv_err_g_003 - nrv_err_cg_fp_003)
nrv_err_diff_fp_mb_003 = (nrv_err_cg_fp_003 - nrv_err_cg_mb_003)

nrv_err_diff_geos_mb_016 = (nrv_err_g_016 - nrv_err_cg_mb_016)
nrv_err_diff_geos_fp_016 = (nrv_err_g_016 - nrv_err_cg_fp_016)
nrv_err_diff_fp_mb_016 = (nrv_err_cg_fp_016 - nrv_err_cg_mb_016)

nrv_err_diff_max = 0.25
Show/hide code
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, subplot_kw={"projection": ccrs.PlateCarree()}, figsize=(15, 5))

fig.suptitle("Normalized relative vorticity error differences")

ax1.set_title("Geos. - Fixed-point")
_ = ax1.contourf(enatl60_t_ds.lon, enatl60_t_ds.lat, enatl60_t_ds.sossheig, cmap=cmo.deep_r, levels=20, alpha=0.3)
_ = ax1.pcolormesh(
    ssh_003_da.lon, ssh_003_da.lat, nrv_err_diff_geos_fp_003, 
    cmap=cmo.balance_r, vmin=-nrv_err_diff_max, vmax=nrv_err_diff_max
)
im = ax1.pcolormesh(
    ssh_016_da.lon, ssh_016_da.lat, nrv_err_diff_geos_fp_016, 
    cmap=cmo.balance_r, vmin=-nrv_err_diff_max, vmax=nrv_err_diff_max
)
ax1.coastlines()
cbar = plt.colorbar(im, ax=ax1, extend="both")

ax2.set_title("Geos. - Minimization-based")
_ = ax2.contourf(enatl60_t_ds.lon, enatl60_t_ds.lat, enatl60_t_ds.sossheig, cmap=cmo.deep_r, levels=20, alpha=0.3)
_ = ax2.pcolormesh(
    ssh_003_da.lon, ssh_003_da.lat, nrv_err_diff_geos_mb_003, 
    cmap=cmo.balance_r, vmin=-nrv_err_diff_max, vmax=nrv_err_diff_max
)
im = ax2.pcolormesh(
    ssh_016_da.lon, ssh_016_da.lat, nrv_err_diff_geos_mb_016, 
    cmap=cmo.balance_r, vmin=-nrv_err_diff_max, vmax=nrv_err_diff_max
)
ax2.coastlines()
cbar = plt.colorbar(im, ax=ax2, extend="both")

ax3.set_title("Fixed-point - Minimization-based")
_ = ax3.contourf(enatl60_t_ds.lon, enatl60_t_ds.lat, enatl60_t_ds.sossheig, cmap=cmo.deep_r, levels=20, alpha=0.3)
_ = ax3.pcolormesh(
    ssh_003_da.lon, ssh_003_da.lat, nrv_err_diff_fp_mb_003, 
    cmap=cmo.balance_r, vmin=-nrv_err_diff_max, vmax=nrv_err_diff_max
)
im = ax3.pcolormesh(
    ssh_016_da.lon, ssh_016_da.lat, nrv_err_diff_fp_mb_016, 
    cmap=cmo.balance_r, vmin=-nrv_err_diff_max, vmax=nrv_err_diff_max
)
ax3.coastlines()
cbar = plt.colorbar(im, ax=ax3, extend="both")

fig.tight_layout()
plt.show()
No description has been provided for this image

As anticipated, the minimization-based approach produces spatially smoother estimates of the normalized relative vorticity.