Skip to content

API

jaxparrow.cyclogeostrophy

CyclogeostrophyResult

Bases: NamedTuple

Result of cyclogeostrophic velocity computation.

This NamedTuple provides named access to results, avoiding positional unpacking errors. All fields except ucg and vcg are optional and depend on the return_* flags passed to the computation function.

Attributes:

Name Type Description
ucg Float[Array, 'y x']

\(u\) component of cyclogeostrophic velocity, on the T grid

vcg Float[Array, 'y x']

\(v\) component of cyclogeostrophic velocity, on the T grid

ug Float[Array, 'y x'] | None

\(u\) component of geostrophic velocity, on the T grid (if return_geos=True)

vg Float[Array, 'y x'] | None

\(v\) component of geostrophic velocity, on the T grid (if return_geos=True)

losses Float[Array, n_it] | None

Cyclogeostrophic imbalance over iterations (if return_losses=True)

Source code in jaxparrow/cyclogeostrophy/_core.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
class CyclogeostrophyResult(NamedTuple):
    """
    Result of cyclogeostrophic velocity computation.

    This NamedTuple provides named access to results, avoiding positional unpacking errors.
    All fields except ``ucg`` and ``vcg`` are optional and depend on the
    ``return_*`` flags passed to the computation function.

    Attributes
    ----------
    ucg : Float[jax.Array, "y x"]
        $u$ component of cyclogeostrophic velocity, on the T grid
    vcg : Float[jax.Array, "y x"]
        $v$ component of cyclogeostrophic velocity, on the T grid
    ug : Float[jax.Array, "y x"] | None
        $u$ component of geostrophic velocity, on the T grid (if ``return_geos=True``)
    vg : Float[jax.Array, "y x"] | None
        $v$ component of geostrophic velocity, on the T grid (if ``return_geos=True``)
    losses : Float[jax.Array, "n_it"] | None
        Cyclogeostrophic imbalance over iterations (if ``return_losses=True``)
    """

    ucg: Float[jax.Array, "y x"]
    vcg: Float[jax.Array, "y x"]
    ug: Float[jax.Array, "y x"] = None
    vg: Float[jax.Array, "y x"] = None
    losses: Float[jax.Array, "n_it"] = None

cyclogeostrophic_loss(ug, vg, ucg, vcg, lat_t=None, lon_t=None, lat_u=None, lon_u=None, lat_v=None, lon_v=None, land_mask=None, uv_on_t=True, is_grid_rectilinear=None)

Computes the cyclogeostrophic imbalance loss (a scalar) from a geostrophic and a cyclogeostrophic velocity field.

The velocity fields can be provided either on the T grid (uv_on_t=True) or on the U/V grids (uv_on_t=False).

If provided, the lat_u, lon_u, lat_v, and lon_v are expected to follow the NEMO convention.

Parameters:

Name Type Description Default
ug Float[Array, 'y x']

\(u\) component of the geostrophic velocity field

required
vg Float[Array, 'y x']

\(v\) component of the geostrophic velocity field

required
ucg Float[Array, 'y x']

\(u\) component of the cyclogeostrophic velocity field

required
vcg Float[Array, 'y x']

\(v\) component of the cyclogeostrophic velocity field

required
lat_t Float[Array, 'y x']

Latitudes of the T grid.

If lat_u, lon_u, lat_v, and lon_v are not provided, lat_t and lon_t must be provided to compute them.

Defaults to None

None
lon_t Float[Array, 'y x']

Longitudes of the T grid.

If lat_u, lon_u, lat_v, and lon_v are not provided, lat_t and lon_t must be provided to compute them.

Defaults to None

None
lat_u Float[Array, 'y x']

Latitudes of the U grid.

Defaults to None

None
lon_u Float[Array, 'y x']

Longitudes of the U grid.

Defaults to None

None
lat_v Float[Array, 'y x']

Latitudes of the V grid.

Defaults to None

None
lon_v Float[Array, 'y x']

Longitudes of the V grid.

Defaults to None

None
land_mask Float[Array, 'y x']

Mask defining the marine area of the spatial domain; 1 or True stands for masked (i.e. land)

None
uv_on_t bool

If True, the velocity components are assumed to be located on the T grid (this is important when manipulating staggered grids)

Defaults to True

True
is_grid_rectilinear bool | None

If True, the grid is assumed to be rectilinear and no rotation is applied to the input velocities. If False, the input velocities are rotated to grid coordinates before computing the imbalance. If None, the grid type is inferred from the grid angles (if angles are close to zero, the grid is considered rectilinear).

Defaults to None

None

Returns:

Name Type Description
loss Float[Array, '']

Cyclogeostrophic imbalance loss

Source code in jaxparrow/cyclogeostrophy/_core.py
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
def cyclogeostrophic_loss(
    ug: Float[jax.Array, "y x"],
    vg: Float[jax.Array, "y x"],
    ucg: Float[jax.Array, "y x"],
    vcg: Float[jax.Array, "y x"],
    lat_t: Float[jax.Array, "y x"] = None,
    lon_t: Float[jax.Array, "y x"] = None,
    lat_u: Float[jax.Array, "y x"] = None,
    lon_u: Float[jax.Array, "y x"] = None,
    lat_v: Float[jax.Array, "y x"] = None,
    lon_v: Float[jax.Array, "y x"] = None,
    land_mask: Float[jax.Array, "y x"] = None,
    uv_on_t: bool = True,
    is_grid_rectilinear: bool | None = None,
) -> Float[jax.Array, ""]:
    """
    Computes the cyclogeostrophic imbalance loss (a scalar) from a geostrophic and a cyclogeostrophic velocity field.

    The velocity fields can be provided either on the T grid (``uv_on_t=True``) or on the U/V grids (``uv_on_t=False``).

    If provided, the ``lat_u``, ``lon_u``, ``lat_v``, and ``lon_v`` are expected to follow the NEMO convention.

    Parameters
    ----------
    ug : Float[jax.Array, "y x"]
        $u$ component of the geostrophic velocity field
    vg : Float[jax.Array, "y x"]
        $v$ component of the geostrophic velocity field
    ucg : Float[jax.Array, "y x"]
        $u$ component of the cyclogeostrophic velocity field
    vcg : Float[jax.Array, "y x"]
        $v$ component of the cyclogeostrophic velocity field
    lat_t : Float[jax.Array, "y x"], optional
        Latitudes of the T grid.

        If ``lat_u``, ``lon_u``, ``lat_v``, and ``lon_v`` are not provided, ``lat_t`` and ``lon_t`` must be provided to compute them.

        Defaults to `None`
    lon_t : Float[jax.Array, "y x"], optional
        Longitudes of the T grid.

        If ``lat_u``, ``lon_u``, ``lat_v``, and ``lon_v`` are not provided, ``lat_t`` and ``lon_t`` must be provided to compute them.

        Defaults to `None`
    lat_u : Float[jax.Array, "y x"], optional
        Latitudes of the U grid.

        Defaults to `None`
    lon_u : Float[jax.Array, "y x"], optional
        Longitudes of the U grid.

        Defaults to `None`
    lat_v : Float[jax.Array, "y x"], optional
        Latitudes of the V grid.

        Defaults to `None`
    lon_v : Float[jax.Array, "y x"], optional
        Longitudes of the V grid.

        Defaults to `None`
    land_mask : Float[jax.Array, "y x"], optional
        Mask defining the marine area of the spatial domain; `1` or `True` stands for masked (i.e. land)
    uv_on_t : bool, optional
        If `True`, the velocity components are assumed to be located on the T grid 
        (this is important when manipulating staggered grids)

        Defaults to `True`
    is_grid_rectilinear : bool | None, optional
        If `True`, the grid is assumed to be rectilinear and no rotation is applied to the input velocities. 
        If `False`, the input velocities are rotated to grid coordinates before computing the imbalance.
        If `None`, the grid type is inferred from the grid angles (if angles are close to zero, the grid is considered rectilinear).

        Defaults to `None`
    Returns
    -------
    loss : Float[jax.Array, ""]
        Cyclogeostrophic imbalance loss
    """
    u_imbalance, v_imbalance = cyclogeostrophic_imbalance(
        ug, vg, ucg, vcg, lat_t, lon_t, lat_u, lon_u, lat_v, lon_v, land_mask, uv_on_t, is_grid_rectilinear
    )

    return jnp.nansum(u_imbalance ** 2 + v_imbalance ** 2)

cyclogeostrophic_imbalance(ug, vg, ucg, vcg, lat_t=None, lon_t=None, lat_u=None, lon_u=None, lat_v=None, lon_v=None, land_mask=None, uv_on_t=True, is_grid_rectilinear=None)

Computes the cyclogeostrophic imbalance field from a geostrophic and a cyclogeostrophic velocity field.

The velocity fields can be provided either on the T grid (uv_on_t=True) or on the U/V grids (uv_on_t=False).

If provided, the lat_u, lon_u, lat_v, and lon_v are expected to follow the NEMO convention.

Parameters:

Name Type Description Default
ug Float[Array, 'y x']

\(u\) component of the geostrophic velocity field

required
vg Float[Array, 'y x']

\(v\) component of the geostrophic velocity field

required
ucg Float[Array, 'y x']

\(u\) component of the cyclogeostrophic velocity field

required
vcg Float[Array, 'y x']

\(v\) component of the cyclogeostrophic velocity field

required
lat_t Float[Array, 'y x']

Latitudes of the T grid.

If lat_u, lon_u, lat_v, and lon_v are not provided, lat_t and lon_t must be provided to compute them.

Defaults to None

None
lon_t Float[Array, 'y x']

Longitudes of the T grid.

If lat_u, lon_u, lat_v, and lon_v are not provided, lat_t and lon_t must be provided to compute them.

Defaults to None

None
lat_u Float[Array, 'y x']

Latitudes of the U grid.

Defaults to None

None
lon_u Float[Array, 'y x']

Longitudes of the U grid.

Defaults to None

None
lat_v Float[Array, 'y x']

Latitudes of the V grid.

Defaults to None

None
lon_v Float[Array, 'y x']

Longitudes of the V grid.

Defaults to None

None
land_mask Float[Array, 'y x']

Mask defining the marine area of the spatial domain; 1 or True stands for masked (i.e. land)

None
uv_on_t bool

If True, the velocity components are assumed to be located on the T grid (this is important when manipulating staggered grids)

Defaults to True

True
is_grid_rectilinear bool | None

If True, the grid is assumed to be rectilinear and no rotation is applied to the input velocities. If False, the input velocities are rotated to grid coordinates before computing the imbalance. If None, the grid type is inferred from the grid angles (if angles are close to zero, the grid is considered rectilinear).

Defaults to None

None

Returns:

Name Type Description
u_imbalance Float[Array, 'y x']

\(u\) component of the cyclogeostrophic imbalance, on the T grid

v_imbalance Float[Array, 'y x']

\(v\) component of the cyclogeostrophic imbalance, on the T grid

