Skip to content

Multi-res grid refinement + Neon backend support#159

Merged
hsalehipour merged 269 commits into
Autodesk:mainfrom
hsalehipour:dev
May 14, 2026
Merged

Multi-res grid refinement + Neon backend support#159
hsalehipour merged 269 commits into
Autodesk:mainfrom
hsalehipour:dev

Conversation

@hsalehipour
Copy link
Copy Markdown
Collaborator

Contributing Guidelines

Description

Grid refinement capability is now supported in XLB through the Neon backend. The Neon backend provides full support for dense grids on multi-GPU systems, as well as multi-resolution grids on single GPUs. All newly introduced functionalities have been carefully tested and optimized. This represents a major enhancement to the library and involves substantial additions and improvements to the codebase.

Type of change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Documentation update

How Has This Been Tested?

  • All pytest tests pass
============================================= test session starts ==============================================
platform linux -- Python 3.12.3, pytest-9.0.2, pluggy-1.6.0
rootdir: /home/max/repo/test/XLB
collected 93 items                                                                                             

tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py ....                                 [  4%]
tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py ....                                [  8%]
tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py ......               [ 15%]
tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py ......              [ 21%]
tests/boundary_conditions/mask/test_bc_indices_masker_jax.py .......                                     [ 29%]
tests/boundary_conditions/mask/test_bc_indices_masker_warp.py ......                                     [ 35%]
tests/grids/test_grid_jax.py .......                                                                     [ 43%]
tests/grids/test_grid_warp.py ....                                                                       [ 47%]
tests/kernels/collision/test_bgk_collision_jax.py ......                                                 [ 53%]
tests/kernels/collision/test_bgk_collision_warp.py ......                                                [ 60%]
tests/kernels/equilibrium/test_equilibrium_jax.py ......                                                 [ 66%]
tests/kernels/equilibrium/test_equilibrium_warp.py ......                                                [ 73%]
tests/kernels/macroscopic/test_macroscopic_jax.py ......                                                 [ 79%]
tests/kernels/macroscopic/test_macroscopic_warp.py .......                                               [ 87%]
tests/kernels/stream/test_stream_jax.py ......                                                           [ 93%]
tests/kernels/stream/test_stream_warp.py ......                                                          [100%]

======================================== 93 passed in 248.34s (0:04:08) ========================================

Linting and Code Formatting

Make sure the code follows the project's linting and formatting standards. This project uses Ruff for linting.

To run Ruff, execute the following command from the root of the repository:

ruff check .
  • Ruff passes

hsalehipour and others added 14 commits March 9, 2026 15:57
* Fixed some runtime bugs

* fixed some naming/spelling errors

* removed some debugging comments.

* Introduced a new file `cell_type.py` containing boundary-mask constants for fluid voxelss to replace hardcoded values with the new constants.

* Applied renaming of 254 to SFV to function names
- Unified multi-resolution recursion builder in `simulation_manager.py` to streamline the construction of simulation steps.
- Refactored nse_multires_stepper for improved clarity
- Updated performance optimization handling in `multires_momentum_transfer.py` to support multiple fusion strategies.
…ine and clarify the implementation of multi-resolution streaming steps.
(refactoring) Cleaning up multi-res stepper.
… multi-res by ensuring consistent use of `store_dtype` and `compute_dtype`.
Fixed mixed precision handling of the Neon backend
@hsalehipour hsalehipour requested a review from mehdiataei March 13, 2026 22:13
* (build) Introducing Neon backend as an optional installation parameter.

* (install) new installation mode for neon backend.

* (build) Add ARM support for Neon wheel resolution

* (documentation) Fixes to README and AUTHORS

* (ruff) fixes to the style

* (documentation) fix list of supported python versions
Copy link
Copy Markdown
Contributor

@mehdiataei mehdiataei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @massimim, this looks like a strong contribution. I left code-level comments in this round, but I also wanted to share a few higher-level suggestions.

  1. I think framing Neon as a new compute "backend" may be slightly misleading. Conceptually, Neon here does not seem fully parallel to JAX. The implementation largely reuses Warp functionals and then executes them through Neon handles, containers, and skeletons. In that sense, Neon feels more like an execution/runtime layer on top of Warp code generation than a standalone compute backend.

I think a different framing could make the design clearer:

  • If Neon is fundamentally "Warp math + Neon execution", I would be hesitant to model it as a third peer backend throughout the operator hierarchy.
  • Instead, I would consider splitting the abstraction into:
    • kernel / math backend: JAX vs Warp
    • execution runtime: direct Warp launch vs Neon container/skeleton launch

