11from __future__ import annotations
22
3+ import itertools
34import time
45from pathlib import Path
56
2223from equitrain .backends .jax_runtime import ensure_multiprocessing_spawn
2324from equitrain .backends .jax_utils import (
2425 ModelBundle ,
26+ batched_iterator ,
27+ iter_micro_batches ,
2528 load_model_bundle ,
2629 replicate_to_local_devices ,
30+ supports_multiprocessing_workers ,
31+ take_chunk ,
2732 unreplicate_from_local_devices ,
2833)
34+ from equitrain .backends .jax_utils import is_multi_device as _is_multi_device
2935from equitrain .backends .jax_utils import (
3036 prepare_sharded_batch as _prepare_sharded_batch ,
3137)
@@ -82,10 +88,6 @@ def _sanitize(x):
8288 return jtu .tree_map (_sanitize , grads )
8389
8490
85- def _is_multi_device () -> bool :
86- return jax .local_device_count () > 1
87-
88-
8991def _replicate_state (state : TrainState ) -> TrainState :
9092 return replicate_to_local_devices (state )
9193
@@ -94,6 +96,33 @@ def _unreplicate(tree):
9496 return unreplicate_from_local_devices (tree )
9597
9698
99+ def _multi_device_chunk_iterator (loader , device_count : int , * , phase : str , logger ):
100+ """Group per-device micro-batches to feed into ``jax.pmap`` calls."""
101+ micro_iter = iter_micro_batches (loader )
102+ first_chunk = take_chunk (micro_iter , device_count )
103+ if len (first_chunk ) < device_count :
104+ raise RuntimeError (
105+ f'[{ phase } ] Need at least { device_count } micro-batches to utilize all '
106+ 'available devices. Reduce --batch-size or the device count.'
107+ )
108+
109+ def _warn (count , expected ):
110+ message = (
111+ f'[{ phase } ] Dropping incomplete multi-device chunk ({ count } /{ expected } ).'
112+ )
113+ if logger is not None :
114+ logger .log (1 , message )
115+ else :
116+ print (message )
117+
118+ remainder = batched_iterator (
119+ micro_iter ,
120+ device_count ,
121+ remainder_action = _warn ,
122+ )
123+ return itertools .chain ([first_chunk ], remainder )
124+
125+
97126def _build_train_functions (
98127 loss_fn ,
99128 optimizer ,
@@ -206,9 +235,14 @@ def _run_train_epoch(
206235 loss_collection = JaxLossCollection ()
207236 ema_count = ema_count_start
208237 device_count = jax .local_device_count () if multi_device else 1
238+ use_chunked_multi = multi_device and device_count > 1
209239 total_steps = None
210240 if hasattr (train_loader , '__len__' ):
211- total_steps = len (train_loader )
241+ approx_batches = len (train_loader )
242+ if use_chunked_multi and approx_batches is not None :
243+ total_steps = approx_batches // device_count
244+ else :
245+ total_steps = approx_batches
212246 if max_steps is not None :
213247 if total_steps is not None :
214248 total_steps = min (total_steps , max_steps )
@@ -220,7 +254,13 @@ def _run_train_epoch(
220254 mask_tree = jtu .tree_map (lambda v : jnp .asarray (v , dtype = jnp .bool_ ), mask )
221255
222256 use_tqdm = bool (getattr (args , 'tqdm' , False ) and tqdm is not None )
223- iterator = enumerate (train_loader )
257+ if use_chunked_multi :
258+ chunk_iter = _multi_device_chunk_iterator (
259+ train_loader , device_count , phase = 'Training' , logger = logger
260+ )
261+ iterator = enumerate (chunk_iter )
262+ else :
263+ iterator = enumerate (train_loader )
224264 progress = None
225265 if use_tqdm :
226266 progress = tqdm (
@@ -235,17 +275,17 @@ def _run_train_epoch(
235275 if max_steps is not None and step_index >= max_steps :
236276 break
237277
238- if isinstance ( graph , list ) :
239- micro_batches = [ g for g in graph if g is not None ]
278+ if use_chunked_multi :
279+ micro_batches = graph
240280 else :
241- micro_batches = [graph ]
281+ if isinstance (graph , list ):
282+ micro_batches = [g for g in graph if g is not None ]
283+ else :
284+ micro_batches = [graph ]
242285
243286 if not micro_batches :
244287 continue
245288
246- micro_count = len (micro_batches )
247- inv_micro = 1.0 / float (micro_count )
248-
249289 step_start = time .perf_counter ()
250290
251291 params_before = state .params
@@ -257,25 +297,26 @@ def _run_train_epoch(
257297 warmup_decay = (1.0 + ema_count ) / (10.0 + ema_count )
258298 ema_factor = float (min (float (ema_decay ), warmup_decay ))
259299
260- accum_grads = jtu .tree_map (lambda x : jnp .zeros_like (x ), state .params )
261300 macro_collection = JaxLossCollection ()
262301
263- for micro_batch in micro_batches :
264- if multi_device :
265- prepared_batch = _prepare_sharded_batch (micro_batch , device_count )
266- _ , aux_dev , grads = grad_step_fn (state .params , prepared_batch )
267- grads = jtu .tree_map (lambda g : g * inv_micro , grads )
268- accum_grads = jtu .tree_map (lambda acc , g : acc + g , accum_grads , grads )
269- aux_host = _unreplicate (aux_dev )
270- else :
302+ if use_chunked_multi :
303+ prepared_batch = _prepare_sharded_batch (micro_batches , device_count )
304+ _ , aux_dev , grads = grad_step_fn (state .params , prepared_batch )
305+ accum_grads = grads
306+ aux_host = _unreplicate (aux_dev )
307+ update_collection_from_aux (loss_collection , aux_host )
308+ update_collection_from_aux (macro_collection , aux_host )
309+ else :
310+ inv_micro = 1.0 / float (len (micro_batches ))
311+ accum_grads = jtu .tree_map (lambda x : jnp .zeros_like (x ), state .params )
312+ for micro_batch in micro_batches :
271313 prepared_batch = _prepare_single_batch (micro_batch )
272314 _ , aux_val , grads = grad_step_fn (state .params , prepared_batch )
273315 grads = jtu .tree_map (lambda g : g * inv_micro , grads )
274316 accum_grads = jtu .tree_map (lambda acc , g : acc + g , accum_grads , grads )
275317 aux_host = jax .device_get (aux_val )
276-
277- update_collection_from_aux (loss_collection , aux_host )
278- update_collection_from_aux (macro_collection , aux_host )
318+ update_collection_from_aux (loss_collection , aux_host )
319+ update_collection_from_aux (macro_collection , aux_host )
279320
280321 state = apply_updates_fn (state , accum_grads , ema_factor )
281322
@@ -357,6 +398,7 @@ def _run_eval_loop(
357398 * ,
358399 max_steps ,
359400 multi_device : bool ,
401+ logger = None ,
360402):
361403 if loader is None :
362404 return None , JaxLossCollection ()
@@ -365,27 +407,38 @@ def _run_eval_loop(
365407 device_count = jax .local_device_count () if multi_device else 1
366408 mean_loss = None
367409
368- for step_index , graph in enumerate (loader ):
410+ if multi_device and device_count > 1 :
411+ data_iter = _multi_device_chunk_iterator (
412+ loader , device_count , phase = 'Eval' , logger = logger
413+ )
414+ else :
415+ data_iter = loader
416+
417+ for step_index , graph in enumerate (data_iter ):
369418 if max_steps is not None and step_index >= max_steps :
370419 break
371- if isinstance (graph , list ):
372- micro_batches = [g for g in graph if g is not None ]
373- else :
374- micro_batches = [graph ]
375- if not micro_batches :
376- continue
377-
378- for micro_batch in micro_batches :
379- if multi_device :
380- batch = _prepare_sharded_batch (micro_batch , device_count )
381- else :
382- batch = _prepare_single_batch (micro_batch )
383-
420+ if multi_device and device_count > 1 :
421+ micro_batches = graph
422+ batch = _prepare_sharded_batch (micro_batches , device_count )
384423 loss , aux = eval_step_fn (params , batch )
385- loss = _unreplicate (loss ) if multi_device else jax . device_get ( loss )
386- aux = _unreplicate (aux ) if multi_device else jax . device_get ( aux )
424+ loss = _unreplicate (loss )
425+ aux = _unreplicate (aux )
387426 update_collection_from_aux (loss_collection , aux )
388427 mean_loss = float (loss )
428+ else :
429+ if isinstance (graph , list ):
430+ micro_batches = [g for g in graph if g is not None ]
431+ else :
432+ micro_batches = [graph ]
433+ if not micro_batches :
434+ continue
435+ for micro_batch in micro_batches :
436+ batch = _prepare_single_batch (micro_batch )
437+ loss , aux = eval_step_fn (params , batch )
438+ loss = jax .device_get (loss )
439+ aux = jax .device_get (aux )
440+ update_collection_from_aux (loss_collection , aux )
441+ mean_loss = float (loss )
389442
390443 if loss_collection .components ['total' ].count :
391444 mean_loss = loss_collection .components ['total' ].value
@@ -445,21 +498,51 @@ def train(args):
445498 reduce_cells = bool (getattr (args , 'niggli_reduce' , False ))
446499 train_seed = getattr (args , 'seed' , None )
447500
501+ multi_device = _is_multi_device ()
502+ device_count = jax .local_device_count () if multi_device else 1
503+
504+ if getattr (args , 'batch_size' , None ) is None or args .batch_size <= 0 :
505+ raise ValueError ('JAX backend requires a positive --batch-size.' )
506+ total_batch_size = int (args .batch_size )
507+ per_device_batch = total_batch_size
508+ if multi_device and device_count > 1 :
509+ if total_batch_size % device_count != 0 :
510+ raise ValueError (
511+ 'For JAX multi-device training, --batch-size must be divisible by '
512+ 'the number of local devices.'
513+ )
514+ per_device_batch = total_batch_size // device_count
515+
516+ base_workers = max (int (getattr (args , 'num_workers' , 0 ) or 0 ), 0 )
517+ if base_workers > 0 and supports_multiprocessing_workers ():
518+ effective_workers = base_workers
519+ if multi_device and device_count > 1 :
520+ effective_workers *= device_count
521+ else :
522+ effective_workers = 0
523+ prefetch_requested = getattr (args , 'prefetch_batches' , None )
524+ if prefetch_requested is None :
525+ prefetch_batches = effective_workers
526+ else :
527+ prefetch_batches = max (int (prefetch_requested or 0 ), 0 )
528+
448529 def _build_streaming_loader (path : str | None , shuffle : bool ):
449530 if path in (None , 'None' ):
450531 return None
451532 return get_dataloader (
452533 data_file = path ,
453534 atomic_numbers = z_table ,
454535 r_max = r_max ,
455- batch_size = args . batch_size ,
536+ batch_size = per_device_batch ,
456537 shuffle = shuffle ,
457538 max_nodes = args .batch_max_nodes ,
458539 max_edges = args .batch_max_edges ,
459540 drop = getattr (args , 'batch_drop' , False ),
460541 seed = train_seed if shuffle else None ,
461542 niggli_reduce = reduce_cells ,
462- prefetch_batches = args .prefetch_batches ,
543+ prefetch_batches = prefetch_batches ,
544+ num_workers = effective_workers ,
545+ graph_multiple = 1 ,
463546 )
464547
465548 train_loader = _build_streaming_loader (args .train_file , shuffle = args .shuffle )
@@ -475,12 +558,6 @@ def _build_streaming_loader(path: str | None, shuffle: bool):
475558 )
476559
477560 num_species = len (z_table )
478- multi_device = _is_multi_device ()
479- device_count = jax .local_device_count () if multi_device else 1
480- if multi_device and device_count > 0 and (args .batch_size % device_count != 0 ):
481- raise ValueError (
482- 'For JAX multi-device training, --batch-size must be divisible by the number of local devices.'
483- )
484561
485562 apply_fn = make_apply_fn (wrapper , num_species = num_species )
486563 loss_settings = LossSettings .from_args (args )
@@ -553,6 +630,7 @@ def _host(tree):
553630 eval_step_fn ,
554631 max_steps = valid_max_steps ,
555632 multi_device = multi_device ,
633+ logger = logger ,
556634 )
557635
558636 current_params_host = _host (train_state .params )
@@ -630,6 +708,7 @@ def _host(tree):
630708 eval_step_fn ,
631709 max_steps = valid_max_steps ,
632710 multi_device = multi_device ,
711+ logger = logger ,
633712 )
634713
635714 train_metric = LossMetrics (** metric_settings )
@@ -738,6 +817,7 @@ def _host(tree):
738817 eval_step_fn ,
739818 max_steps = None ,
740819 multi_device = multi_device ,
820+ logger = logger ,
741821 )
742822 test_metrics = LossMetrics (
743823 include_energy = loss_settings .energy_weight > 0.0 ,
0 commit comments