Source code in jaxparrow/cyclogeostrophy/_core.py
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
def cyclogeostrophic_imbalance(
    ug: Float[jax.Array, "y x"],
    vg: Float[jax.Array, "y x"],
    ucg: Float[jax.Array, "y x"],
    vcg: Float[jax.Array, "y x"],
    lat_t: Float[jax.Array, "y x"] = None,
    lon_t: Float[jax.Array, "y x"] = None,
    lat_u: Float[jax.Array, "y x"] = None,
    lon_u: Float[jax.Array, "y x"] = None,
    lat_v: Float[jax.Array, "y x"] = None,
    lon_v: Float[jax.Array, "y x"] = None,
    land_mask: Float[jax.Array, "y x"] = None,
    uv_on_t: bool = True,
    is_grid_rectilinear: bool | None = None,
) -> tuple[Float[jax.Array, "y x"], Float[jax.Array, "y x"]]:
    """
    Computes the cyclogeostrophic imbalance field from a geostrophic and a cyclogeostrophic velocity field.

    The velocity fields can be provided either on the T grid (``uv_on_t=True``) or on the U/V grids (``uv_on_t=False``).

    If provided, the ``lat_u``, ``lon_u``, ``lat_v``, and ``lon_v`` are expected to follow the NEMO convention.

    Parameters
    ----------
    ug : Float[jax.Array, "y x"]
        $u$ component of the geostrophic velocity field
    vg : Float[jax.Array, "y x"]
        $v$ component of the geostrophic velocity field
    ucg : Float[jax.Array, "y x"]
        $u$ component of the cyclogeostrophic velocity field
    vcg : Float[jax.Array, "y x"]
        $v$ component of the cyclogeostrophic velocity field
    lat_t : Float[jax.Array, "y x"], optional
        Latitudes of the T grid.

        If ``lat_u``, ``lon_u``, ``lat_v``, and ``lon_v`` are not provided, ``lat_t`` and ``lon_t`` must be provided to compute them.

        Defaults to `None`
    lon_t : Float[jax.Array, "y x"], optional
        Longitudes of the T grid.

        If ``lat_u``, ``lon_u``, ``lat_v``, and ``lon_v`` are not provided, ``lat_t`` and ``lon_t`` must be provided to compute them.

        Defaults to `None`
    lat_u : Float[jax.Array, "y x"], optional
        Latitudes of the U grid.

        Defaults to `None`
    lon_u : Float[jax.Array, "y x"], optional
        Longitudes of the U grid.

        Defaults to `None`
    lat_v : Float[jax.Array, "y x"], optional
        Latitudes of the V grid.

        Defaults to `None`
    lon_v : Float[jax.Array, "y x"], optional
        Longitudes of the V grid.

        Defaults to `None`
    land_mask : Float[jax.Array, "y x"], optional
        Mask defining the marine area of the spatial domain; `1` or `True` stands for masked (i.e. land)
    uv_on_t : bool, optional
        If `True`, the velocity components are assumed to be located on the T grid 
        (this is important when manipulating staggered grids)

        Defaults to `True`
    is_grid_rectilinear : bool | None, optional
        If `True`, the grid is assumed to be rectilinear and no rotation is applied to the input velocities. 
        If `False`, the input velocities are rotated to grid coordinates before computing the imbalance.
        If `None`, the grid type is inferred from the grid angles (if angles are close to zero, the grid is considered rectilinear).

        Defaults to `None`

    Returns
    -------
    u_imbalance : Float[jax.Array, "y x"]
        $u$ component of the cyclogeostrophic imbalance, on the T grid
    v_imbalance : Float[jax.Array, "y x"]
        $v$ component of the cyclogeostrophic imbalance, on the T grid
    """
    if land_mask is None:
        land_mask = sanitize.init_land_mask(ug)

    if not uv_on_t:
        ug = operators.interpolation(ug, axis=1, padding="left", land_mask=land_mask)  # U(i), U(i+1) -> T(i+1)
        vg = operators.interpolation(vg, axis=0, padding="left", land_mask=land_mask)  # U(i), U(i+1) -> T(i+1)
        ucg = operators.interpolation(ucg, axis=1, padding="right", land_mask=land_mask)
        vcg = operators.interpolation(vcg, axis=0, padding="right", land_mask=land_mask)

    if lat_t is None or lon_t is None:
        if lat_u is not None and lon_u is not None:
            lat_t = operators.interpolation(lat_u, axis=1, padding="left", land_mask=land_mask)
            lon_t = operators.interpolation(lon_u, axis=1, padding="left", land_mask=land_mask)
        elif lat_v is not None and lon_v is not None:
            lat_t = operators.interpolation(lat_v, axis=0, padding="left", land_mask=land_mask)
            lon_t = operators.interpolation(lon_v, axis=0, padding="left", land_mask=land_mask)
        else:
            raise ValueError("Either lat_t and lon_t, or lat_u, lon_u, lat_v, and lon_v must be provided")

    grid_angle_i, grid_angle_j = None, None

    if is_grid_rectilinear is None:
        grid_angle_i, grid_angle_j = geometry.compute_grid_angle(lat_t, lon_t)
        is_grid_rectilinear = jnp.all(jnp.abs(grid_angle_i) < 1e-3)

    if not is_grid_rectilinear:
        if grid_angle_i is None or grid_angle_j is None:
            grid_angle_i, grid_angle_j = geometry.compute_grid_angle(lat_t, lon_t)

        # rotate the input velocities to the grid coordinates
        ug, vg = geometry.rotate_to_grid(ug, vg, grid_angle_i, grid_angle_j)
        ucg, vcg = geometry.rotate_to_grid(ucg, vcg, grid_angle_i, grid_angle_j)

    # compute grid spacing once
    dx, dy = geometry.grid_spacing(lat_t, lon_t)
    f = geometry.coriolis_factor(lat_t)

    return _cyclogeostrophic_imbalance(ug, vg, ucg, vcg, dx, dy, f, land_mask)

fixed_point(lat_t, lon_t, ssh_t=None, ug_t=None, vg_t=None, land_mask=None, is_grid_rectilinear=None, rotate_to_geographic=True, return_geos=False, return_losses=False, n_it=20, res_eps=0.01)

Computes the cyclogeostrophic Sea Surface Current (SSC) velocity field using the fixed-point method Penven et al. (2014).

There are two modes of operation:

  1. SSH mode: Provide lat_t, lon_t, ssh_t (and optionally land_mask). Geostrophic velocities will be computed from SSH.

  2. Geostrophic mode: Provide lat_t, lon_t, ug_t, vg_t (and optionally land_mask). Geostrophic velocities are provided on the T grid.

Parameters:

Name Type Description Default
lat_t Float[Array, 'y x']

Latitude of the T grid

required
lon_t Float[Array, 'y x']

Longitude of the T grid

required
ssh_t Float[Array, 'y x']

SSH field (on the T grid)

Defaults to None, required if geostrophic velocities are not provided

None
ug_t Float[Array, 'y x']

U component of geostrophic velocity on T grid

Defaults to None, required if ssh_t is not provided

None
vg_t Float[Array, 'y x']

V component of geostrophic velocity on T grid

Defaults to None, required if ssh_t is not provided

None
land_mask Float[Array, 'y x']

Mask defining the marine area of the spatial domain; 1 or True stands for masked (i.e. land)

If not provided, inferred from ssh_t or ug_t nan values

Defaults to None

None
is_grid_rectilinear bool

If True, the grid is assumed to be rectilinear in geographic coordinates. If False, the grid is assumed to be curvilinear and the grid angle is computed from the grid spacing. If None, the grid is assumed to be rectilinear if the grid angle computed from the grid spacing is close to zero everywhere, and curvilinear otherwise.

Defaults to None

None
rotate_to_geographic bool

If True, rotates the output velocities from grid-relative to geographic coordinates. Rotation is performed using the grid angle computed from the grid spacing. If False, output velocities are in grid-relative coordinates.

If using a rectilinear grid in geographic coordinates, set to False to avoid unnecessary rotation.

Defaults to True

True
return_geos bool

If True, returns the geostrophic SSC velocity field in addition to the cyclogeostrophic one.

Defaults to False

False
return_losses bool

If True, returns the losses (cyclogeostrophic imbalance) over iterations.

Defaults to False

False
n_it int

Maximum number of iterations.

Defaults to 20

20
res_eps float

Residual tolerance of the iterative approach. When residuals are smaller, the iterative approach considers local convergence to cyclogeostrophy.

Defaults to 0.01

0.01

Returns:

Type Description
CyclogeostrophyResult

Named tuple containing: - ucg: \(u\) component of cyclogeostrophic velocity, on the T grid - vcg: \(v\) component of cyclogeostrophic velocity, on the T grid - ug, vg: Geostrophic velocities (if return_geos=True) - losses: Cyclogeostrophic imbalance per iteration (if return_losses=True)

Source code in jaxparrow/cyclogeostrophy/_fixed_point.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def fixed_point(
    lat_t: Float[jax.Array, "y x"],
    lon_t: Float[jax.Array, "y x"],
    ssh_t: Float[jax.Array, "y x"] = None,
    ug_t: Float[jax.Array, "y x"] = None,
    vg_t: Float[jax.Array, "y x"] = None,
    land_mask: Float[jax.Array, "y x"] = None,
    is_grid_rectilinear: bool | None = None,
    rotate_to_geographic: bool = True,
    return_geos: bool = False,
    return_losses: bool = False,
    n_it: int = 20,
    res_eps: float = 0.01
) -> CyclogeostrophyResult:
    """
    Computes the cyclogeostrophic Sea Surface Current (SSC) velocity field
    using the fixed-point method [Penven et al. (2014)](https://doi.org/10.1002/2013JC009528).

    There are two modes of operation:

    1. **SSH mode**: Provide ``lat_t``, ``lon_t``, ``ssh_t`` (and optionally ``land_mask``).
       Geostrophic velocities will be computed from SSH.

    2. **Geostrophic mode**: Provide ``lat_t``, ``lon_t``, ``ug_t``, ``vg_t``
       (and optionally ``land_mask``). Geostrophic velocities are provided on the T grid.

    Parameters
    ----------
    lat_t : Float[jax.Array, "y x"]
        Latitude of the T grid
    lon_t : Float[jax.Array, "y x"]
        Longitude of the T grid
    ssh_t : Float[jax.Array, "y x"], optional
        SSH field (on the T grid)

        Defaults to `None`, required if geostrophic velocities are not provided
    ug_t : Float[jax.Array, "y x"], optional
        U component of geostrophic velocity on T grid

        Defaults to `None`, required if ``ssh_t`` is not provided
    vg_t : Float[jax.Array, "y x"], optional
        V component of geostrophic velocity on T grid

        Defaults to `None`, required if ``ssh_t`` is not provided
    land_mask : Float[jax.Array, "y x"], optional
        Mask defining the marine area of the spatial domain; `1` or `True` stands for masked (i.e. land)

        If not provided, inferred from ``ssh_t`` or ``ug_t`` `nan` values

        Defaults to `None`
    is_grid_rectilinear : bool, optional
        If `True`, the grid is assumed to be rectilinear in geographic coordinates.
        If `False`, the grid is assumed to be curvilinear and the grid angle is computed from the grid spacing. 
        If `None`, the grid is assumed to be rectilinear if the grid angle computed from the grid spacing is close to zero everywhere, and curvilinear otherwise.

        Defaults to `None`
    rotate_to_geographic : bool, optional
        If `True`, rotates the output velocities from grid-relative to geographic coordinates.
        Rotation is performed using the grid angle computed from the grid spacing.
        If `False`, output velocities are in grid-relative coordinates.

        If using a rectilinear grid in geographic coordinates, set to `False` to avoid unnecessary rotation.

        Defaults to `True`
    return_geos : bool, optional
        If `True`, returns the geostrophic SSC velocity field in addition to the cyclogeostrophic one.

        Defaults to `False`
    return_losses : bool, optional
        If `True`, returns the losses (cyclogeostrophic imbalance) over iterations.

        Defaults to `False`
    n_it : int, optional
        Maximum number of iterations.

        Defaults to `20`
    res_eps : float, optional
        Residual tolerance of the iterative approach.
        When residuals are smaller, the iterative approach considers local convergence to cyclogeostrophy.

        Defaults to `0.01`

    Returns
    -------
    CyclogeostrophyResult
        Named tuple containing:
        - ``ucg``: $u$ component of cyclogeostrophic velocity, on the T grid
        - ``vcg``: $v$ component of cyclogeostrophic velocity, on the T grid
        - ``ug``, ``vg``: Geostrophic velocities (if ``return_geos=True``)
        - ``losses``: Cyclogeostrophic imbalance per iteration (if ``return_losses=True``)
    """
    setup = setup_cyclogeostrophy(
        lat_t, lon_t, ssh_t=ssh_t, ug_t=ug_t, vg_t=vg_t, land_mask=land_mask, is_grid_rectilinear=is_grid_rectilinear
    )

    ucg, vcg, losses = _fixed_point(
        setup.ug_t, setup.vg_t,
        setup.dx_t, setup.dy_t,
        setup.coriolis_factor_t,
        setup.land_mask, n_it, res_eps, return_losses
    )

    return assemble_result(
        ucg, vcg, setup, rotate_to_geographic, return_geos, return_losses=return_losses, losses=losses
    )

