Skip to content

Commit 49ec037

Browse files
committed
feat: enhance PyTorch integration by ensuring global visibility
1 parent 0c6de55 commit 49ec037

File tree

1 file changed

+29
-6
lines changed

1 file changed

+29
-6
lines changed

src/programs/CMakeLists.txt

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -268,23 +268,46 @@ if("${GMX_QMMM_PROGRAM}" STREQUAL "TENSORFLOW" OR "${GMX_QMMM_PROGRAM}" STREQUAL
268268
#target_include_directories(gmx PRIVATE "${TENSORFLOW_INCLUDE_DIR}")
269269
endif()
270270

271-
if ("${GMX_QMMM_PROGRAM}" STREQUAL "PYTORCH" OR "${GMX_QMMM_PROGRAM}" STREQUAL "DFTBPLUS_PYTORCH")
271+
if ("${GMX_QMMM_PROGRAM}" STREQUAL "PYTORCH" OR "${GMX_QMMM_PROGRAM}" STREQUAL "DFTBPLUS_PYTORCH")
272272
find_package(Torch REQUIRED)
273-
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
274273

275-
add_library(libtorch STATIC IMPORTED)
276-
set_target_properties(libtorch PROPERTIES IMPORTED_LOCATION "${TORCH_LIBRARIES}")
277-
set_property(TARGET libtorch PROPERTY CXX_STANDARD 17)
274+
# # Create missing PyTorch targets with mock properties
275+
# if(NOT TARGET torch_cpu)
276+
# add_library(torch_cpu INTERFACE)
277+
# target_compile_options(torch_cpu INTERFACE ${TORCH_CXX_FLAGS})
278+
# target_compile_definitions(torch_cpu INTERFACE "")
279+
# target_include_directories(torch_cpu INTERFACE ${TORCH_INCLUDE_DIRS})
280+
# endif()
281+
282+
# if(NOT TARGET torch_cuda)
283+
# add_library(torch_cuda INTERFACE)
284+
# target_compile_options(torch_cuda INTERFACE ${TORCH_CXX_FLAGS})
285+
# target_compile_definitions(torch_cuda INTERFACE "")
286+
# target_include_directories(torch_cuda INTERFACE ${TORCH_INCLUDE_DIRS})
287+
# endif()
288+
289+
# Make the targets globally visible
290+
set_property(TARGET torch_cpu PROPERTY IMPORTED_GLOBAL TRUE)
291+
set_property(TARGET torch_cuda PROPERTY IMPORTED_GLOBAL TRUE)
292+
set_property(TARGET torch PROPERTY IMPORTED_GLOBAL TRUE)
293+
294+
# Ensure targets exist in parent scope by exporting variables
295+
set(torch_cpu_EXISTS TRUE CACHE INTERNAL "torch_cpu target exists")
296+
set(torch_cuda_EXISTS TRUE CACHE INTERNAL "torch_cuda target exists")
297+
set(torch_EXISTS TRUE CACHE INTERNAL "torch target exists")
278298

279299
target_link_libraries(libgromacs PRIVATE ${TORCH_LIBRARIES})
280300
target_include_directories(libgromacs PRIVATE ${TORCH_INCLUDE_DIRS})
281-
301+
target_compile_definitions(libgromacs PRIVATE ${TORCH_CXX_FLAGS})
302+
282303
if(GMX_BUILD_MDRUN_ONLY)
283304
target_link_libraries(mdrun-only PRIVATE ${TORCH_LIBRARIES})
284305
target_include_directories(mdrun-only PRIVATE ${TORCH_INCLUDE_DIRS})
306+
target_compile_definitions(mdrun-only PRIVATE ${TORCH_CXX_FLAGS})
285307
else()
286308
target_link_libraries(gmx PRIVATE ${TORCH_LIBRARIES})
287309
target_include_directories(gmx PRIVATE ${TORCH_INCLUDE_DIRS})
310+
target_compile_definitions(gmx PRIVATE ${TORCH_CXX_FLAGS})
288311
endif()
289312

290313
message(STATUS "Torch found (include: ${TORCH_INCLUDE_DIRS}, lib: ${TORCH_LIBRARIES})")

0 commit comments

Comments
 (0)