FAQ¶
How can I use jaxparrow to estimate cyclogeostrophic currents for large observation or model datasets?¶
Dask is probably your best bet to handle large datasets that do not fit into (CPU or GPU) memory.
On a CPU backend, one can use a chunk size of 1 along the time dimension and map the call to minimization_based onto the dataset:
import dask
import jax.numpy as jnp
import numpy as np
import xarray as xr
from jaxparrow import minimization_based
def do_one_block(in_block):
mb_result = minimization_based(
lat_t=jnp.asarray(in_block.lat.values), lon_tjnp.asarray(in_block.lon.values),
ssh_t=jnp.asarray(in_block.ssh.values),
return_geos=True
)
ucg = mb_result.ucg
vcg = mb_result.vcg
ug = mb_result.ug
vg = mb_result.vg
out_block = xr.Dataset(
{
"ucg": (in_block.ssh.dims, np.asarray(ucg)[None, :, :]),
"vcg": (in_block.ssh.dims, np.asarray(vcg)[None, :, :]),
"ug": (in_block.ssh.dims, np.asarray(ug)[None, :, :]),
"vg": (in_block.ssh.dims, np.asarray(vg)[None, :, :]),
}
)
return out_block
nt = ds.time.size
ny = ds.lat.size
nx = ds.lon.size
empty_arr = dask.array.empty((nt, ny, nx), chunks=(1, ny, nx), dtype=np.float32)
template = xr.Dataset(
{
"ucg": (ds.ssh.dims, empty_arr),
"vcg": (ds.ssh.dims, empty_arr),
"ug": (ds.ssh.dims, empty_arr),
"vg": (ds.ssh.dims, empty_arr),
},
)
result = xr.map_blocks(do_one_block, ds, template=template)
result = result.assign_coords({
"time": ds.time,
"lat": (("y", "x"), ds.lat),
"lon": (("y", "x"), ds.lon),
})
result.to_zarr(OUT_PATH, compute=True, consolidated=False)
On a GPU backend, the previous approach can be combined with jax.vmap to further speed up computations by processing multiple time slices in parallel on the GPU:
import dask
import jax
import jax.numpy as jnp
import numpy as np
import xarray as xr
from jaxparrow.cyclogeostrophy import cyclogeostrophy
vmap_cyclogeostrophy = jax.vmap(
lambda ssh, lat, lon: minimization_based(lat_t=lat, lon_t=lon, ssh_t=ssh, return_geos=True),
in_axes=(0, None, None)
)
def do_one_block_vmap(in_block: xr.Dataset):
mb_result = vmap_cyclogeostrophy(
jnp.asarray(in_block.ssh.values), jnp.asarray(in_block.lat.values), jnp.asarray(in_block.lon.values)
)
ucg_3d = mb_result.uvg
vcg_3d = mb_result.vcg
ug_3d = mb_result.ug
vg_3d = mb_result.vg
out_block = xr.Dataset(
{
"ucg": (in_block.ssh.dims, np.asarray(ucg_3d)),
"vcg": (in_block.ssh.dims, np.asarray(vcg_3d)),
"ug": (in_block.ssh.dims, np.asarray(ug_3d)),
"vg": (in_block.ssh.dims, np.asarray(vg_3d)),
}
)
return out_block
nt = ds.time.size
ny = ds.lat.size
nx = ds.lon.size
empty_arr = dask.array.empty((nt, ny, nx), chunks=(BATCH_SIZE, ny, nx), dtype=np.float32)
template = xr.Dataset(
{
"ucg": (ds.ssh.dims, empty_arr),
"vcg": (ds.ssh.dims, empty_arr),
"ug": (ds.ssh.dims, empty_arr),
"vg": (ds.ssh.dims, empty_arr),
},
)
with dask.config.set(scheduler="synchronous"):
result = xr.map_blocks(do_one_block_vmap, ds, template=template)
result = result.assign_coords({
"time": ds.time,
"lat": (("y", "x"), ds.lat),
"lon": (("y", "x"), ds.lon),
})
result.to_zarr(OUT_PATH, compute=True, consolidated=False)
Note that in this case, you will need to force Dask to use the synchronous scheduler as JAX is not multi-threaded.
You can see this approach in action in the DUACS example.
I am getting very large current velocity estimates¶
From our experience, this can happen when the input SSH data contains unbalanced signals.
To mitigate this, we clip updates during the gradient descent minimization.
This is achieved using the optax.clip transformation:
import optax
optimizer = optax.chain(
optax.clip(1.0),
optax.adam(learning_rate=5e-3)
)
And then pass the optimizer object as the optim argument of the minimization_based function.
This is employed in the Pseudo-SWOT observations from eNATL60 model data example.
We also recommend using JAX floating point types with sufficient precision, e.g., float64:
import jax
jax.config.update("jax_enable_x64", True)