gradient_wind(lat_t, lon_t, ssh_t=None, ug_t=None, vg_t=None, land_mask=None, is_grid_rectilinear=None, rotate_to_geographic=True, return_geos=False)

Computes the cyclogeostrophic Sea Surface Current (SSC) velocity field using the gradient wind approximation.

There are two modes of operation:

  1. SSH mode: Provide lat_t, lon_t, ssh_t (and optionally land_mask). Geostrophic velocities will be computed from SSH.

  2. Geostrophic mode: Provide lat_t, lon_t, ug_t, vg_t (and optionally land_mask). Geostrophic velocities are provided on the T grid.

Parameters:

Name Type Description Default
lat_t Float[Array, 'y x']

Latitude of the T grid

required
lon_t Float[Array, 'y x']

Longitude of the T grid

required
ssh_t Float[Array, 'y x']

SSH field (on the T grid)

Defaults to None, required if geostrophic velocities are not provided

None
ug_t Float[Array, 'y x']

U component of geostrophic velocity on T grid

Defaults to None, required if ssh_t is not provided

None
vg_t Float[Array, 'y x']

V component of geostrophic velocity on T grid

Defaults to None, required if ssh_t is not provided

None
land_mask Float[Array, 'y x']

Mask defining the marine area of the spatial domain; 1 or True stands for masked (i.e. land)

If not provided, inferred from ssh_t or ug_t nan values

Defaults to None

None
is_grid_rectilinear bool

If True, the grid is assumed to be rectilinear in geographic coordinates. If False, the grid is assumed to be curvilinear and the grid angle is computed from the grid spacing. If None, the grid is assumed to be rectilinear if the grid angle computed from the grid spacing is close to zero everywhere, and curvilinear otherwise.

Defaults to None

None
rotate_to_geographic bool

If True, rotates the output velocities from grid-relative to geographic coordinates. Rotation is performed using the grid angle computed from the grid spacing. If False, output velocities are in grid-relative coordinates.

If using a rectilinear grid in geographic coordinates, set to False to avoid unnecessary rotation.

Defaults to True

True
return_geos bool

If True, returns the geostrophic SSC velocity field in addition to the cyclogeostrophic one.

Defaults to False

False

Returns:

Type Description
CyclogeostrophyResult

Named tuple containing: - ucg: \(u\) component of cyclogeostrophic velocity, on the T grid - vcg: \(v\) component of cyclogeostrophic velocity, on the T grid - ug, vg: Geostrophic velocities (if return_geos=True)

Source code in jaxparrow/cyclogeostrophy/_gradient_wind.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def gradient_wind(
    lat_t: Float[jax.Array, "y x"],
    lon_t: Float[jax.Array, "y x"],
    ssh_t: Float[jax.Array, "y x"] = None,
    ug_t: Float[jax.Array, "y x"] = None,
    vg_t: Float[jax.Array, "y x"] = None,
    land_mask: Float[jax.Array, "y x"] = None,
    is_grid_rectilinear: bool | None = None,
    rotate_to_geographic: bool = True,
    return_geos: bool = False
) -> CyclogeostrophyResult:
    """
    Computes the cyclogeostrophic Sea Surface Current (SSC) velocity field
    using the gradient wind approximation.

    There are two modes of operation:

    1. **SSH mode**: Provide ``lat_t``, ``lon_t``, ``ssh_t`` (and optionally ``land_mask``).
       Geostrophic velocities will be computed from SSH.

    2. **Geostrophic mode**: Provide ``lat_t``, ``lon_t``, ``ug_t``, ``vg_t``
       (and optionally ``land_mask``). Geostrophic velocities are provided on the T grid.

    Parameters
    ----------
    lat_t : Float[jax.Array, "y x"]
        Latitude of the T grid
    lon_t : Float[jax.Array, "y x"]
        Longitude of the T grid
    ssh_t : Float[jax.Array, "y x"], optional
        SSH field (on the T grid)

        Defaults to `None`, required if geostrophic velocities are not provided
    ug_t : Float[jax.Array, "y x"], optional
        U component of geostrophic velocity on T grid

        Defaults to `None`, required if ``ssh_t`` is not provided
    vg_t : Float[jax.Array, "y x"], optional
        V component of geostrophic velocity on T grid

        Defaults to `None`, required if ``ssh_t`` is not provided
    land_mask : Float[jax.Array, "y x"], optional
        Mask defining the marine area of the spatial domain; `1` or `True` stands for masked (i.e. land)

        If not provided, inferred from ``ssh_t`` or ``ug_t`` `nan` values

        Defaults to `None`
    is_grid_rectilinear : bool, optional
        If `True`, the grid is assumed to be rectilinear in geographic coordinates.
        If `False`, the grid is assumed to be curvilinear and the grid angle is computed from the grid spacing. 
        If `None`, the grid is assumed to be rectilinear if the grid angle computed from the grid spacing is close to zero everywhere, and curvilinear otherwise.

        Defaults to `None`
    rotate_to_geographic : bool, optional
        If `True`, rotates the output velocities from grid-relative to geographic coordinates.
        Rotation is performed using the grid angle computed from the grid spacing.
        If `False`, output velocities are in grid-relative coordinates.

        If using a rectilinear grid in geographic coordinates, set to `False` to avoid unnecessary rotation.

        Defaults to `True`
    return_geos : bool, optional
        If `True`, returns the geostrophic SSC velocity field in addition to the cyclogeostrophic one.

        Defaults to `False`

    Returns
    -------
    CyclogeostrophyResult
        Named tuple containing:
        - ``ucg``: $u$ component of cyclogeostrophic velocity, on the T grid
        - ``vcg``: $v$ component of cyclogeostrophic velocity, on the T grid
        - ``ug``, ``vg``: Geostrophic velocities (if ``return_geos=True``)
    """
    setup = setup_cyclogeostrophy(
        lat_t, lon_t, ssh_t=ssh_t, ug_t=ug_t, vg_t=vg_t, land_mask=land_mask, is_grid_rectilinear=is_grid_rectilinear
    )

    ucg, vcg = _gradient_wind(
        setup.ug_t, setup.vg_t,
        setup.dx_t, setup.dy_t, 
        setup.coriolis_factor_t, 
        setup.land_mask
    )

    return assemble_result(ucg, vcg, setup, rotate_to_geographic, return_geos)

minimization_based(lat_t, lon_t, ssh_t=None, ug_t=None, vg_t=None, land_mask=None, is_grid_rectilinear=None, rotate_to_geographic=True, return_geos=False, return_losses=False, n_it=2000, optim='sgd', optim_kwargs=None, regularization=None, reg_kwargs=None)

Computes the cyclogeostrophic Sea Surface Current (SSC) velocity field using our minimization-based method.

There are two modes of operation:

  1. SSH mode: Provide lat_t, lon_t, ssh_t (and optionally land_mask). Geostrophic velocities will be computed from SSH.

  2. Geostrophic mode: Provide lat_t, lon_t, ug_t, vg_t (and optionally land_mask). Geostrophic velocities are provided on the T grid and will be interpolated to U/V grids internally.

Parameters:

Name Type Description Default
lat_t Float[Array, 'y x']

Latitude of the T grid.

required
lon_t Float[Array, 'y x']

Longitude of the T grid.

required
ssh_t Float[Array, 'y x']

SSH field (on the T grid). Required if geostrophic velocities are not provided.

None
ug_t Float[Array, 'y x']

U component of geostrophic velocity on T grid. If provided with vg_t, bypasses SSH-based computation. Will be interpolated to U grid.

None
vg_t Float[Array, 'y x']

V component of geostrophic velocity on T grid. If provided with ug_t, bypasses SSH-based computation. Will be interpolated to V grid.

None
land_mask Float[Array, 'y x']

Mask defining the marine area of the spatial domain; 1 or True stands for masked (i.e. land)

If not provided, inferred from ssh_t or ug_t nan values

Defaults to None

None
is_grid_rectilinear bool

If True, the grid is assumed to be rectilinear in geographic coordinates. If False, the grid is assumed to be curvilinear and the grid angle is computed from the grid spacing. If None, the grid is assumed to be rectilinear if the grid angle computed from the grid spacing is close to zero everywhere, and curvilinear otherwise.

Defaults to None

None
rotate_to_geographic bool

If True, rotates the output velocities from grid-relative to geographic coordinates. Rotation is performed using the grid angle computed from the grid spacing. If False, output velocities are in grid-relative coordinates.

If using a rectilinear grid in geographic coordinates, set to False to avoid unnecessary rotation.

Defaults to True

True
return_geos bool

If True, returns the geostrophic SSC velocity field in addition to the cyclogeostrophic one.

Defaults to False

False
return_grids bool

If True, returns the U and V grids.

Defaults to True

required
return_losses bool

If True, returns the losses (cyclogeostrophic imbalance) over iterations.

Defaults to False

False
n_it int

Maximum number of iterations.

Defaults to 2000

2000
optim GradientTransformation | str

Optimizer to use. Can be an optax.GradientTransformation optimizer, or a string referring to such an optimizer.

Defaults to sgd

'sgd'
optim_kwargs dict

Optimizer arguments (such as learning rate, etc...).

If None, only the learning rate is enforced to 0.005

Defaults to None

None
regularization Callable

A regularization function added to the cyclogeostrophic loss at every iteration. Its signature is defined as follows:

  • Parameter names from {ucg_t, vcg_t, lat_t, lon_t, dx_t, dy_t, coriolis_factor_t, land_mask} are automatically provided, but only ucg_t and vcg_t are required.
  • Any other parameter names must be provided via reg_kwargs.

Must return a scalar.

Defaults to None

None
reg_kwargs dict

Additional keyword arguments passed to the regularization function. Values should be tracable JAX Pytrees.

Defaults to None

None

Returns:

Type Description
CyclogeostrophyResult

Named tuple containing: - ucg: \(u\) component of cyclogeostrophic velocity, on the T grid - vcg: \(v\) component of cyclogeostrophic velocity, on the T grid - ug, vg: Geostrophic velocities (if return_geos=True or optimized via regularization) - ssh: Optimized SSH field (if SSH regularization was used) - losses: Cyclogeostrophic imbalance per iteration (if return_losses=True)

