Skip to content

Commit 5d19c85

Browse files
committed
optimized jax multi-gpu support
1 parent 31d9bb2 commit 5d19c85

File tree

5 files changed

+376
-142
lines changed

5 files changed

+376
-142
lines changed

equitrain/backends/jax_backend.py

Lines changed: 128 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import itertools
34
import time
45
from pathlib import Path
56

@@ -22,10 +23,15 @@
2223
from equitrain.backends.jax_runtime import ensure_multiprocessing_spawn
2324
from equitrain.backends.jax_utils import (
2425
ModelBundle,
26+
batched_iterator,
27+
iter_micro_batches,
2528
load_model_bundle,
2629
replicate_to_local_devices,
30+
supports_multiprocessing_workers,
31+
take_chunk,
2732
unreplicate_from_local_devices,
2833
)
34+
from equitrain.backends.jax_utils import is_multi_device as _is_multi_device
2935
from equitrain.backends.jax_utils import (
3036
prepare_sharded_batch as _prepare_sharded_batch,
3137
)
@@ -82,10 +88,6 @@ def _sanitize(x):
8288
return jtu.tree_map(_sanitize, grads)
8389

8490

85-
def _is_multi_device() -> bool:
86-
return jax.local_device_count() > 1
87-
88-
8991
def _replicate_state(state: TrainState) -> TrainState:
9092
return replicate_to_local_devices(state)
9193

@@ -94,6 +96,33 @@ def _unreplicate(tree):
9496
return unreplicate_from_local_devices(tree)
9597

9698

