Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 50 additions & 65 deletions docs/user_guide/examples/tutorial_Argofloats.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -26,70 +26,62 @@
"source": [
"import numpy as np\n",
"\n",
"# Define the new Kernels that mimic Argo vertical movement\n",
"# Define the new Kernel that mimics Argo vertical movement\n",
"driftdepth = 1000 # maximum depth in m\n",
"maxdepth = 2000 # maximum depth in m\n",
"vertical_speed = 0.10 # sink and rise speed in m/s\n",
"cycletime = 10 * 86400 # total time of cycle in seconds\n",
"drifttime = 9 * 86400 # time of deep drift in seconds\n",
"\n",
"\n",
"def ArgoPhase1(particles, fieldset):\n",
" def SinkingPhase(p):\n",
" \"\"\"Phase 0: Sinking with vertical_speed until depth is driftdepth\"\"\"\n",
" p.dz += vertical_speed * particles.dt\n",
" p.cycle_phase = np.where(p.z + p.dz >= driftdepth, 1, p.cycle_phase)\n",
" p.dz = np.where(p.z + p.dz >= driftdepth, driftdepth - p.z, p.dz)\n",
"def ArgoVerticalMovement(particles, fieldset):\n",
" # Split particles based on their current cycle_phase\n",
" ptcls0 = particles[particles.cycle_phase == 0]\n",
" ptcls1 = particles[particles.cycle_phase == 1]\n",
" ptcls2 = particles[particles.cycle_phase == 2]\n",
" ptcls3 = particles[particles.cycle_phase == 3]\n",
" ptcls4 = particles[particles.cycle_phase == 4]\n",
"\n",
" # Phase 0: Sinking with vertical_speed until depth is driftdepth\n",
" ptcls0.dz += vertical_speed * ptcls0.dt\n",
" ptcls0.cycle_phase = np.where(\n",
" ptcls0.z + ptcls0.dz >= driftdepth, 1, ptcls0.cycle_phase\n",
" )\n",
" ptcls0.dz = np.where(\n",
" ptcls0.z + ptcls0.dz >= driftdepth, driftdepth - ptcls0.z, ptcls0.dz\n",
" )\n",
"\n",
" # Phase 1: Drifting at depth for drifttime seconds\n",
" ptcls1.drift_age += ptcls1.dt\n",
" ptcls1.cycle_phase = np.where(ptcls1.drift_age >= drifttime, 2, ptcls1.cycle_phase)\n",
" ptcls1.drift_age = np.where(ptcls1.drift_age >= drifttime, 0, ptcls1.drift_age)\n",
"\n",
" # Phase 2: Sinking further to maxdepth\n",
" ptcls2.dz += vertical_speed * ptcls2.dt\n",
" ptcls2.cycle_phase = np.where(\n",
" ptcls2.z + ptcls2.dz >= maxdepth, 3, ptcls2.cycle_phase\n",
" )\n",
" ptcls2.dz = np.where(\n",
" ptcls2.z + ptcls2.dz >= maxdepth, maxdepth - ptcls2.z, ptcls2.dz\n",
" )\n",
"\n",
" # Phase 3: Rising with vertical_speed until at surface\n",
" ptcls3.dz -= vertical_speed * ptcls3.dt\n",
" ptcls3.temp = fieldset.thetao[ptcls3.time, ptcls3.z, ptcls3.lat, ptcls3.lon]\n",
" ptcls3.cycle_phase = np.where(\n",
" ptcls3.z + ptcls3.dz <= fieldset.mindepth, 4, ptcls3.cycle_phase\n",
" )\n",
" ptcls3.dz = np.where(\n",
" ptcls3.z + ptcls3.dz <= fieldset.mindepth,\n",
" fieldset.mindepth - ptcls3.z,\n",
" ptcls3.dz,\n",
" )\n",
"\n",
" # Phase 4: Transmitting at surface until cycletime is reached\n",
" ptcls4.cycle_phase = np.where(ptcls4.cycle_age >= cycletime, 0, ptcls4.cycle_phase)\n",
" ptcls4.cycle_age = np.where(ptcls4.cycle_age >= cycletime, 0, ptcls4.cycle_age)\n",
" ptcls4.temp = np.nan # no temperature measurement when at surface\n",
"\n",
" SinkingPhase(particles[particles.cycle_phase == 0])\n",
"\n",
"\n",
"def ArgoPhase2(particles, fieldset):\n",
" def DriftingPhase(p):\n",
" \"\"\"Phase 1: Drifting at depth for drifttime seconds\"\"\"\n",
" p.drift_age += particles.dt\n",
" p.cycle_phase = np.where(p.drift_age >= drifttime, 2, p.cycle_phase)\n",
" p.drift_age = np.where(p.drift_age >= drifttime, 0, p.drift_age)\n",
"\n",
" DriftingPhase(particles[particles.cycle_phase == 1])\n",
"\n",
"\n",
"def ArgoPhase3(particles, fieldset):\n",
" def SecondSinkingPhase(p):\n",
" \"\"\"Phase 2: Sinking further to maxdepth\"\"\"\n",
" p.dz += vertical_speed * particles.dt\n",
" p.cycle_phase = np.where(p.z + p.dz >= maxdepth, 3, p.cycle_phase)\n",
" p.dz = np.where(p.z + p.dz >= maxdepth, maxdepth - p.z, p.dz)\n",
"\n",
" SecondSinkingPhase(particles[particles.cycle_phase == 2])\n",
"\n",
"\n",
"def ArgoPhase4(particles, fieldset):\n",
" def RisingPhase(p):\n",
" \"\"\"Phase 3: Rising with vertical_speed until at surface\"\"\"\n",
" p.dz -= vertical_speed * particles.dt\n",
" p.temp = fieldset.thetao[p.time, p.z, p.lat, p.lon]\n",
" p.cycle_phase = np.where(p.z + p.dz <= fieldset.mindepth, 4, p.cycle_phase)\n",
" p.dz = np.where(\n",
" p.z + p.dz <= fieldset.mindepth,\n",
" fieldset.mindepth - p.z,\n",
" p.dz,\n",
" )\n",
"\n",
" RisingPhase(particles[particles.cycle_phase == 3])\n",
"\n",
"\n",
"def ArgoPhase5(particles, fieldset):\n",
" def TransmittingPhase(p):\n",
" \"\"\"Phase 4: Transmitting at surface until cycletime is reached\"\"\"\n",
" p.cycle_phase = np.where(p.cycle_age >= cycletime, 0, p.cycle_phase)\n",
" p.cycle_age = np.where(p.cycle_age >= cycletime, 0, p.cycle_age)\n",
" p.temp = np.nan # no temperature measurement when at surface\n",
"\n",
" TransmittingPhase(particles[particles.cycle_phase == 4])\n",
"\n",
"\n",
"def ArgoPhase6(particles, fieldset):\n",
" particles.cycle_age += particles.dt # update cycle_age"
]
},
Expand Down Expand Up @@ -136,9 +128,7 @@
"ArgoParticle = parcels.Particle.add_variable(\n",
" [\n",
" parcels.Variable(\"cycle_phase\", dtype=np.int32, initial=0.0),\n",
" parcels.Variable(\n",
" \"cycle_age\", dtype=np.float32, initial=0.0\n",
" ), # TODO update to \"timedelta64[s]\"\n",
" parcels.Variable(\"cycle_age\", dtype=np.float32, initial=0.0),\n",
" parcels.Variable(\"drift_age\", dtype=np.float32, initial=0.0),\n",
" parcels.Variable(\"temp\", dtype=np.float32, initial=np.nan),\n",
" ]\n",
Expand All @@ -155,12 +145,7 @@
"\n",
"# combine Argo vertical movement kernel with built-in Advection kernel\n",
"kernels = [\n",
" ArgoPhase1,\n",
" ArgoPhase2,\n",
" ArgoPhase3,\n",
" ArgoPhase4,\n",
" ArgoPhase5,\n",
" ArgoPhase6,\n",
" ArgoVerticalMovement,\n",
" parcels.kernels.AdvectionRK4,\n",
"]\n",
"\n",
Expand Down
13 changes: 2 additions & 11 deletions docs/user_guide/examples/tutorial_interaction.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -293,18 +293,9 @@
" larger_idx = np.where(mass_j > mass_i, pair_j, pair_i)\n",
" smaller_idx = np.where(mass_j > mass_i, pair_i, pair_j)\n",
"\n",
" # perform transfer and mark deletions\n",
" # TODO note that we use temporary arrays for indexing because of KernelParticle bug (GH #2143)\n",
" masses = particles.mass\n",
" states = particles.state\n",
"\n",
" # transfer mass from smaller to larger and mark smaller for deletion\n",
" masses[larger_idx] += particles.mass[smaller_idx]\n",
" states[smaller_idx] = parcels.StatusCode.Delete\n",
"\n",
" # TODO use particle variables directly after KernelParticle bug (GH #2143) is fixed\n",
" particles.mass = masses\n",
" particles.state = states"
" particles.mass[larger_idx] += particles.mass[smaller_idx]\n",
" particles.state[smaller_idx] = parcels.StatusCode.Delete"
]
},
{
Expand Down
3 changes: 0 additions & 3 deletions src/parcels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
Variable,
Particle,
ParticleClass,
KernelParticle, # ? remove?
)
from parcels._core.field import Field, VectorField
from parcels._core.basegrid import BaseGrid
Expand Down Expand Up @@ -87,8 +86,6 @@
"logger",
"download_example_dataset",
"list_example_datasets",
# (marked for potential removal)
"KernelParticle",
]