I think this would make Neon more generic and would better highlight its real strength: the execution model and skeleton abstraction, rather than presenting it as a bespoke backend. It may also make the integration easier to extend and adopt. Several of the current issues feel like symptoms of the abstraction boundary being one layer off. This likely needs some careful design thought, but I would strongly encourage it. To me, the more compelling framing is that Neon provides a skeleton/runtime that Warp kernels can target.

  1. The multires implementation also feels too monolithic. Kernels, schedule planning, state ownership, and runtime graph compilation all live in roughly the same layer, which makes the system harder to reason about, test, and extend.

One possible improvement would be to introduce a typed MultiresPlan / Schedule layer that represents the recursive timestep as explicit operations, then have a separate Neon graph builder that lowers that plan into containers/skeletons. I would also keep simulation state in a manager and keep kernels separate from schedule construction.

  1. The topology and coordinate model feels too implicit at the moment, which makes it harder to debug and reuse. An explicit MultiresTopology or LevelInfo abstraction could help a lot, with methods such as:
  • level_shape(level)
  • global_bounds(level)
  • to_global(level, coords)
  • face_indices(level, side)
  • active_indices(level)

I think making those concepts explicit would improve both clarity and correctness.

  1. For the new BC, I suggest clearly clarifying the Re that it has been validated for.

Comment thread setup.py Outdated
Comment thread setup.py
Comment thread setup.py Outdated
Comment thread xlb/operator/stepper/nse_multires_stepper.py
Comment thread xlb/helper/simulation_manager.py
Comment thread examples/cfd/data/ahmed.json
Comment thread xlb/grid/grid.py
velocity_set = velocity_set or DefaultConfig.velocity_set
if compute_backend == ComputeBackend.WARP:
from xlb.grid.warp_grid import WarpGrid

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can delete WarpGrid import

Comment thread xlb/grid/neon_grid.py Outdated
Comment thread xlb/grid/neon_grid.py Outdated
Comment thread xlb/grid/multires_grid.py Outdated
@massimim
Copy link
Copy Markdown
Collaborator

Hi @mehdiataei , thank you for the review. This PR is the result of quite a long journey with @hsalehipour and others too.

I'll go through points 1 to 3.

  1. I see your point on the compute "backends" as Neon reuses directly most of the functionals defined by the warp backend as the python version of Neon is built on top of warp. It is different however if we consider their programming models. Neon follows a structured parallel pattern approach as it raises the abstraction level by fully hiding the complexity of multi-GPU systems (for the dense grid). I think the critical point for determining the best structure would be until we get the automatic differentiation for Neon too. I'll open an issue to track this.
  2. I agree on the structure to be too monolithic. There are some changes to the Neon API that are coming that will also need to be considered. Once these are in we can consider a restructuring.
  3. I agree here too. There are still some mechanisms that Neon does not expose yet, but will be there hopefully soon along with a new graph mechanism to improve on independent kernels.

@hsalehipour and I are addressing the remaining comments and will send a PR revision.

massimim and others added 4 commits May 8, 2026 15:11
…#41)

* (build) Introducing Neon backend as an optional installation parameter.

* (install) new installation mode for neon backend.

* (build) Add ARM support for Neon wheel resolution

* (documentation) Fixes to README and AUTHORS

* (ruff) fixes to the style

* (documentation) fix list of supported python versions