Source code in jaxparrow/cyclogeostrophy/_minimization_based.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
def minimization_based(
    lat_t: Float[jax.Array, "y x"],
    lon_t: Float[jax.Array, "y x"],
    ssh_t: Float[jax.Array, "y x"] = None,
    ug_t: Float[jax.Array, "y x"] = None,
    vg_t: Float[jax.Array, "y x"] = None,
    land_mask: Float[jax.Array, "y x"] = None,
    is_grid_rectilinear: bool | None = None,
    rotate_to_geographic: bool = True,
    return_geos: bool = False,
    return_losses: bool = False,
    n_it: int = 2000,
    optim: optax.GradientTransformation | str = "sgd",
    optim_kwargs: dict = None,
    regularization: Callable = None,
    reg_kwargs: dict = None,
) -> CyclogeostrophyResult:
    """
    Computes the cyclogeostrophic Sea Surface Current (SSC) velocity field
    using our minimization-based method.

    There are two modes of operation:

    1. **SSH mode**: Provide ``lat_t``, ``lon_t``, ``ssh_t`` (and optionally ``land_mask``).
       Geostrophic velocities will be computed from SSH.

    2. **Geostrophic mode**: Provide ``lat_t``, ``lon_t``, ``ug_t``, ``vg_t``
       (and optionally ``land_mask``). Geostrophic velocities are provided on the T grid
       and will be interpolated to U/V grids internally.

    Parameters
    ----------
    lat_t : Float[jax.Array, "y x"]
        Latitude of the T grid.
    lon_t : Float[jax.Array, "y x"]
        Longitude of the T grid.
    ssh_t : Float[jax.Array, "y x"], optional
        SSH field (on the T grid). Required if geostrophic velocities are not provided.
    ug_t : Float[jax.Array, "y x"], optional
        U component of geostrophic velocity on T grid. If provided with ``vg_t``,
        bypasses SSH-based computation. Will be interpolated to U grid.
    vg_t : Float[jax.Array, "y x"], optional
        V component of geostrophic velocity on T grid. If provided with ``ug_t``,
        bypasses SSH-based computation. Will be interpolated to V grid.
    land_mask : Float[jax.Array, "y x"], optional
        Mask defining the marine area of the spatial domain; `1` or `True` stands for masked (i.e. land)

        If not provided, inferred from ``ssh_t`` or ``ug_t`` `nan` values

        Defaults to `None`
    is_grid_rectilinear : bool, optional
        If `True`, the grid is assumed to be rectilinear in geographic coordinates.
        If `False`, the grid is assumed to be curvilinear and the grid angle is computed from the grid spacing. 
        If `None`, the grid is assumed to be rectilinear if the grid angle computed from the grid spacing is close to zero everywhere, and curvilinear otherwise.

        Defaults to `None`
    rotate_to_geographic : bool, optional
        If `True`, rotates the output velocities from grid-relative to geographic coordinates.
        Rotation is performed using the grid angle computed from the grid spacing.
        If `False`, output velocities are in grid-relative coordinates.

        If using a rectilinear grid in geographic coordinates, set to `False` to avoid unnecessary rotation.

        Defaults to `True`
    return_geos : bool, optional
        If `True`, returns the geostrophic SSC velocity field in addition to the cyclogeostrophic one.

        Defaults to `False`
    return_grids : bool, optional
        If `True`, returns the U and V grids.

        Defaults to `True`
    return_losses : bool, optional
        If `True`, returns the losses (cyclogeostrophic imbalance) over iterations.

        Defaults to `False`
    n_it : int, optional
        Maximum number of iterations.

        Defaults to `2000`
    optim : optax.GradientTransformation | str, optional
        Optimizer to use.
        Can be an ``optax.GradientTransformation`` optimizer, or a ``string`` referring to such an optimizer.

        Defaults to `sgd`
    optim_kwargs : dict, optional
        Optimizer arguments (such as learning rate, etc...).

        If `None`, only the learning rate is enforced to `0.005`

        Defaults to `None`
    regularization : Callable, optional
        A regularization function added to the cyclogeostrophic loss at every iteration.
        Its signature is defined as follows:

        - Parameter names from ``{ucg_t, vcg_t, lat_t, lon_t, dx_t, dy_t, coriolis_factor_t, land_mask}`` are automatically provided,
        but only ``ucg_t`` and ``vcg_t`` are required.
        - Any other parameter names must be provided via ``reg_kwargs``.

        Must return a scalar.

        Defaults to `None`
    reg_kwargs : dict, optional
        Additional keyword arguments passed to the ``regularization`` function.
        Values should be tracable JAX Pytrees.

        Defaults to `None`

    Returns
    -------
    CyclogeostrophyResult
        Named tuple containing:
        - ``ucg``: $u$ component of cyclogeostrophic velocity, on the T grid
        - ``vcg``: $v$ component of cyclogeostrophic velocity, on the T grid
        - ``ug``, ``vg``: Geostrophic velocities (if ``return_geos=True`` or optimized via regularization)
        - ``ssh``: Optimized SSH field (if SSH regularization was used)
        - ``losses``: Cyclogeostrophic imbalance per iteration (if ``return_losses=True``)
    """
    setup = setup_cyclogeostrophy(
        lat_t, lon_t, ssh_t=ssh_t, ug_t=ug_t, vg_t=vg_t, land_mask=land_mask, is_grid_rectilinear=is_grid_rectilinear
    )

    if isinstance(optim, str):
        if optim_kwargs is None:
            optim_kwargs = {"learning_rate": 0.005}
        optim = getattr(optax, optim)(**optim_kwargs)
        optim = optax.chain(optax.clip(1.0), optim)  # Clip gradients to prevent instability
    elif not isinstance(optim, optax.GradientTransformation):
        raise TypeError(
            "optim should be an optax.GradientTransformation optimizer, or a string referring to such an optimizer."
        )

    # Handle regularization
    reg_wrapper = None
    if regularization is not None:
        reg_wrapper = _build_reg_wrapper(regularization, reg_kwargs)

    ucg, vcg, losses = _minimization_based(
        setup.ug_t, setup.vg_t,
        setup.dx_t, setup.dy_t,
        setup.coriolis_factor_t,
        setup.land_mask, n_it, optim,
        regularization=reg_wrapper, lat_t=lat_t, lon_t=lon_t,
        reg_kwargs=reg_kwargs,
    )

    return assemble_result(
        ucg, vcg, setup, rotate_to_geographic, return_geos, return_losses=return_losses, losses=losses,
    )

jaxparrow.geostrophy

geostrophy(ssh_t, lat_t, lon_t, land_mask=None, is_grid_rectilinear=None, rotate_to_geographic=True)

Computes the geostrophic velocity field from a Sea Surface Height (SSH) field.

Parameters:

Name Type Description Default
ssh_t Float[Array, 'y x']

SSH field (on the T grid)

required
lat_t Float[Array, 'y x']

Latitudes of the T grid

required
lon_t Float[Array, 'y x']

Longitudes of the T grid

required
land_mask Float[Array, 'y x']

Mask defining the marine area of the spatial domain; 1 or True stands for masked (i.e. land).

Defaults to None, in which case inferred from ssh_t nan values

None
is_grid_rectilinear bool

If True, the grid is assumed to be rectilinear in geographic coordinates. If False, the grid is assumed to be curvilinear and the grid angle is computed from the grid spacing. If None, the grid is assumed to be rectilinear if the grid angle computed from the grid spacing is close to zero everywhere, and curvilinear otherwise.

Defaults to None

None
rotate_to_geographic bool

If True, rotates the velocity field to geographic coordinates (eastward and northward components).

Defaults to True, in which case the returned velocity components are in geographic coordinates. If False, the returned velocity components are in grid coordinates (i.e. along the grid axes, which may not be aligned with geographic east and north directions).

True

Returns:

Name Type Description
ug_t Float[Array, 'y x']

\(u\) component of the geostrophic velocity field, on the T grid

vg_t Float[Array, 'y x']

\(v\) component of the geostrophic velocity field, on the T grid

Source code in jaxparrow/geostrophy.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def geostrophy(
    ssh_t: Float[jax.Array, "y x"],
    lat_t: Float[jax.Array, "y x"],
    lon_t: Float[jax.Array, "y x"],
    land_mask: Float[jax.Array, "y x"] = None,
    is_grid_rectilinear: bool | None = None,
    rotate_to_geographic: bool = True
) -> tuple[Float[jax.Array, "y x"], Float[jax.Array, "y x"]]:
    """
    Computes the geostrophic velocity field from a Sea Surface Height (SSH) field.

    Parameters
    ----------
    ssh_t : Float[jax.Array, "y x"]
        SSH field (on the T grid)
    lat_t : Float[jax.Array, "y x"]
        Latitudes of the T grid
    lon_t : Float[jax.Array, "y x"]
        Longitudes of the T grid
    land_mask : Float[jax.Array, "y x"], optional
        Mask defining the marine area of the spatial domain; `1` or `True` stands for masked (i.e. land).

        Defaults to `None`, in which case inferred from `ssh_t` `nan` values
    is_grid_rectilinear : bool, optional
        If `True`, the grid is assumed to be rectilinear in geographic coordinates.
        If `False`, the grid is assumed to be curvilinear and the grid angle is computed from the grid spacing. 
        If `None`, the grid is assumed to be rectilinear if the grid angle computed from the grid spacing is close to zero everywhere, and curvilinear otherwise.

        Defaults to `None`
    rotate_to_geographic : bool, optional
        If `True`, rotates the velocity field to geographic coordinates (eastward and northward components).

        Defaults to `True`, in which case the returned velocity components are in geographic coordinates. 
        If `False`, the returned velocity components are in grid coordinates (i.e. along the grid axes, which may not be aligned with geographic east and north directions).

    Returns
    -------
    ug_t : Float[jax.Array, "y x"]
        $u$ component of the geostrophic velocity field, on the T grid
    vg_t : Float[jax.Array, "y x"]
        $v$ component of the geostrophic velocity field, on the T grid
    """
    # Make sure the mask is initialized
    land_mask = sanitize.init_land_mask(ssh_t, land_mask)

    # Handle spurious and masked data
    ssh_t = sanitize.sanitize_data(ssh_t, jnp.nan, land_mask)

    ug_t, vg_t = _geostrophy(ssh_t, lat_t, lon_t, land_mask)

    # Handle masked data (set land cells to NaN)
    ug_t = sanitize.sanitize_data(ug_t, jnp.nan, land_mask)
    vg_t = sanitize.sanitize_data(vg_t, jnp.nan, land_mask)

    if rotate_to_geographic:
        grid_angle_i = None
        grid_angle_j = None
        if is_grid_rectilinear is None:
            # determine if the grid is rectilinear by checking the i-axis angle
            grid_angle_i, grid_angle_j = geometry.compute_grid_angle(lat_t, lon_t)
            is_grid_rectilinear = jnp.all(jnp.abs(grid_angle_i) < 1e-3)

        if not is_grid_rectilinear:
            if grid_angle_i is None:
                grid_angle_i, grid_angle_j = geometry.compute_grid_angle(lat_t, lon_t)
            ug_t, vg_t = geometry.rotate_to_geographic(ug_t, vg_t, grid_angle_i, grid_angle_j)

    return ug_t, vg_t

jaxparrow.utils

Utility modules for jaxparrow computations.

geometry

grid_spacing(lat, lon)

Computes the physical spacing associated with one grid-index step, used to transform derivatives to physical coordinates.

It makes use of the distance-on-a-sphere formula with Taylor expansion approximations of cos and arccos functions to avoid truncation issues.

Parameters:

Name Type Description Default
lat Float[Array, 'y x']

Latitude grid

required
lon Float[Array, 'y x']

Longitude grid

required

Returns:

Name Type Description
dx Float[Array, 'y x']

Spacing associated with one step in the x-index direction

dy Float[Array, 'y x']

Spacing associated with one step in the y-index direction

Source code in jaxparrow/utils/geometry.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def grid_spacing(
    lat: Float[jax.Array, "y x"], lon: Float[jax.Array, "y x"]
) -> tuple[
    Float[jax.Array, "y x"], 
    Float[jax.Array, "y x"], 
]:
    """
    Computes the physical spacing associated with one grid-index step, 
    used to transform derivatives to physical coordinates.

    It makes use of the distance-on-a-sphere formula with Taylor expansion approximations of `cos` and `arccos`
    functions to avoid truncation issues.

    Parameters
    ----------
    lat : Float[jax.Array, "y x"]
        Latitude grid
    lon : Float[jax.Array, "y x"]
        Longitude grid

    Returns
    -------
    dx : Float[jax.Array, "y x"]
        Spacing associated with one step in the x-index direction
    dy : Float[jax.Array, "y x"]
        Spacing associated with one step in the y-index direction
    """
    def physical_spacing(lat1, lat2, lon1, lon2):
        # convert to radians
        lat1_rad = jnp.radians(lat1)
        lat2_rad = jnp.radians(lat2)

        # difference in radians; normalize lon diff to [-180, 180] before radians to handle dateline
        dlon = lon2 - lon1
        dlon = (dlon + 180.0) % 360.0 - 180.0   # now in [-180,180]
        dlon_rad = jnp.radians(dlon)

        dlat_rad = jnp.radians(lat2 - lat1)

        # haversine distance
        a = jnp.sin(dlat_rad / 2.0) ** 2 + jnp.cos(lat1_rad) * jnp.cos(lat2_rad) * (jnp.sin(dlon_rad / 2.0) ** 2)
        c = 2.0 * jnp.arctan2(jnp.sqrt(a), jnp.sqrt(1.0 - a))
        d = EARTH_RADIUS * c

        return d

    # physical spacing
    dx = physical_spacing(lat[:, :-1], lat[:, 1:], lon[:, :-1], lon[:, 1:])
    dy = physical_spacing(lat[:-1, :], lat[1:, :], lon[:-1, :], lon[1:, :])

    dx = jnp.pad(dx, ((0, 0), (0, 1)), mode="edge")
    dy = jnp.pad(dy, ((0, 1), (0, 0)), mode="edge")

    return dx, dy