99+
def _multi_device_chunk_iterator(loader, device_count: int, *, phase: str, logger):
100+
"""Group per-device micro-batches to feed into ``jax.pmap`` calls."""
101+
micro_iter = iter_micro_batches(loader)
102+
first_chunk = take_chunk(micro_iter, device_count)
103+
if len(first_chunk) < device_count:
104+
raise RuntimeError(
105+
f'[{phase}] Need at least {device_count} micro-batches to utilize all '
106+
'available devices. Reduce --batch-size or the device count.'
107+
)
108+
109+
def _warn(count, expected):
110+
message = (
111+
f'[{phase}] Dropping incomplete multi-device chunk ({count}/{expected}).'
112+
)
113+
if logger is not None:
114+
logger.log(1, message)
115+
else:
116+
print(message)
117+
118+
remainder = batched_iterator(
119+
micro_iter,
120+
device_count,
121+
remainder_action=_warn,
122+
)
123+
return itertools.chain([first_chunk], remainder)
124+
125+
97126
def _build_train_functions(
98127
loss_fn,
99128
optimizer,
@@ -206,9 +235,14 @@ def _run_train_epoch(
206235
loss_collection = JaxLossCollection()
207236
ema_count = ema_count_start
208237
device_count = jax.local_device_count() if multi_device else 1
238+
use_chunked_multi = multi_device and device_count > 1
209239
total_steps = None
210240
if hasattr(train_loader, '__len__'):
211-
total_steps = len(train_loader)
241+
approx_batches = len(train_loader)
242+
if use_chunked_multi and approx_batches is not None:
243+
total_steps = approx_batches // device_count
244+
else:
245+
total_steps = approx_batches
212246
if max_steps is not None:
213247
if total_steps is not None:
214248
total_steps = min(total_steps, max_steps)
@@ -220,7 +254,13 @@ def _run_train_epoch(
220254
mask_tree = jtu.tree_map(lambda v: jnp.asarray(v, dtype=jnp.bool_), mask)
221255

222256
use_tqdm = bool(getattr(args, 'tqdm', False) and tqdm is not None)
223-
iterator = enumerate(train_loader)
257+
if use_chunked_multi:
258+
chunk_iter = _multi_device_chunk_iterator(
259+
train_loader, device_count, phase='Training', logger=logger
260+
)
261+
iterator = enumerate(chunk_iter)
262+
else:
263+
iterator = enumerate(train_loader)
224264
progress = None
225265
if use_tqdm:
226266
progress = tqdm(
@@ -235,17 +275,17 @@ def _run_train_epoch(
235275
if max_steps is not None and step_index >= max_steps:
236276
break
237277

238-
if isinstance(graph, list):
239-
micro_batches = [g for g in graph if g is not None]
278+
if use_chunked_multi:
279+
micro_batches = graph
240280
else:
241-
micro_batches = [graph]
281+
if isinstance(graph, list):
282+
micro_batches = [g for g in graph if g is not None]
283+
else:
284+
micro_batches = [graph]
242285

243286
if not micro_batches:
244287
continue
245288

246-
micro_count = len(micro_batches)
247-
inv_micro = 1.0 / float(micro_count)
248-
249289
step_start = time.perf_counter()
250290

251291
params_before = state.params
@@ -257,25 +297,26 @@ def _run_train_epoch(
257297
warmup_decay = (1.0 + ema_count) / (10.0 + ema_count)
258298
ema_factor = float(min(float(ema_decay), warmup_decay))
259299

260-
accum_grads = jtu.tree_map(lambda x: jnp.zeros_like(x), state.params)
261300
macro_collection = JaxLossCollection()
262301

263-
for micro_batch in micro_batches:
264-
if multi_device:
265-
prepared_batch = _prepare_sharded_batch(micro_batch, device_count)
266-
_, aux_dev, grads = grad_step_fn(state.params, prepared_batch)
267-
grads = jtu.tree_map(lambda g: g * inv_micro, grads)
268-
accum_grads = jtu.tree_map(lambda acc, g: acc + g, accum_grads, grads)
269-
aux_host = _unreplicate(aux_dev)
270-
else:
302+
if use_chunked_multi:
303+
prepared_batch = _prepare_sharded_batch(micro_batches, device_count)
304+
_, aux_dev, grads = grad_step_fn(state.params, prepared_batch)
305+
accum_grads = grads
306+
aux_host = _unreplicate(aux_dev)
307+
update_collection_from_aux(loss_collection, aux_host)
308+
update_collection_from_aux(macro_collection, aux_host)
309+
else:
310+
inv_micro = 1.0 / float(len(micro_batches))
311+
accum_grads = jtu.tree_map(lambda x: jnp.zeros_like(x), state.params)
312+
for micro_batch in micro_batches:
271313
prepared_batch = _prepare_single_batch(micro_batch)
272314
_, aux_val, grads = grad_step_fn(state.params, prepared_batch)
273315
grads = jtu.tree_map(lambda g: g * inv_micro, grads)
274316
accum_grads = jtu.tree_map(lambda acc, g: acc + g, accum_grads, grads)
275317
aux_host = jax.device_get(aux_val)
276-
277-
update_collection_from_aux(loss_collection, aux_host)
278-
update_collection_from_aux(macro_collection, aux_host)
318+
update_collection_from_aux(loss_collection, aux_host)
319+
update_collection_from_aux(macro_collection, aux_host)
279320

280321
state = apply_updates_fn(state, accum_grads, ema_factor)
281322

@@ -357,6 +398,7 @@ def _run_eval_loop(
357398
*,
358399
max_steps,
359400
multi_device: bool,
401+
logger=None,
360402
):
361403
if loader is None:
362404
return None, JaxLossCollection()
@@ -365,27 +407,38 @@ def _run_eval_loop(
365407
device_count = jax.local_device_count() if multi_device else 1
366408
mean_loss = None
367409

368-
for step_index, graph in enumerate(loader):
410+
if multi_device and device_count > 1:
411+
data_iter = _multi_device_chunk_iterator(
412+
loader, device_count, phase='Eval', logger=logger
413+
)
414+
else:
415+
data_iter = loader
416+
417+
for step_index, graph in enumerate(data_iter):
369418
if max_steps is not None and step_index >= max_steps:
370419
break
371-
if isinstance(graph, list):
372-
micro_batches = [g for g in graph if g is not None]
373-
else:
374-
micro_batches = [graph]
375-
if not micro_batches:
376-
continue
377-
378-
for micro_batch in micro_batches:
379-
if multi_device:
380-
batch = _prepare_sharded_batch(micro_batch, device_count)
381-
else:
382-
batch = _prepare_single_batch(micro_batch)
383-
420+
if multi_device and device_count > 1:
421+
micro_batches = graph
422+
batch = _prepare_sharded_batch(micro_batches, device_count)
384423
loss, aux = eval_step_fn(params, batch)
385-
loss = _unreplicate(loss) if multi_device else jax.device_get(loss)
386-
aux = _unreplicate(aux) if multi_device else jax.device_get(aux)
424+
loss = _unreplicate(loss)
425+
aux = _unreplicate(aux)
387426
update_collection_from_aux(loss_collection, aux)
388427
mean_loss = float(loss)
428+
else:
429+
if isinstance(graph, list):
430+
micro_batches = [g for g in graph if g is not None]
431+
else:
432+
micro_batches = [graph]
433+
if not micro_batches:
434+
continue
435+
for micro_batch in micro_batches:
436+
batch = _prepare_single_batch(micro_batch)
437+
loss, aux = eval_step_fn(params, batch)
438+
loss = jax.device_get(loss)
439+
aux = jax.device_get(aux)
440+
update_collection_from_aux(loss_collection, aux)
441+
mean_loss = float(loss)
389442

390443
if loss_collection.components['total'].count:
391444
mean_loss = loss_collection.components['total'].value
@@ -445,21 +498,51 @@ def train(args):
445498
reduce_cells = bool(getattr(args, 'niggli_reduce', False))
446499
train_seed = getattr(args, 'seed', None)
447500

501+
multi_device = _is_multi_device()
502+
device_count = jax.local_device_count() if multi_device else 1
503+
504+
if getattr(args, 'batch_size', None) is None or args.batch_size <= 0:
505+
raise ValueError('JAX backend requires a positive --batch-size.')
506+
total_batch_size = int(args.batch_size)
507+
per_device_batch = total_batch_size
508+
if multi_device and device_count > 1:
509+
if total_batch_size % device_count != 0:
510+
raise ValueError(
511+
'For JAX multi-device training, --batch-size must be divisible by '
512+
'the number of local devices.'
513+
)
514+
per_device_batch = total_batch_size // device_count
515+
516+
base_workers = max(int(getattr(args, 'num_workers', 0) or 0), 0)
517+
if base_workers > 0 and supports_multiprocessing_workers():
518+
effective_workers = base_workers
519+
if multi_device and device_count > 1:
520+
effective_workers *= device_count
521+
else:
522+
effective_workers = 0
523+
prefetch_requested = getattr(args, 'prefetch_batches', None)
524+
if prefetch_requested is None:
525+
prefetch_batches = effective_workers
526+
else:
527+
prefetch_batches = max(int(prefetch_requested or 0), 0)
528+
448529
def _build_streaming_loader(path: str | None, shuffle: bool):
449530
if path in (None, 'None'):
450531
return None
451532
return get_dataloader(
452533
data_file=path,
453534
atomic_numbers=z_table,
454535
r_max=r_max,
455-
batch_size=args.batch_size,
536+
batch_size=per_device_batch,
456537
shuffle=shuffle,
457538
max_nodes=args.batch_max_nodes,
458539
max_edges=args.batch_max_edges,
459540
drop=getattr(args, 'batch_drop', False),
460541
seed=train_seed if shuffle else None,
461542
niggli_reduce=reduce_cells,
462-
prefetch_batches=args.prefetch_batches,
543+
prefetch_batches=prefetch_batches,
544+
num_workers=effective_workers,
545+
graph_multiple=1,
463546
)
464547

465548
train_loader = _build_streaming_loader(args.train_file, shuffle=args.shuffle)
@@ -475,12 +558,6 @@ def _build_streaming_loader(path: str | None, shuffle: bool):
475558
)
476559

477560
num_species = len(z_table)
478-
multi_device = _is_multi_device()
479-
device_count = jax.local_device_count() if multi_device else 1
480-
if multi_device and device_count > 0 and (args.batch_size % device_count != 0):
481-
raise ValueError(
482-
'For JAX multi-device training, --batch-size must be divisible by the number of local devices.'
483-
)
484561

485562
apply_fn = make_apply_fn(wrapper, num_species=num_species)
486563
loss_settings = LossSettings.from_args(args)
@@ -553,6 +630,7 @@ def _host(tree):
553630
eval_step_fn,
554631
max_steps=valid_max_steps,
555632
multi_device=multi_device,
633+
logger=logger,
556634
)
557635

558636
current_params_host = _host(train_state.params)
@@ -630,6 +708,7 @@ def _host(tree):
630708
eval_step_fn,
631709
max_steps=valid_max_steps,
632710
multi_device=multi_device,
711+
logger=logger,
633712
)
634713

635714
train_metric = LossMetrics(**metric_settings)
@@ -738,6 +817,7 @@ def _host(tree):
738817
eval_step_fn,
739818
max_steps=None,
740819
multi_device=multi_device,
820+
logger=logger,
741821
)
742822
test_metrics = LossMetrics(
743823
include_energy=loss_settings.energy_weight > 0.0,

0 commit comments

Comments
 (0)