_stdlib_warnings.warn(
Expand Down
10 changes: 5 additions & 5 deletions src/parcels/_core/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
_unitconverters_map,
)
from parcels._core.index_search import GRID_SEARCH_ERROR, LEFT_OUT_OF_BOUNDS, RIGHT_OUT_OF_BOUNDS, _search_time_index
from parcels._core.particle import KernelParticle
from parcels._core.particlesetview import ParticleSetView
from parcels._core.statuscodes import (
AllParcelsErrorCodes,
StatusCode,
Expand All @@ -35,9 +35,9 @@


def _deal_with_errors(error, key, vector_type: VectorType):
if isinstance(key, KernelParticle):
if isinstance(key, ParticleSetView):
key.state = AllParcelsErrorCodes[type(error)]
elif isinstance(key[-1], KernelParticle):
elif isinstance(key[-1], ParticleSetView):
key[-1].state = AllParcelsErrorCodes[type(error)]
else:
raise RuntimeError(f"{error}. Error could not be handled because particles was not part of the Field Sampling.")
Expand Down Expand Up @@ -229,7 +229,7 @@ def eval(self, time: datetime, z, y, x, particles=None, applyConversion=True):
def __getitem__(self, key):
self._check_velocitysampling()
try:
if isinstance(key, KernelParticle):
if isinstance(key, ParticleSetView):
return self.eval(key.time, key.z, key.lat, key.lon, key)
else:
return self.eval(*key)
Expand Down Expand Up @@ -330,7 +330,7 @@ def eval(self, time: datetime, z, y, x, particles=None, applyConversion=True):

def __getitem__(self, key):
try:
if isinstance(key, KernelParticle):
if isinstance(key, ParticleSetView):
return self.eval(key.time, key.z, key.lat, key.lon, key)
else:
return self.eval(*key)
Expand Down
26 changes: 1 addition & 25 deletions src/parcels/_core/particle.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from parcels._core.utils.time import TimeInterval
from parcels._reprs import _format_list_items_multiline

__all__ = ["KernelParticle", "Particle", "ParticleClass", "Variable"]
__all__ = ["Particle", "ParticleClass", "Variable"]
_TO_WRITE_OPTIONS = [True, False, "once"]


Expand Down Expand Up @@ -116,30 +116,6 @@ def add_variable(self, variable: Variable | list[Variable]):
return ParticleClass(variables=self.variables + variable)


class KernelParticle:
"""Simple class to be used in a kernel that links a particle (on the kernel level) to a particle dataset."""

def __init__(self, data, index):
self._data = data
self._index = index

def __getattr__(self, name):
return self._data[name][self._index]

def __setattr__(self, name, value):
if name in ["_data", "_index"]:
object.__setattr__(self, name, value)
else:
self._data[name][self._index] = value

def __getitem__(self, index):
self._index = index
return self

def __len__(self):
return len(self._index)


def _assert_no_duplicate_variable_names(*, existing_vars: list[Variable], new_vars: list[Variable]):
existing_names = {var.name for var in existing_vars}
for var in new_vars:
Expand Down
5 changes: 3 additions & 2 deletions src/parcels/_core/particleset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@

from parcels._core.converters import _convert_to_flat_array
from parcels._core.kernel import Kernel
from parcels._core.particle import KernelParticle, Particle, create_particle_data
from parcels._core.particle import Particle, create_particle_data
from parcels._core.particlesetview import ParticleSetView
from parcels._core.statuscodes import StatusCode
from parcels._core.utils.time import (
TimeInterval,
Expand Down Expand Up @@ -166,7 +167,7 @@ def __getattr__(self, name):

def __getitem__(self, index):
"""Get a single particle by index."""
return KernelParticle(self._data, index=index)
return ParticleSetView(self._data, index=index)

def __setattr__(self, name, value):
if name in ["_data"]:
Expand Down
Loading
Loading