compute_grid_angle(lat, lon)

Computes the local angles of both grid axes relative to geographic east.

For curvilinear grids (e.g., SWOT swaths, tripolar grids), the grid axes are not aligned with geographic east-west/north-south directions. This function computes the rotation angles needed to transform velocity components between grid coordinates and geographic coordinates.

Parameters:

Name Type Description Default
lat Float[Array, 'lat lon']

Latitude grid

required
lon Float[Array, 'lat lon']

Longitude grid

required

Returns:

Name Type Description
angle_i Float[Array, 'lat lon']

Angle of the grid i-axis (axis=1) relative to geographic east, in radians, measured counterclockwise. Range is [-pi, pi].

angle_j Float[Array, 'lat lon']

Angle of the grid j-axis (axis=0) relative to geographic east, in radians, measured counterclockwise. Range is [-pi, pi].

Notes

Both angles are computed using the initial bearing formula between adjacent grid points. For a standard rectilinear lat/lon grid: angle_i ≈ 0 (i-axis ≈ east) and angle_j ≈ π/2 (j-axis ≈ north).

For grids where the axes are not orthogonal in the geographic sense, or where angle_j - angle_i ≠ π/2 (e.g. non-orthogonal or left-handed grids), the rotation functions jaxparrow.utils.geometry.rotate_to_geographic and jaxparrow.utils.geometry.rotate_to_grid handle these cases correctly via the determinant det = sin(angle_j - angle_i).

Source code in jaxparrow/utils/geometry.py
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
def compute_grid_angle(
    lat: Float[jax.Array, "lat lon"],
    lon: Float[jax.Array, "lat lon"]
) -> tuple[Float[jax.Array, "lat lon"], Float[jax.Array, "lat lon"]]:
    """
    Computes the local angles of both grid axes relative to geographic east.

    For curvilinear grids (e.g., SWOT swaths, tripolar grids), the grid axes are not aligned
    with geographic east-west/north-south directions. This function computes the rotation angles
    needed to transform velocity components between grid coordinates and geographic coordinates.

    Parameters
    ----------
    lat : Float[jax.Array, "lat lon"]
        Latitude grid
    lon : Float[jax.Array, "lat lon"]
        Longitude grid

    Returns
    -------
    angle_i : Float[jax.Array, "lat lon"]
        Angle of the grid i-axis (axis=1) relative to geographic east, in radians,
        measured counterclockwise. Range is [-pi, pi].
    angle_j : Float[jax.Array, "lat lon"]
        Angle of the grid j-axis (axis=0) relative to geographic east, in radians,
        measured counterclockwise. Range is [-pi, pi].

    Notes
    -----
    Both angles are computed using the initial bearing formula between adjacent grid points.
    For a standard rectilinear lat/lon grid: ``angle_i ≈ 0`` (i-axis ≈ east) and
    ``angle_j ≈ π/2`` (j-axis ≈ north).

    For grids where the axes are not orthogonal in the geographic sense, or where
    ``angle_j - angle_i ≠ π/2`` (e.g. non-orthogonal or left-handed grids), the rotation
    functions [`jaxparrow.utils.geometry.rotate_to_geographic`][] and [`jaxparrow.utils.geometry.rotate_to_grid`][] 
    handle these cases correctly via the determinant ``det = sin(angle_j - angle_i)``.
    """
    angle_i = _axis_bearing_to_angle(lat, lon, axis=1)
    angle_j = _axis_bearing_to_angle(lat, lon, axis=0)
    return angle_i, angle_j

rotate_to_geographic(u, v, angle_i, angle_j)

Rotates velocity components from grid coordinates to geographic coordinates (eastward and northward components).

Uses the full 2-column rotation matrix defined by the actual directions of both grid axes, correctly handling any grid orientation including non-orthogonal or left-handed grids.

Parameters:

Name Type Description Default
u Float[Array, 'y x']

Velocity component along the grid i-axis (axis=1)

required
v Float[Array, 'y x']

Velocity component along the grid j-axis (axis=0)

required
angle_i Float[Array, 'y x']

Angle of the grid i-axis (axis=1) relative to geographic east, in radians (counterclockwise positive). Typically obtained from jaxparrow.utils.geometry.compute_grid_angle.

required
angle_j Float[Array, 'y x']

Angle of the grid j-axis (axis=0) relative to geographic east, in radians (counterclockwise positive). Typically obtained from jaxparrow.utils.geometry.compute_grid_angle.

required

Returns:

Name Type Description
u_east Float[Array, 'y x']

Eastward velocity component

v_north Float[Array, 'y x']

Northward velocity component

Source code in jaxparrow/utils/geometry.py
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
def rotate_to_geographic(
    u: Float[jax.Array, "y x"],
    v: Float[jax.Array, "y x"],
    angle_i: Float[jax.Array, "y x"],
    angle_j: Float[jax.Array, "y x"]
) -> tuple[Float[jax.Array, "y x"], Float[jax.Array, "y x"]]:
    """
    Rotates velocity components from grid coordinates to geographic coordinates (eastward and northward components).

    Uses the full 2-column rotation matrix defined by the actual directions of both grid axes,
    correctly handling any grid orientation including non-orthogonal or left-handed grids.

    Parameters
    ----------
    u : Float[jax.Array, "y x"]
        Velocity component along the grid i-axis (axis=1)
    v : Float[jax.Array, "y x"]
        Velocity component along the grid j-axis (axis=0)
    angle_i : Float[jax.Array, "y x"]
        Angle of the grid i-axis (axis=1) relative to geographic east, in radians
        (counterclockwise positive). Typically obtained from [`jaxparrow.utils.geometry.compute_grid_angle`][].
    angle_j : Float[jax.Array, "y x"]
        Angle of the grid j-axis (axis=0) relative to geographic east, in radians
        (counterclockwise positive). Typically obtained from [`jaxparrow.utils.geometry.compute_grid_angle`][].

    Returns
    -------
    u_east : Float[jax.Array, "y x"]
        Eastward velocity component
    v_north : Float[jax.Array, "y x"]
        Northward velocity component
    """
    cos_i = jnp.cos(angle_i)
    sin_i = jnp.sin(angle_i)
    cos_j = jnp.cos(angle_j)
    sin_j = jnp.sin(angle_j)

    # det = sin(angle_j - angle_i): +1 for right-handed grids, -1 for left-handed grids
    det = cos_i * sin_j - sin_i * cos_j

    u_east = (u * cos_i + v * cos_j) / det
    v_north = (u * sin_i + v * sin_j) / det

    return u_east, v_north

rotate_to_grid(u, v, angle_i, angle_j)

Rotates velocity components from geographic coordinates (eastward and northward) to grid coordinates.

This is the inverse of jaxparrow.utils.geometry.rotate_to_geographic.

Parameters:

Name Type Description Default
u Float[Array, 'y x']

Eastward velocity component

required
v Float[Array, 'y x']

Northward velocity component

required
angle_i Float[Array, 'y x']

Angle of the grid i-axis (axis=1) relative to geographic east, in radians (counterclockwise positive). Typically obtained from jaxparrow.utils.geometry.compute_grid_angle.

required
angle_j Float[Array, 'y x']

Angle of the grid j-axis (axis=0) relative to geographic east, in radians (counterclockwise positive). Typically obtained from jaxparrow.utils.geometry.compute_grid_angle.

required

Returns:

Name Type Description
u_grid Float[Array, 'y x']

Velocity component along the grid i-axis (axis=1)

v_grid Float[Array, 'y x']

Velocity component along the grid j-axis (axis=0)

Source code in jaxparrow/utils/geometry.py
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
def rotate_to_grid(
    u: Float[jax.Array, "y x"],
    v: Float[jax.Array, "y x"],
    angle_i: Float[jax.Array, "y x"],
    angle_j: Float[jax.Array, "y x"]
) -> tuple[Float[jax.Array, "y x"], Float[jax.Array, "y x"]]:
    """
    Rotates velocity components from geographic coordinates (eastward and northward) to grid coordinates.

    This is the inverse of [`jaxparrow.utils.geometry.rotate_to_geographic`][].

    Parameters
    ----------
    u : Float[jax.Array, "y x"]
        Eastward velocity component
    v : Float[jax.Array, "y x"]
        Northward velocity component
    angle_i : Float[jax.Array, "y x"]
        Angle of the grid i-axis (axis=1) relative to geographic east, in radians
        (counterclockwise positive). Typically obtained from [`jaxparrow.utils.geometry.compute_grid_angle`][].
    angle_j : Float[jax.Array, "y x"]
        Angle of the grid j-axis (axis=0) relative to geographic east, in radians
        (counterclockwise positive). Typically obtained from [`jaxparrow.utils.geometry.compute_grid_angle`][].

    Returns
    -------
    u_grid : Float[jax.Array, "y x"]
        Velocity component along the grid i-axis (axis=1)
    v_grid : Float[jax.Array, "y x"]
        Velocity component along the grid j-axis (axis=0)
    """
    cos_i = jnp.cos(angle_i)
    sin_i = jnp.sin(angle_i)
    cos_j = jnp.cos(angle_j)
    sin_j = jnp.sin(angle_j)

    u_grid = u * sin_j - v * cos_j
    v_grid = -u * sin_i + v * cos_i

    return u_grid, v_grid

coriolis_factor(lat)

Computes the Coriolis factor from a latitude grid.

Parameters:

Name Type Description Default
lat Float[Array, 'y x']

Latitudes grid

required

Returns:

Name Type Description
cf Float[Array, 'y x']

Coriolis factor grid

Source code in jaxparrow/utils/geometry.py
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
def coriolis_factor(lat: Float[jax.Array, "y x"]) -> Float[jax.Array, "y x"]:
    """
    Computes the Coriolis factor from a latitude grid.

    Parameters
    ----------
    lat : Float[jax.Array, "y x"]
        Latitudes grid

    Returns
    -------
    cf : Float[jax.Array, "y x"]
        Coriolis factor grid
    """
    return 2 * EARTH_ANG_SPEED * jnp.sin((jnp.radians(lat)))

compute_uv_grids(lat_t, lon_t)

Computes the U and V grids associated to a T grid following NEMO convention.

Parameters:

Name Type Description Default
lat_t Float[Array, 'lat lon']

Latitudes of the T grid

required
lon_t Float[Array, 'lat lon']

Longitudes of the T grid

required

Returns:

Name Type Description
lat_u Float[Array, 'lat lon']

Latitudes of the U grid

lon_u Float[Array, 'lat lon']

Longitudes of the U grid

lat_v Float[Array, 'lat lon']

Latitudes of the V grid

lon_v Float[Array, 'lat lon']

Longitudes of the V grid