* (install) Add installation unit tests for JAX, Warp and Neon backends, update utils for Warp to JAX conversion
…ion (#42)

* (build) Introducing Neon backend as an optional installation parameter.

* (install) new installation mode for neon backend.

* (build) Add ARM support for Neon wheel resolution

* (documentation) Fixes to README and AUTHORS

* (ruff) fixes to the style

* (documentation) fix list of supported python versions

* (install) Add installation unit tests for JAX, Warp and Neon backends, update utils for Warp to JAX conversion

* (install) Enhance warp-lang uninstallation process for Neon installation

- Updated `_uninstall_warp_lang` function to include a reason for uninstallation.
- Modified `InstallWithNeonHooks` class to uninstall `warp-lang` before and after installation when the `[neon]` extra is requested.
- Added a new test to verify that pre-existing `warp-lang` is uninstalled during the editable install of XLB with the `[neon]` extra.
* (build) Introducing Neon backend as an optional installation parameter.

* (install) new installation mode for neon backend.

* (build) Add ARM support for Neon wheel resolution

* (documentation) Fixes to README and AUTHORS

* (ruff) fixes to the style

* (documentation) fix list of supported python versions

* (install) Add installation unit tests for JAX, Warp and Neon backends, update utils for Warp to JAX conversion

* (install) Enhance warp-lang uninstallation process for Neon installation

- Updated `_uninstall_warp_lang` function to include a reason for uninstallation.
- Modified `InstallWithNeonHooks` class to uninstall `warp-lang` before and after installation when the `[neon]` extra is requested.
- Added a new test to verify that pre-existing `warp-lang` is uninstalled during the editable install of XLB with the `[neon]` extra.

* (cleaning) removed unnecessary 'pass' call

* (refactor) Update JSON data structure and clean up code

- Modified the `ahmed.json` file to ensure consistent formatting of velocity and height data.
- Removed a hardcoded comment in `multires_grid.py` regarding device initialization.
- Commented out the `info_print` call in `neon_grid.py` to reduce output clutter.
- Enhanced the `_create_constant_prescribed_profile` method in `bc_halfway_bounce_back.py` to better handle different compute backends and added error handling for unsupported backends.
* added missing neon functionals and fixed forced_collision

* fixed global coordinate calculations in NeonMultiresGrid for improved bounding box face detection across levels

* Remove deprecated BC from README

* fixed an issue in setting prescribed values in zouhe bc

* Fixes to the PR comments (#45)

* fix(install): add h5py when the neon installation option is selected.

* refactor(mres): changed 'omega' to 'coalescence_factor' in function signatures.

* refactor(grid): change default values of sparsity_pattern_list and sparsity_pattern_origins to None

* fix(docs): correct typos in multires_flow_past_sphere_3d.py and update NEON backend description in compute_backend.py

* ruff

---------

Co-authored-by: massimim <57805133+massimim@users.noreply.github.com>
@hsalehipour
Copy link
Copy Markdown
Collaborator Author

Thanks a lot, @mehdiataei , for the thorough review. We have now addressed the remaining issues in the latest commits, and the changes are ready to be merged.

@mehdiataei
Copy link
Copy Markdown
Contributor

Hi @hsalehipour , maybe I’m missing something, but are there any benchmarks reported for this implementation? For example, DrivAer at the reference Reynolds number of 4.87M.

I think this is important, since people may use this for industrial applications. If those benchmarks cannot be passed, it would be helpful to clearly document the limits of the method in an MD file inside the example folder, along with the multi-resolution example and the relevant explanations.

Also, there are a number of WIP/debugging commits that should be squashed. Could you please clean those up?

@hsalehipour
Copy link
Copy Markdown
Collaborator Author

@mehdiataei Agreed. The new example multires_windtunnel_3d already serves that purpose and relies on ahmed.json for reference data.

We attempted to squash just the WIP commits via interactive rebase. The squash itself worked, but replaying the ~190 commits that sit above the squash range produced merge conflicts in the upper commits, and we aborted to keep the branch intact. Resolving those conflicts is feasible but non-trivial, so we're proposing simple merge instead as we did for the OOC PR.

@mehdiataei
Copy link
Copy Markdown
Contributor

mehdiataei commented May 14, 2026

What is the Reynolds number range within which the BC is intended to operate? I believe this should be explicitly clarified, particularly since the BC was previously presented as a novel design. In addition, I suggest either add a README file explaining the validity and add the tables etc or cite the appropriate literature on the method.

Merging such a large number of commits (269 commits while the full library history is just 369 commits) into one PR is fairly unconventional and also risks reducing the visibility of substantial prior contributions by pushing them further back in the project history. We can also fix the OCC merge if that's your concern.

Let me know what you think.

@mehdiataei
Copy link
Copy Markdown
Contributor

If the concern is primarily around conflict resolution or preserving the existing branch state, I would be happy to help with the squashing/rebase process to make the PR history cleaner and more manageable.

Please let me know!

@hsalehipour
Copy link
Copy Markdown
Collaborator Author

I don’t think it would be appropriate to provide empirical bounds on the stability or accuracy of HybridBC, as these characteristics are highly problem- and geometry-dependent. Based on the experiments conducted so far, the method has demonstrated strong stability. Any discrepancies with reference data are likely influenced by other important modeling components that are not yet incorporated into the library, such as wall models.

Regarding the commit history, Max and I, as maintainers of the library, have decided to follow a squash-and-merge strategy for these integrations in order to keep the history concise and manageable. It would be appreciated if the same approach could also be applied retroactively to the previous OOC merge.

@hsalehipour hsalehipour merged commit 01c648a into Autodesk:main May 14, 2026
10 checks passed
@github-actions github-actions Bot locked and limited conversation to collaborators May 14, 2026
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants