Hands-on jaxparrow with Duacs data

The aim of this notebook is to illustrate how jaxparrow can be employed to derive geostrophic and cyclogeostrophic currents on a C-grid from Sea Surface Height (SSH) observations. The demo focuses on the Alboran Sea, a highly energetic area of the Mediterranean Sea (Ioannou et al. 2019).

We use the European Seas Gridded L4 Sea Surface Heights And Derived Variables Reprocessed dataset (description, reference). This product provides daily average of SSH, and geostrophic currents, on a rectilinear A-grid, with a spatial resolution of 1/8°.

We need to install some dependencies first:

!pip install ipympl matplotlib cartopy
!pip install copernicusmarine jaxparrow jaxtyping numpy xarray

%reload_ext autoreload
%autoreload 2

Accessing Duacs data

Copernicus Marine datasets can be accessed through the Copernicus Marine Toolbox API.

import copernicusmarine as cm

The API allows to download subsets of the datasets by restricting the spatial and temporal domains, and the variables.

import numpy as np

spatial_extent = (-5.3538, -1.1883, 35.0707, 36.8415)  # spatial extent (lon0, lon1, lat0, lat1) of the Alboran Sea
temporal_slice = (np.datetime64("2019-01-01T00:00:00"), np.datetime64("2019-12-31T23:59:59"))  # we look at the 2019 data for our demo
variables = ["adt", "ugos", "vgos"]  # we retrieve SSH and geostrophic currents (for comparison) data

dataset_options = {
    "dataset_id": "cmems_obs-sl_eur_phy-ssh_my_allsat-l4-duacs-0.125deg_P1D",
    "variables": variables,
    "minimum_longitude": spatial_extent[0],
    "maximum_longitude": spatial_extent[1],
    "minimum_latitude": spatial_extent[2],
    "maximum_latitude": spatial_extent[3],
    "start_datetime": temporal_slice[0],
    "end_datetime": temporal_slice[1]
}
duacs_ds = cm.open_dataset(**dataset_options)

Visualisation

Lets visualise how the SSH, and the magnitude of the geostrophic currents, evolve over the time period.

duacs_ds = duacs_ds.assign(uvgos=np.sqrt(duacs_ds.ugos ** 2 + duacs_ds.vgos ** 2))
import matplotlib.pyplot as plt

from duacs_visualisation import AnimatedSSHCurrent
%matplotlib widget

anim = AnimatedSSHCurrent(duacs_ds, ("adt", "uvgos"), ("Duacs", "Duacs"))
plt.show()

png

Geostrophic currents using jaxparrow

jaxparrow uses C-grids, following NEMO convention. U, V, and F points are automatically derived from the T points.

import jax.numpy as jnp  # we manipulate jax.Array

lat_t = jnp.ones((duacs_ds.latitude.size, duacs_ds.longitude.size)) * duacs_ds.latitude.data.reshape(-1, 1)
lon_t = jnp.ones((duacs_ds.latitude.size, duacs_ds.longitude.size)) * duacs_ds.longitude.data

The spatial domain covers sea and land, we derive tbe mask to exclude the land parts of the domain from the adt invalid values.

adt_t = jnp.asarray(duacs_ds.adt.data)
mask = ~(jnp.isfinite(adt_t))

And we compute the geostrophic currents using the geostrophy function.

Rather than looping over our time indices, we can vectorise the geostrophy function over the time axis and compute the geostrophic currents at every time point using the vectorise version.

import jax
import jaxparrow as jpw

vmap_geostrophy = jax.vmap(jpw.geostrophy, in_axes=(0, None, None, 0), out_axes=(0, 0, None, None, None, None))

ug_jpw_u, vg_jpw_v, lat_u, lon_u, lat_v, lon_v = vmap_geostrophy(adt_t, lat_t, lon_t, mask)

To visualise the results, we compute the magnitude of the velocity.

from jaxparrow.tools.kinematics import magnitude

uvg_jpw_t = jax.vmap(magnitude, in_axes=(0, 0))(ug_jpw_u, vg_jpw_v)

We store everything in an xarray Dataset.

import xarray as xr

gos_jpw_ds = xr.Dataset(
    {
        "adt": (["time", "latitude", "longitude"], adt_t),
        "ug": (["time", "latitude_u", "longitude_u"], ug_jpw_u),
        "vg": (["time", "latitude_v", "longitude_v"], vg_jpw_v),
        "uvg": (["time", "latitude", "longitude"], uvg_jpw_t)
    },
    coords={
        "time": duacs_ds.time,
        "latitude": duacs_ds.latitude, "longitude": duacs_ds.longitude,
        "latitude_u": np.unique(lat_u).astype(np.float32), "longitude_u": np.unique(lon_u).astype(np.float32),
        "latitude_v": np.unique(lat_v).astype(np.float32), "longitude_v": np.unique(lon_v).astype(np.float32)
    }
)

Visualisation

anim = AnimatedSSHCurrent(gos_jpw_ds, ("adt", "uvg"), ("Duacs", "jaxparrow (geos)"))
plt.show()

png

Geostrophic inter-comparison

For sanity check we can compare the two geostrophic reconstructions.

from duacs_visualisation import AnimatedCurrents

gos_ds = xr.Dataset(
    {
        "uvg": (["time", "latitude", "longitude"], duacs_ds.uvgos.data),
        "uvg_jpw": (["time", "latitude", "longitude"], uvg_jpw_t)
    },
    coords={
        "time": duacs_ds.time,
        "latitude": duacs_ds.latitude, "longitude": duacs_ds.longitude
    }
)

anim = AnimatedCurrents(gos_ds, ("uvg", "uvg_jpw"), ("Duacs", "jaxparrow"))
plt.show()

png

Cyclogeostrophic currents using jaxparrow

Now, lets see the results of the variational inversion of the cyclogeostrophic currents.

vmap_cyclogeostrophy = jax.vmap(jpw.cyclogeostrophy, in_axes=(0, None, None, 0), out_axes=(0, 0, None, None, None, None))

uc_jpw_u, vc_jpw_v, lat_u, lon_u, lat_v, lon_v = vmap_cyclogeostrophy(adt_t, lat_t, lon_t, mask)

uvc_jpw_t = jax.vmap(magnitude, in_axes=(0, 0))(uc_jpw_u, vc_jpw_v)
/Users/bertrava/projects/jaxparrow/venv/lib/python3.9/site-packages/matplotlib/animation.py:892: UserWarning: Animation was deleted without rendering anything. This is most likely not intended. To prevent deletion, assign the Animation to a variable, e.g. `anim`, that exists until you output the Animation using `plt.show()` or `anim.save()`.
  warnings.warn(
cgos_jpw_ds = xr.Dataset(
    {
        "adt": (["time", "latitude", "longitude"], adt_t),
        "uc": (["time", "latitude_u", "longitude_u"], uc_jpw_u),
        "vc": (["time", "latitude_v", "longitude_v"], vc_jpw_v),
        "uvc": (["time", "latitude", "longitude"], uvc_jpw_t)
    },
    coords={
        "time": duacs_ds.time,
        "latitude": duacs_ds.latitude, "longitude": duacs_ds.longitude,
        "latitude_u": np.unique(lat_u).astype(np.float32), "longitude_u": np.unique(lon_u).astype(np.float32),
        "latitude_v": np.unique(lat_v).astype(np.float32), "longitude_v": np.unique(lon_v).astype(np.float32)
    }
)

Visualisation

anim = AnimatedSSHCurrent(cgos_jpw_ds, ("adt", "uvc"), ("Duacs", "jaxparrow (cyclogeos)"))
plt.show()

png

Comparison with geostrophy

jpw_ds = xr.Dataset(
    {
        "uvg": (["time", "latitude", "longitude"], uvg_jpw_t),
        "uvc": (["time", "latitude", "longitude"], uvc_jpw_t)
    },
    coords={
        "time": duacs_ds.time,
        "latitude": duacs_ds.latitude, "longitude": duacs_ds.longitude
    }
)

anim = AnimatedCurrents(jpw_ds, ("uvg", "uvc"), ("geostrophy", "cyclogeostrophy"))
plt.show()

png