Source code in jaxparrow/utils/geometry.py
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
def compute_uv_grids(
    lat_t: Float[jax.Array, "lat lon"],
    lon_t: Float[jax.Array, "lat lon"]
) -> tuple[
    Float[jax.Array, "lat lon"], Float[jax.Array, "lat lon"], Float[jax.Array, "lat lon"], Float[jax.Array, "lat lon"]
]:
    """
    Computes the U and V grids associated to a T grid following NEMO convention.

    Parameters
    ----------
    lat_t : Float[jax.Array, "lat lon"]
        Latitudes of the T grid
    lon_t : Float[jax.Array, "lat lon"]
        Longitudes of the T grid

    Returns
    -------
    lat_u : Float[jax.Array, "lat lon"]
        Latitudes of the U grid
    lon_u : Float[jax.Array, "lat lon"]
        Longitudes of the U grid
    lat_v : Float[jax.Array, "lat lon"]
        Latitudes of the V grid
    lon_v : Float[jax.Array, "lat lon"]
        Longitudes of the V grid
    """
    lat_u = interpolation(lat_t, axis=1, padding="right")
    lat_u = lat_u.at[:, -1].set(2 * lat_t[:, -1] - lat_t[:, -2])
    lon_u = interpolation(lon_t, axis=1, padding="right")
    lon_u = lon_u.at[:, -1].set(2 * lon_t[:, -1] - lon_t[:, -2])

    lat_v = interpolation(lat_t, axis=0, padding="right")
    lat_v = lat_v.at[-1, :].set(2 * lat_t[-1, :] - lat_t[-2, :])
    lon_v = interpolation(lon_t, axis=0, padding="right")
    lon_v = lon_v.at[-1, :].set(2 * lon_t[-1, :] - lon_t[-2, :])

    return lat_u, lon_u, lat_v, lon_v

kinematics

magnitude(u, v, land_mask=None, uv_on_t=True)

Computes the magnitude (azimuthal velocity) of a 2d velocity field.

The velocity field can be provided either on the T grid (uv_on_t=True) or on the U/V grids (uv_on_t=False).

Parameters:

Name Type Description Default
u Float[Array, 'y x']

\(u\) component of the velocity field (on the U or T grid)

required
v Float[Array, 'y x']

\(v\) component of the velocity field (on the V or T grid)

required
land_mask Float[Array, 'y x']

Mask defining the marine area of the spatial domain; 1 or True stands for masked (i.e. land)

Defaults to None

None
uv_on_t bool

If True, the velocity components are assumed to be located on the U and V grids, and are interpolated to the T one (following NEMO convention). If False, the velocity components are assumed to be located on the T grid, and interpolation is not needed.

Defaults to True

True

Returns:

Name Type Description
magnitude Float[Array, 'y x']

Magnitude of the velocity field, on the T grid

Source code in jaxparrow/utils/kinematics.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def magnitude(
    u: Float[jax.Array, "y x"],
    v: Float[jax.Array, "y x"],
    land_mask: Float[jax.Array, "y x"] = None,
    uv_on_t: bool = True
) -> Float[jax.Array, "y x"]:
    """
    Computes the magnitude (azimuthal velocity) of a 2d velocity field.

    The velocity field can be provided either on the T grid (``uv_on_t=True``) or on the U/V grids (``uv_on_t=False``).

    Parameters
    ----------
    u : Float[jax.Array, "y x"]
        $u$ component of the velocity field (on the U or T grid)
    v : Float[jax.Array, "y x"]
        $v$ component of the velocity field (on the V or T grid)
    land_mask : Float[jax.Array, "y x"], optional
        Mask defining the marine area of the spatial domain; `1` or `True` stands for masked (i.e. land)

        Defaults to `None`
    uv_on_t : bool, optional
        If `True`, the velocity components are assumed to be located on the U and V grids,
        and are interpolated to the T one (following NEMO convention).
        If `False`, the velocity components are assumed to be located on the T grid, and interpolation is not needed.

        Defaults to `True`

    Returns
    -------
    magnitude : Float[jax.Array, "y x"]
        Magnitude of the velocity field, on the T grid
    """
    land_mask = init_land_mask(u, land_mask)

    if not uv_on_t:
        u = interpolation(u, axis=1, padding="left", land_mask=land_mask)  # U(i), U(i+1) -> T(i+1)
        v = interpolation(v, axis=0, padding="left", land_mask=land_mask)  # V(j), V(j+1) -> T(j+1)

    magn = jnp.sqrt(u** 2 + v ** 2)
    magn = sanitize_data(magn, jnp.nan, land_mask)

    return magn

kinetic_energy(u, v, land_mask=None, uv_on_t=True)

Computes the Kinetic Energy (KE) of a velocity field.

The velocity field can be provided either on the T grid (uv_on_t=True) or on the U/V grids (uv_on_t=False).

Parameters:

Name Type Description Default
u Float[Array, 'y x']

\(u\) component of the velocity field (on the U grid)

required
v Float[Array, 'y x']

\(v\) component of the velocity field (on the V grid)

required
land_mask Float[Array, 'y x']

Mask defining the marine area of the spatial domain; 1 or True stands for masked (i.e. land).

Defaults to None

None
uv_on_t bool

If True, the velocity components are assumed to be located on the U and V grids, and are interpolated to the T one (following NEMO convention). If False, the velocity components are assumed to be located on the T grid, and interpolation is not needed.

Defaults to True

True

Returns:

Name Type Description
kinetic_energy Float[Array, 'y x']

The Kinetic Energy on the T grid

Source code in jaxparrow/utils/kinematics.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
def kinetic_energy(
    u: Float[jax.Array, "y x"],
    v: Float[jax.Array, "y x"],
    land_mask: Float[jax.Array, "y x"] = None,
    uv_on_t: bool = True,
) -> Float[jax.Array, "y x"]:
    """
    Computes the Kinetic Energy (KE) of a velocity field.

    The velocity field can be provided either on the T grid (``uv_on_t=True``) or on the U/V grids (``uv_on_t=False``).

    Parameters
    ----------
    u : Float[jax.Array, "y x"]
        $u$ component of the velocity field (on the U grid)
    v : Float[jax.Array, "y x"]
        $v$ component of the velocity field (on the V grid)
    land_mask : Float[jax.Array, "y x"], optional
        Mask defining the marine area of the spatial domain; `1` or `True` stands for masked (i.e. land).

        Defaults to `None`
    uv_on_t : bool, optional
        If `True`, the velocity components are assumed to be located on the U and V grids,
        and are interpolated to the T one (following NEMO convention).
        If `False`, the velocity components are assumed to be located on the T grid, and interpolation is not needed.

        Defaults to `True`

    Returns
    -------
    kinetic_energy : Float[jax.Array, "y x"]
        The Kinetic Energy on the T grid
    """
    # Make sure the mask is initialized
    land_mask = init_land_mask(u, land_mask)

    if not uv_on_t:
        u = interpolation(u, axis=1, padding="left", land_mask=land_mask)  # U(i), U(i+1) -> T(i+1)
        v = interpolation(v, axis=0, padding="left", land_mask=land_mask)  # V(j), V(j+1) -> T(j+1)

    ke = (u ** 2 + v ** 2) / 2
    ke = sanitize_data(ke, jnp.nan, land_mask)

    return ke

vorticity(u, v, lat_t=None, lon_t=None, lat_u=None, lon_u=None, lat_v=None, lon_v=None, land_mask=None, uv_on_t=True, normalize_by_coriolis=True)

Computes the relative vorticity of a velocity field.

The velocity field can be provided either on the T grid (uv_on_t=True) or on the U/V grids (uv_on_t=False).

If provided, the lat_u, lon_u, lat_v, and lon_v are expected to follow the NEMO convention.

Parameters:

Name Type Description Default
u Float[Array, 'y x']

\(u\) component of the velocity field

required
v Float[Array, 'y x']

\(v\) component of the velocity field

required
lat_t Float[Array, 'y x']

Latitudes of the T grid.

If lat_u, lon_u, lat_v, and lon_v are not provided, lat_t and lon_t must be provided to compute them.

Defaults to None

None
lon_t Float[Array, 'y x']

Longitudes of the T grid.

If lat_u, lon_u, lat_v, and lon_v are not provided, lat_t and lon_t must be provided to compute them.

Defaults to None

None
lat_u Float[Array, 'y x']

Latitudes of the U grid.

Defaults to None

None
lon_u Float[Array, 'y x']

Longitudes of the U grid.

Defaults to None

None
lat_v Float[Array, 'y x']

Latitudes of the V grid.

Defaults to None

None
lon_v Float[Array, 'y x']

Longitudes of the V grid.

Defaults to None

None
land_mask Float[Array, 'y x']

Mask defining the marine area of the spatial domain; 1 or True stands for masked (i.e. land)

None
uv_on_t bool

If True, the velocity components are assumed to be located on the T grid (this is important when manipulating staggered grids)

Defaults to True

True
normalize_by_coriolis bool

If True, returns the vorticity normalized by the Coriolis factor

Defaults to True

True

Returns:

Name Type Description
vorticity Float[Array, 'y x']

The vorticity on the T grid, normalized by the Coriolis factor if normalize_by_coriolis=True

Source code in jaxparrow/utils/kinematics.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
def vorticity(
    u: Float[jax.Array, "y x"],
    v: Float[jax.Array, "y x"],
    lat_t: Float[jax.Array, "y x"] = None,
    lon_t: Float[jax.Array, "y x"] = None,
    lat_u: Float[jax.Array, "y x"] = None,
    lon_u: Float[jax.Array, "y x"] = None,
    lat_v: Float[jax.Array, "y x"] = None,
    lon_v: Float[jax.Array, "y x"] = None,
    land_mask: Float[jax.Array, "y x"] = None,
    uv_on_t: bool = True,
    normalize_by_coriolis: bool = True
) -> Float[jax.Array, "y x"]:
    """
    Computes the relative vorticity of a velocity field.

    The velocity field can be provided either on the T grid (``uv_on_t=True``) or on the U/V grids (``uv_on_t=False``).

    If provided, the ``lat_u``, ``lon_u``, ``lat_v``, and ``lon_v`` are expected to follow the NEMO convention.

    Parameters
    ----------
    u : Float[jax.Array, "y x"]
        $u$ component of the velocity field
    v : Float[jax.Array, "y x"]
        $v$ component of the velocity field
    lat_t : Float[jax.Array, "y x"], optional
        Latitudes of the T grid.

        If ``lat_u``, ``lon_u``, ``lat_v``, and ``lon_v`` are not provided, ``lat_t`` and ``lon_t`` must be provided to compute them.

        Defaults to `None`
    lon_t : Float[jax.Array, "y x"], optional
        Longitudes of the T grid.

        If ``lat_u``, ``lon_u``, ``lat_v``, and ``lon_v`` are not provided, ``lat_t`` and ``lon_t`` must be provided to compute them.

        Defaults to `None`
    lat_u : Float[jax.Array, "y x"], optional
        Latitudes of the U grid.

        Defaults to `None`
    lon_u : Float[jax.Array, "y x"], optional
        Longitudes of the U grid.

        Defaults to `None`
    lat_v : Float[jax.Array, "y x"], optional
        Latitudes of the V grid.

        Defaults to `None`
    lon_v : Float[jax.Array, "y x"], optional
        Longitudes of the V grid.

        Defaults to `None`
    land_mask : Float[jax.Array, "y x"], optional
        Mask defining the marine area of the spatial domain; `1` or `True` stands for masked (i.e. land)
    uv_on_t : bool, optional
        If `True`, the velocity components are assumed to be located on the T grid 
        (this is important when manipulating staggered grids)

        Defaults to `True`
    normalize_by_coriolis : bool, optional
        If `True`, returns the vorticity normalized by the Coriolis factor

        Defaults to `True`

    Returns
    -------
    vorticity : Float[jax.Array, "y x"]
        The vorticity on the T grid, normalized by the Coriolis factor if ``normalize_by_coriolis=True``
    """
    u, v, lat_t, lon_t, dx, dy, land_mask = setup_kinematics(
        u, v, lat_t, lon_t, lat_u, lon_u, lat_v, lon_v, land_mask, uv_on_t
    )

    _, du_y = horizontal_derivatives(u, dx=dx, dy=dy, land_mask=land_mask)
    dv_x, _ = horizontal_derivatives(v, dx=dx, dy=dy, land_mask=land_mask)

    vort = dv_x - du_y

    if normalize_by_coriolis:
        f = coriolis_factor(lat_t)
        vort /= f

    vort = sanitize_data(vort, jnp.nan, land_mask)

    return vort

strain_rate(u, v, lat_t=None, lon_t=None, lat_u=None, lon_u=None, lat_v=None, lon_v=None, land_mask=None, uv_on_t=True, normalize_by_coriolis=True)

