Skip to content

FAQ

How can I use jaxparrow to estimate cyclogeostrophic currents for large observation or model datasets?

Dask is probably your best bet to handle large datasets that do not fit into (CPU or GPU) memory.

On a CPU backend, one can use a chunk size of 1 along the time dimension and map the call to minimization_based onto the dataset:

import dask
import jax.numpy as jnp
import numpy as np
import xarray as xr

from jaxparrow import minimization_based


def do_one_block(in_block):
    mb_result = minimization_based(
        lat_t=jnp.asarray(in_block.lat.values), lon_tjnp.asarray(in_block.lon.values), 
        ssh_t=jnp.asarray(in_block.ssh.values), 
        return_geos=True
    )

    ucg = mb_result.ucg
    vcg = mb_result.vcg
    ug = mb_result.ug
    vg = mb_result.vg

    out_block = xr.Dataset(
        {
            "ucg": (in_block.ssh.dims, np.asarray(ucg)[None, :, :]),
            "vcg": (in_block.ssh.dims, np.asarray(vcg)[None, :, :]),
            "ug": (in_block.ssh.dims, np.asarray(ug)[None, :, :]),
            "vg": (in_block.ssh.dims, np.asarray(vg)[None, :, :]),
        }
    )

    return out_block


nt = ds.time.size
ny = ds.lat.size
nx = ds.lon.size

empty_arr = dask.array.empty((nt, ny, nx), chunks=(1, ny, nx), dtype=np.float32)
template = xr.Dataset(
    {
        "ucg": (ds.ssh.dims, empty_arr),
        "vcg": (ds.ssh.dims, empty_arr),
        "ug": (ds.ssh.dims, empty_arr),
        "vg": (ds.ssh.dims, empty_arr),
    },
)

result = xr.map_blocks(do_one_block, ds, template=template)
result = result.assign_coords({
    "time": ds.time,
    "lat": (("y", "x"), ds.lat),
    "lon": (("y", "x"), ds.lon),
})

result.to_zarr(OUT_PATH, compute=True, consolidated=False)

On a GPU backend, the previous approach can be combined with jax.vmap to further speed up computations by processing multiple time slices in parallel on the GPU:

import dask
import jax
import jax.numpy as jnp
import numpy as np
import xarray as xr

from jaxparrow.cyclogeostrophy import cyclogeostrophy


vmap_cyclogeostrophy = jax.vmap(
    lambda ssh, lat, lon: minimization_based(lat_t=lat, lon_t=lon, ssh_t=ssh, return_geos=True),
    in_axes=(0, None, None)
)


def do_one_block_vmap(in_block: xr.Dataset):
    mb_result = vmap_cyclogeostrophy(
        jnp.asarray(in_block.ssh.values), jnp.asarray(in_block.lat.values), jnp.asarray(in_block.lon.values)
    )

    ucg_3d = mb_result.uvg
    vcg_3d = mb_result.vcg
    ug_3d = mb_result.ug
    vg_3d = mb_result.vg

    out_block = xr.Dataset(
        {
            "ucg": (in_block.ssh.dims, np.asarray(ucg_3d)),
            "vcg": (in_block.ssh.dims, np.asarray(vcg_3d)),
            "ug": (in_block.ssh.dims, np.asarray(ug_3d)),
            "vg": (in_block.ssh.dims, np.asarray(vg_3d)),
        }
    )

    return out_block


nt = ds.time.size
ny = ds.lat.size
nx = ds.lon.size

empty_arr = dask.array.empty((nt, ny, nx), chunks=(BATCH_SIZE, ny, nx), dtype=np.float32)
template = xr.Dataset(
    {
        "ucg": (ds.ssh.dims, empty_arr),
        "vcg": (ds.ssh.dims, empty_arr),
        "ug": (ds.ssh.dims, empty_arr),
        "vg": (ds.ssh.dims, empty_arr),
    },
)

with dask.config.set(scheduler="synchronous"):
    result = xr.map_blocks(do_one_block_vmap, ds, template=template)
    result = result.assign_coords({
        "time": ds.time,
        "lat": (("y", "x"), ds.lat),
        "lon": (("y", "x"), ds.lon),
    })

    result.to_zarr(OUT_PATH, compute=True, consolidated=False)

Note that in this case, you will need to force Dask to use the synchronous scheduler as JAX is not multi-threaded. You can see this approach in action in the DUACS example.

I am getting very large current velocity estimates

From our experience, this can happen when the input SSH data contains unbalanced signals. To mitigate this, we clip updates during the gradient descent minimization. This is achieved using the optax.clip transformation:

import optax

optimizer = optax.chain(
    optax.clip(1.0),
    optax.adam(learning_rate=5e-3)
)

And then pass the optimizer object as the optim argument of the minimization_based function. This is employed in the Pseudo-SWOT observations from eNATL60 model data example.

We also recommend using JAX floating point types with sufficient precision, e.g., float64:

import jax
jax.config.update("jax_enable_x64", True)