Computes the strain rate magnitude of a velocity field.

The velocity field can be provided either on the T grid (uv_on_t=True) or on the U/V grids (uv_on_t=False).

If provided, the lat_u, lon_u, lat_v, and lon_v are expected to follow the NEMO convention.

Parameters:

Name Type Description Default
u Float[Array, 'y x']

\(u\) component of the velocity field

required
v Float[Array, 'y x']

\(v\) component of the velocity field

required
lat_t Float[Array, 'y x']

Latitudes of the T grid. Defaults to None

None
lon_t Float[Array, 'y x']

Longitudes of the T grid. Defaults to None

None
lat_u Float[Array, 'y x']

Latitudes of the U grid. Defaults to None

None
lon_u Float[Array, 'y x']

Longitudes of the U grid. Defaults to None

None
lat_v Float[Array, 'y x']

Latitudes of the V grid. Defaults to None

None
lon_v Float[Array, 'y x']

Longitudes of the V grid. Defaults to None

None
land_mask Float[Array, 'y x']

Mask defining the marine area of the spatial domain; 1 or True stands for masked (i.e. land)

None
uv_on_t bool

If True, the velocity components are assumed to be located on the T grid (this is important when manipulating staggered grids)

Defaults to True

True
normalize_by_coriolis bool

If True, returns the strain rate normalized by the Coriolis factor

Defaults to True

True

Returns:

Name Type Description
strain_rate Float[Array, 'y x']

The strain rate magnitude on the T grid, normalized by the Coriolis factor if normalize_by_coriolis=True

Source code in jaxparrow/utils/kinematics.py
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
def strain_rate(
    u: Float[jax.Array, "y x"],
    v: Float[jax.Array, "y x"],
    lat_t: Float[jax.Array, "y x"] = None,
    lon_t: Float[jax.Array, "y x"] = None,
    lat_u: Float[jax.Array, "y x"] = None,
    lon_u: Float[jax.Array, "y x"] = None,
    lat_v: Float[jax.Array, "y x"] = None,
    lon_v: Float[jax.Array, "y x"] = None,
    land_mask: Float[jax.Array, "y x"] = None,
    uv_on_t: bool = True,
    normalize_by_coriolis: bool = True
) -> Float[jax.Array, "y x"]:
    """
    Computes the strain rate magnitude of a velocity field.

    The velocity field can be provided either on the T grid (``uv_on_t=True``) or on the U/V grids (``uv_on_t=False``).

    If provided, the ``lat_u``, ``lon_u``, ``lat_v``, and ``lon_v`` are expected to follow the NEMO convention.

    Parameters
    ----------
    u : Float[jax.Array, "y x"]
        $u$ component of the velocity field
    v : Float[jax.Array, "y x"]
        $v$ component of the velocity field
    lat_t : Float[jax.Array, "y x"], optional
        Latitudes of the T grid.
        Defaults to `None`
    lon_t : Float[jax.Array, "y x"], optional
        Longitudes of the T grid.
        Defaults to `None`
    lat_u : Float[jax.Array, "y x"], optional
        Latitudes of the U grid.
        Defaults to `None`
    lon_u : Float[jax.Array, "y x"], optional
        Longitudes of the U grid.
        Defaults to `None`
    lat_v : Float[jax.Array, "y x"], optional
        Latitudes of the V grid.
        Defaults to `None`
    lon_v : Float[jax.Array, "y x"], optional
        Longitudes of the V grid.
        Defaults to `None`
    land_mask : Float[jax.Array, "y x"], optional
        Mask defining the marine area of the spatial domain; `1` or `True` stands for masked (i.e. land)
    uv_on_t : bool, optional
        If `True`, the velocity components are assumed to be located on the T grid 
        (this is important when manipulating staggered grids)

        Defaults to `True`
    normalize_by_coriolis : bool, optional
        If `True`, returns the strain rate normalized by the Coriolis factor

        Defaults to `True`

    Returns
    -------
    strain_rate : Float[jax.Array, "y x"]
        The strain rate magnitude on the T grid, normalized by the Coriolis factor if ``normalize_by_coriolis=True``
    """
    u, v, lat_t, lon_t, dx, dy, land_mask = setup_kinematics(
        u, v, lat_t, lon_t, lat_u, lon_u, lat_v, lon_v, land_mask, uv_on_t
    )

    du_x, du_y = horizontal_derivatives(u, dx=dx, dy=dy, land_mask=land_mask)
    dv_x, dv_y = horizontal_derivatives(v, dx=dx, dy=dy, land_mask=land_mask)

    strain = jnp.sqrt((du_x - dv_y) ** 2 + (dv_x + du_y) ** 2)

    if normalize_by_coriolis:
        f = coriolis_factor(lat_t)
        strain /= f

    strain = sanitize_data(strain, jnp.nan, land_mask)

    return strain

radius_of_curvature(u, v, lat_t=None, lon_t=None, lat_u=None, lon_u=None, lat_v=None, lon_v=None, land_mask=None, uv_on_t=True)

Computes the radius of curvature of a 2d velocity field.

The velocity field can be provided either on the T grid (uv_on_t=True) or on the U/V grids (uv_on_t=False).

If provided, the lat_u, lon_u, lat_v, and lon_v are expected to follow the NEMO convention.

Parameters:

Name Type Description Default
u Float[Array, 'y x']

\(u\) component of the velocity field

required
v Float[Array, 'y x']

\(v\) component of the velocity field

required
lat_t Float[Array, 'y x']

Latitudes of the T grid.

If lat_u, lon_u, lat_v, and lon_v are not provided, lat_t and lon_t must be provided to compute them.

Defaults to None

None
lon_t Float[Array, 'y x']

Longitudes of the T grid.

If lat_u, lon_u, lat_v, and lon_v are not provided, lat_t and lon_t must be provided to compute them.

Defaults to None

None
lat_u Float[Array, 'y x']

Latitudes of the U grid.

Defaults to None

None
lon_u Float[Array, 'y x']

Longitudes of the U grid.

Defaults to None

None
lat_v Float[Array, 'y x']

Latitudes of the V grid.

Defaults to None

None
lon_v Float[Array, 'y x']

Longitudes of the V grid.

Defaults to None

None
land_mask Float[Array, 'y x']

Mask defining the marine area of the spatial domain; 1 or True stands for masked (i.e. land)

None
uv_on_t bool

If True, the velocity components are assumed to be located on the T grid (this is important when manipulating staggered grids)

Defaults to True

True

Returns:

Name Type Description
rc Float[Array, 'y x']

The radius of curvature of the velocity field in meters, on the T grid

Source code in jaxparrow/utils/kinematics.py
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
def radius_of_curvature(
    u: Float[jax.Array, "y x"],
    v: Float[jax.Array, "y x"],
    lat_t: Float[jax.Array, "y x"] = None,
    lon_t: Float[jax.Array, "y x"] = None,
    lat_u: Float[jax.Array, "y x"] = None,
    lon_u: Float[jax.Array, "y x"] = None,
    lat_v: Float[jax.Array, "y x"] = None,
    lon_v: Float[jax.Array, "y x"] = None,
    land_mask: Float[jax.Array, "y x"] = None,
    uv_on_t: bool = True,
) -> Float[jax.Array, "y x"]:
    """
    Computes the radius of curvature of a 2d velocity field.

    The velocity field can be provided either on the T grid (``uv_on_t=True``) or on the U/V grids (``uv_on_t=False``).

    If provided, the ``lat_u``, ``lon_u``, ``lat_v``, and ``lon_v`` are expected to follow the NEMO convention.

    Parameters
    ----------
    u : Float[jax.Array, "y x"]
        $u$ component of the velocity field
    v : Float[jax.Array, "y x"]
        $v$ component of the velocity field
    lat_t : Float[jax.Array, "y x"], optional
        Latitudes of the T grid.

        If ``lat_u``, ``lon_u``, ``lat_v``, and ``lon_v`` are not provided, ``lat_t`` and ``lon_t`` must be provided to compute them.

        Defaults to `None`
    lon_t : Float[jax.Array, "y x"], optional
        Longitudes of the T grid.

        If ``lat_u``, ``lon_u``, ``lat_v``, and ``lon_v`` are not provided, ``lat_t`` and ``lon_t`` must be provided to compute them.

        Defaults to `None`
    lat_u : Float[jax.Array, "y x"], optional
        Latitudes of the U grid.

        Defaults to `None`
    lon_u : Float[jax.Array, "y x"], optional
        Longitudes of the U grid.

        Defaults to `None`
    lat_v : Float[jax.Array, "y x"], optional
        Latitudes of the V grid.

        Defaults to `None`
    lon_v : Float[jax.Array, "y x"], optional
        Longitudes of the V grid.

        Defaults to `None`
    land_mask : Float[jax.Array, "y x"], optional
        Mask defining the marine area of the spatial domain; `1` or `True` stands for masked (i.e. land)
    uv_on_t : bool, optional
        If `True`, the velocity components are assumed to be located on the T grid 
        (this is important when manipulating staggered grids)

        Defaults to `True`

    Returns
    -------
    rc : Float[jax.Array, "y x"]
        The radius of curvature of the velocity field in meters, on the T grid
    """
    u, v, lat_t, lon_t, dx, dy, land_mask = setup_kinematics(
        u, v, lat_t, lon_t, lat_u, lon_u, lat_v, lon_v, land_mask, uv_on_t
    )

    return _radius_of_curvature(u, v, dx, dy, land_mask)

operators

interpolation(field, axis, padding, land_mask=None)

Interpolates the values of a field along a given axis (0 for lat/y, 1 for lon/x), applying padding to the left (i.e. West if axis=1, South if axis=0) or to the right (i.e. East if axis=1, North if axis=0) of the domain.

An open boundary condition is applied:

  • At domain edges: the interpolated value equals the nearest interior value
  • At land/NaN boundaries: if one of the two values is NaN, the valid value is used; if both are NaN, the result is NaN

Parameters:

Name Type Description Default
field Float[Array, 'y x']

Field to interpolate

required
axis Literal[0, 1]

Axis along which interpolation is performed

required
padding Literal['left', 'right']

Padding direction. For example, following NEMO convention, interpolating from U to T points requires a left padding (the midpoint between \(U_i\) and \(U_{i+1}\) corresponds to \(T_{i+1}\)), and interpolating from T to U points a right padding (the midpoint between \(T_i\) and \(T_{i+1}\) corresponds to \(U_i\))

required
land_mask Float[Array, 'y x']

Mask indicating the land domain where extrapolation should be applied. False/0 indicates ocean cells, True/1 indicates land cells.

Defaults to None, in which case no land masking is applied and extrapolation is performed across the entire domain

None

Returns:

Name Type Description
field Float[Array, 'y x']

Interpolated field

Source code in jaxparrow/utils/operators.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def interpolation(
    field: Float[jax.Array, "y x"],
    axis: Literal[0, 1],
    padding: Literal["left", "right"],
    land_mask: Float[jax.Array, "y x"] = None,
) -> Float[jax.Array, "y x"]:
    """
    Interpolates the values of a ``field`` along a given ``axis`` (`0` for `lat`/`y`, `1` for `lon`/`x`),
    applying ``padding`` to the `left` (i.e. `West` if ``axis=1``, `South` if ``axis=0``) or
    to the `right` (i.e. `East` if ``axis=1``, `North` if ``axis=0``) of the domain.

    An open boundary condition is applied:

    - At domain edges: the interpolated value equals the nearest interior value
    - At land/NaN boundaries: if one of the two values is NaN, the valid value is used;
      if both are NaN, the result is NaN

    Parameters
    ----------
    field : Float[jax.Array, "y x"]
        Field to interpolate
    axis : Literal[0, 1]
        Axis along which interpolation is performed
    padding : Literal["left", "right"]
        Padding direction.
        For example, following NEMO convention,
        interpolating from U to T points requires a `left` padding
        (the midpoint between $U_i$ and $U_{i+1}$ corresponds to $T_{i+1}$),
        and interpolating from T to U points a `right` padding
        (the midpoint between $T_i$ and $T_{i+1}$ corresponds to $U_i$)
    land_mask : Float[jax.Array, "y x"], optional
        Mask indicating the land domain where extrapolation should be applied.
        `False`/`0` indicates ocean cells, `True`/`1` indicates land cells.

        Defaults to `None`, 
        in which case no land masking is applied and extrapolation is performed across the entire domain

    Returns
    -------
    field : Float[jax.Array, "y x"]
        Interpolated field
    """
    f = jnp.moveaxis(field, axis, -1)

    left = f[:, :-1]
    right = f[:, 1:]

    left_valid = ~jnp.isnan(left)
    right_valid = ~jnp.isnan(right)

    # Open boundary condition for NaN values:
    # - Both valid: average
    # - One valid: use the valid one
    # - Both NaN: NaN
    mid = jnp.where(
        left_valid & right_valid,
        (left + right) * 0.5,
        jnp.where(left_valid, left, jnp.where(right_valid, right, jnp.nan))
    )

    # Pad at the domain boundary with edge value (open boundary condition)
    mid = lax.cond(
        padding == "left",
        lambda: jnp.pad(mid, ((0, 0), (1, 0)), mode='edge'),
        lambda: jnp.pad(mid, ((0, 0), (0, 1)), mode='edge')
    )

    mid = jnp.moveaxis(mid, -1, axis)

    if land_mask is not None:
        mid = jnp.where(land_mask, jnp.nan, mid)

    return mid

derivative(field, axis, land_mask=None)

Differentiates a field, using finite differences, along a given axis (0 for lat/y, 1 for lon/'x'), applying padding to the left (i.e. West if axis=1, South if axis=0) or to the right (i.e. East if axis=1, North if axis=0) of the domain.

An open boundary condition is applied (zero second derivative):

  • At domain edges: the boundary derivative equals the nearest interior derivative
  • At land/NaN boundaries: if one of the two values is NaN, the derivative is filled with the immediate neighbor derivative; if both neighbors are NaN, the result is NaN

This is appropriate for domains with sharp physical boundaries (e.g., SWOT swaths) where the signal continues smoothly beyond the observation edge.

Parameters:

Name Type Description Default
field Float[Array, 'y x']

Field to differentiate

required
axis Literal[0, 1]

Axis along which interpolation is performed

required
land_mask Float[Array, 'y x']

Mask indicating the land domain where extrapolation should be applied. False/0 indicates ocean cells, True/1 indicates land cells.

Defaults to None, in which case no land masking is applied and extrapolation is performed across the entire domain

None

Returns:

Name Type Description
df Float[Array, 'y x']

Differentiated field

Source code in jaxparrow/utils/operators.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
def derivative(
    field: Float[jax.Array, "y x"],
    axis: Literal[0, 1],
    land_mask: Float[jax.Array, "y x"] = None,
) -> Float[jax.Array, "y x"]:
    """
    Differentiates a ``field``, using finite differences, along a given ``axis`` (`0` for `lat`/`y`, `1` for `lon`/'x'),
    applying ``padding`` to the `left` (i.e. `West` if ``axis=1``, `South` if ``axis=0``) or
    to the `right` (i.e. `East` if ``axis=1``, `North` if ``axis=0``) of the domain.

    An open boundary condition is applied (zero second derivative):

    - At domain edges: the boundary derivative equals the nearest interior derivative
    - At land/NaN boundaries: if one of the two values is NaN, the derivative is filled
      with the immediate neighbor derivative; if both neighbors are NaN, the result is NaN

    This is appropriate for domains with sharp physical boundaries (e.g., SWOT swaths)
    where the signal continues smoothly beyond the observation edge.

    Parameters
    ----------
    field : Float[jax.Array, "y x"]
        Field to differentiate
    axis : Literal[0, 1]
        Axis along which interpolation is performed
    land_mask : Float[jax.Array, "y x"], optional
        Mask indicating the land domain where extrapolation should be applied.
        `False`/`0` indicates ocean cells, `True`/`1` indicates land cells.

        Defaults to `None`, 
        in which case no land masking is applied and extrapolation is performed across the entire domain

    Returns
    -------
    df : Float[jax.Array, "y x"]
        Differentiated field
    """
    f = jnp.moveaxis(field, axis, -1)

    df = jnp.diff(f, axis=-1)

    # Open boundary condition for NaN values:
    # Fill NaN derivatives with immediate neighbor
    left_neighbor = jnp.roll(df, 1, axis=-1).at[..., 0].set(jnp.nan)
    right_neighbor = jnp.roll(df, -1, axis=-1).at[..., -1].set(jnp.nan)

    df = jnp.where(
        ~jnp.isnan(df),
        df,
        jnp.where(
            ~jnp.isnan(left_neighbor), left_neighbor,
            jnp.where(~jnp.isnan(right_neighbor), right_neighbor, jnp.nan)
        )
    )

    # Open boundary condition at domain edges: ∂²f/∂x² = 0
    # The boundary derivative equals the nearest interior derivative
    df = jnp.concatenate([df, df[..., -1:]], axis=-1)

    df = jnp.moveaxis(df, -1, axis)

    if land_mask is not None:
        df = jnp.where(land_mask, jnp.nan, df)

    return df

horizontal_derivatives(field, lat=None, lon=None, dx=None, dy=None, land_mask=None)

Computes the horizontal derivatives of a field defined on an orthogonal grid, using finite differences and applying an open boundary condition at domain edges and land/NaN boundaries.

Horizontal derivatives are returned on the same grid as the input field, and in the grid coordinates.

Parameters:

Name Type Description Default
field Float[Array, 'y x']

Field for which to compute gradients

required
lat Float[Array, 'y x']

Latitude grid corresponding to the field

Defaults to None, in which case dx and dy must be provided

None
lon Float[Array, 'y x']

Longitude grid corresponding to the field

Defaults to None, in which case dx and dy must be provided

None
dx Float[Array, 'y x']

Grid spacing in the eastward direction (i.e. along axis=1)

Defaults to None, in which case lat and lon must be provided

None
dy Float[Array, 'y x']

Grid spacing in the northward direction (i.e. along axis=1)

Defaults to None, in which case lat and lon must be provided

None
land_mask Float[Array, 'y x']

Mask indicating the land domain where extrapolation should be applied. False/0 indicates ocean cells, True/1 indicates land cells.

Defaults to None, in which case no land masking is applied and extrapolation is performed across the entire domain

None

Returns:

Name Type Description
df_e Float[Array, 'y x']

Eastward derivative of the field, on the same grid as the input field

df_n Float[Array, 'y x']

Northward derivative of the field, on the same grid as the input field

Source code in jaxparrow/utils/operators.py
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
def horizontal_derivatives(
    field: Float[jax.Array, "y x"],
    lat: Float[jax.Array, "y x"] = None,
    lon: Float[jax.Array, "y x"] = None,
    dx: Float[jax.Array, "y x"] = None,
    dy: Float[jax.Array, "y x"] = None,
    land_mask: Float[jax.Array, "y x"] = None,
) -> tuple[Float[jax.Array, "y x"], Float[jax.Array, "y x"]]:
    """
    Computes the horizontal derivatives of a ``field`` defined on an orthogonal grid, 
    using finite differences and applying an open boundary condition at domain edges and land/NaN boundaries.

    Horizontal derivatives are returned on the same grid as the input field, and in the grid coordinates.

    Parameters
    ----------
    field : Float[jax.Array, "y x"]
        Field for which to compute gradients
    lat : Float[jax.Array, "y x"], optional
        Latitude grid corresponding to the field

        Defaults to `None`, in which case ``dx`` and ``dy`` must be provided
    lon : Float[jax.Array, "y x"], optional
        Longitude grid corresponding to the field

        Defaults to `None`, in which case ``dx`` and ``dy`` must be provided
    dx : Float[jax.Array, "y x"], optional
        Grid spacing in the eastward direction (i.e. along axis=1)

        Defaults to `None`, in which case ``lat`` and ``lon`` must be provided
    dy : Float[jax.Array, "y x"], optional
        Grid spacing in the northward direction (i.e. along axis=1)

        Defaults to `None`, in which case ``lat`` and ``lon`` must be provided
    land_mask : Float[jax.Array, "y x"], optional
        Mask indicating the land domain where extrapolation should be applied.
        `False`/`0` indicates ocean cells, `True`/`1` indicates land cells.

        Defaults to `None`, 
        in which case no land masking is applied and extrapolation is performed across the entire domain

    Returns
    -------
    df_e : Float[jax.Array, "y x"]
        Eastward derivative of the field, on the same grid as the input field
    df_n : Float[jax.Array, "y x"]
        Northward derivative of the field, on the same grid as the input field
    """
    if dx is None or dy is None:
        if lat is None or lon is None:
            raise ValueError("Either lat/lon or dx/dy must be provided")
        from .geometry import grid_spacing
        dx, dy = grid_spacing(lat, lon)

    # compute derivatives in grid coordinates
    df_x = derivative(field, axis=1, land_mask=land_mask)
    df_y = derivative(field, axis=0, land_mask=land_mask)

    # interpolate from staggered grids to the T grid of the input field
    df_x = interpolation(df_x, axis=1, padding="left", land_mask=land_mask)
    df_y = interpolation(df_y, axis=0, padding="left", land_mask=land_mask)

    # scale derivatives with grid spacing to get physical derivatives
    df_x /= dx
    df_y /= dy

    return df_x, df_y

sanitize

sanitize_data(arr, fill_value, land_mask)

Sanitizes data by replacing nan with fill_value and applying fill_value to the masked area.

Parameters:

Name Type Description Default
arr Float[Array, 'y x']

Array to sanitize

required
fill_value float

Value to replace nan values and masked area with

required
land_mask Float[jax.Array, "y x"]

Mask to apply, 1 or True for masked

required

Returns:

Name Type Description
arr Float[Array, 'y x']

Sanitized array

Source code in jaxparrow/utils/sanitize.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def sanitize_data(
    arr: Float[jax.Array, "y x"],
    fill_value: float,
    land_mask: Float[jax.Array, "y x"]
) -> Float[jax.Array, "y x"]:
    """
    Sanitizes data by replacing `nan` with ``fill_value`` and applying ``fill_value`` to the masked area.

    Parameters
    ----------
    arr : Float[jax.Array, "y x"]
        Array to sanitize
    fill_value : float
        Value to replace `nan` values and masked area with
    land_mask :  Float[jax.Array, "y x"]
        Mask to apply, `1` or `True` for masked

    Returns
    -------
    arr : Float[jax.Array, "y x"]
        Sanitized array
    """
    arr = jnp.nan_to_num(arr, copy=False, nan=fill_value, posinf=fill_value, neginf=fill_value)
    arr = jnp.where(land_mask, fill_value, arr)
    return arr

init_land_mask(field, land_mask=None)

If land_mask is None, initializes it from the nan values of field. If land_mask is not None, simply returns it.

Parameters:

Name Type Description Default
field Float[Array, 'y x']

Field used to initialize the mask (if needed)

required
land_mask Float[jax.Array, "y x"]

Mask to initialized (if None).

Defaults to None

None

Returns:

Name Type Description
land_mask Float[Array, 'y x']

Initialized (if needed) land mask

Source code in jaxparrow/utils/sanitize.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def init_land_mask(
    field: Float[jax.Array, "y x"],
    land_mask: Float[jax.Array, "y x"] = None
) -> Float[jax.Array, "y x"]:
    """
    If ``land_mask is None``, initializes it from the `nan` values of ``field``.
    If ``land_mask is not None``, simply returns it.

    Parameters
    ----------
    field : Float[jax.Array, "y x"]
        Field used to initialize the mask (if needed)
    land_mask :  Float[jax.Array, "y x"], optional
        Mask to initialized (if `None`).

        Defaults to `None`

    Returns
    -------
    land_mask : Float[jax.Array, "y x"]
        Initialized (if needed) land mask
    """
    if land_mask is None:
        land_mask = ~jnp.isfinite(field)
    return land_mask