Skip to content

Commit

Permalink
feat(jax): add options to use TensorFlow C library to build the JAX b…
Browse files Browse the repository at this point in the history
…ackend (#4357)

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

## Release Notes

- **New Features**
- Updated installation documentation to clarify requirements for
TensorFlow and JAX backends.
	- Expanded supported platforms to include Windows x86-64.
- Added instructions for enabling JAX backend during installation from
source.

- **Documentation**
- Enhanced clarity of installation prerequisites and supported
platforms.
- Included a note directing users to the TensorFlow tab for additional
information.

- **Bug Fixes**
	- Improved error handling for unsupported backend configurations.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
  • Loading branch information
njzjz and coderabbitai[bot] authored Nov 14, 2024
1 parent d5295d5 commit 058e066
Show file tree
Hide file tree
Showing 8 changed files with 111 additions and 5 deletions.
4 changes: 4 additions & 0 deletions doc/install/easy-install.md
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,10 @@ pip install deepmd-kit[jax]

:::::

To generate a SavedModel and use [the LAMMPS module](../third-party/lammps-command.md) and [the i-PI driver](../third-party/ipi.md),
you need to install the TensorFlow.
Switch to the TensorFlow {{ tensorflow_icon }} tab for more information.

::::::

:::::::
Expand Down
4 changes: 2 additions & 2 deletions doc/install/install-from-c-library.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Install from pre-compiled C library {{ tensorflow_icon }}
# Install from pre-compiled C library {{ tensorflow_icon }}, JAX {{ jax_icon }}

:::{note}
**Supported backends**: TensorFlow {{ tensorflow_icon }}
**Supported backends**: TensorFlow {{ tensorflow_icon }}, JAX {{ jax_icon }}
:::

DeePMD-kit provides pre-compiled C library package (`libdeepmd_c.tar.gz`) in each [release](https://github.com/deepmodeling/deepmd-kit/releases). It can be used to build the [LAMMPS plugin](./install-lammps.md) and [GROMACS patch](./install-gromacs.md), as well as many [third-party software packages](../third-party/out-of-deepmd-kit.md), without building TensorFlow and DeePMD-kit on one's own.
Expand Down
31 changes: 31 additions & 0 deletions doc/install/install-from-source.md
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,15 @@ You can also download libtorch prebuilt library from the [PyTorch website](https

:::

:::{tab-item} JAX {{ jax_icon }}

The JAX backend only depends on the TensorFlow C API, which is included in both TensorFlow C++ library and [TensorFlow C library](https://www.tensorflow.org/install/lang_c).
If you want to use the TensorFlow C++ library, just enable the TensorFlow backend (which depends on the TensorFlow C++ library) and nothing else needs to do.
If you want to use the TensorFlow C library and disable the TensorFlow backend,
download the TensorFlow C library from [this page](https://www.tensorflow.org/install/lang_c#download_and_extract).

:::

::::

### Install DeePMD-kit's C++ interface
Expand Down Expand Up @@ -369,6 +378,17 @@ cmake -DENABLE_PYTORCH=TRUE -DUSE_PT_PYTHON_LIBS=TRUE -DCMAKE_INSTALL_PREFIX=$de

:::

:::{tab-item} JAX {{ jax_icon }}

If you want to use the TensorFlow C++ library, just enable the TensorFlow backend and nothing else needs to do.
If you want to use the TensorFlow C library and disable the TensorFlow backend, set {cmake:variable}`ENABLE_JAX` to `ON` and `CMAKE_PREFIX_PATH` to the root directory of the [TensorFlow C library](https://www.tensorflow.org/install/lang_c).

```bash
cmake -DENABLE_JAX=ON -D CMAKE_PREFIX_PATH=${tensorflow_c_root} ..
```

:::

::::

One may add the following CMake variables to `cmake` using the [`-D <var>=<value>` option](https://cmake.org/cmake/help/latest/manual/cmake.1.html#cmdoption-cmake-D):
Expand All @@ -378,6 +398,7 @@ One may add the following CMake variables to `cmake` using the [`-D <var>=<value
**Type**: `BOOL` (`ON`/`OFF`), Default: `OFF`

{{ tensorflow_icon }} {{ jax_icon }} Whether building the TensorFlow backend and the JAX backend.
Setting this option to `ON` will also set {cmake:variable}`ENABLE_JAX` to `ON`.

:::

Expand All @@ -389,6 +410,16 @@ One may add the following CMake variables to `cmake` using the [`-D <var>=<value

:::

:::{cmake:variable} ENABLE_JAX

**Type**: `BOOL` (`ON`/`OFF`), Default: `OFF`

{{ jax_icon }} Build the JAX backend.
If {cmake:variable}`ENABLE_TENSORFLOW` is `ON`, the TensorFlow C++ library is used to build the JAX backend;
If {cmake:variable}`ENABLE_TENSORFLOW` is `OFF`, the TensorFlow C library is used to build the JAX backend.

:::

:::{cmake:variable} TENSORFLOW_ROOT

**Type**: `PATH`
Expand Down
25 changes: 25 additions & 0 deletions source/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ project(DeePMD)

option(ENABLE_TENSORFLOW "Enable TensorFlow interface" OFF)
option(ENABLE_PYTORCH "Enable PyTorch interface" OFF)
option(ENABLE_JAX "Enable JAX interface" OFF)
if(ENABLE_TENSORFLOW)
# JAX requires TF C interface, contained in TF C++ library
set(ENABLE_JAX ON)
endif()
option(BUILD_TESTING "Build test and enable coverage" OFF)
set(DEEPMD_C_ROOT
""
Expand Down Expand Up @@ -246,6 +251,22 @@ if(ENABLE_PYTORCH AND NOT DEEPMD_C_ROOT)
list(APPEND BACKEND_LIBRARY_PATH ${PyTorch_LIBRARY_PATH})
list(APPEND BACKEND_INCLUDE_DIRS ${TORCH_INCLUDE_DIRS})
endif()
if(ENABLE_JAX
AND BUILD_CPP_IF
AND NOT DEEPMD_C_ROOT)
# no way to find it using Python
find_package(TensorFlowC REQUIRED MODULE)
if(DEFINED TENSORFLOWC_LIBRARY)
list(APPEND BACKEND_LIBRARY_PATH ${TENSORFLOWC_LIBRARY})
endif()
if(DEFINED TENSORFLOWC_INCLUDE_DIR)
list(APPEND BACKEND_INCLUDE_DIRS ${TENSORFLOWC_INCLUDE_DIR})
endif()
endif()
if(NOT DEFINED OP_CXX_ABI)
# prevent setting an empty value; this is default on GCC>=5
set(OP_CXX_ABI 1)
endif()
# log enabled backends
if(NOT DEEPMD_C_ROOT)
message(STATUS "Enabled backends:")
Expand All @@ -255,8 +276,12 @@ if(NOT DEEPMD_C_ROOT)
if(ENABLE_PYTORCH)
message(STATUS "- PyTorch")
endif()
if(ENABLE_JAX)
message(STATUS "- JAX")
endif()
if(NOT ENABLE_TENSORFLOW
AND NOT ENABLE_PYTORCH
AND NOT ENABLE_JAX
AND NOT BUILD_PY_IF)
message(FATAL_ERROR "No backend is enabled.")
endif()
Expand Down
4 changes: 4 additions & 0 deletions source/api_cc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ if(ENABLE_PYTORCH
target_link_libraries(${libname} PRIVATE "${TORCH_LIBRARIES}")
target_compile_definitions(${libname} PRIVATE BUILD_PYTORCH)
endif()
if(ENABLE_JAX)
target_link_libraries(${libname} PRIVATE TensorFlow::tensorflow_c)
target_compile_definitions(${libname} PRIVATE BUILD_JAX)
endif()

target_include_directories(
${libname}
Expand Down
6 changes: 4 additions & 2 deletions source/api_cc/src/DeepPot.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
#include "AtomMap.h"
#include "common.h"
#ifdef BUILD_TENSORFLOW
#include "DeepPotJAX.h"
#include "DeepPotTF.h"
#endif
#ifdef BUILD_PYTORCH
#include "DeepPotPT.h"
#endif
#if defined(BUILD_TENSORFLOW) || defined(BUILD_JAX)
#include "DeepPotJAX.h"
#endif
#include "device.h"

using namespace deepmd;
Expand Down Expand Up @@ -63,7 +65,7 @@ void DeepPot::init(const std::string& model,
} else if (deepmd::DPBackend::Paddle == backend) {
throw deepmd::deepmd_exception("PaddlePaddle backend is not supported yet");
} else if (deepmd::DPBackend::JAX == backend) {
#ifdef BUILD_TENSORFLOW
#if defined(BUILD_TENSORFLOW) || defined(BUILD_JAX)
dp = std::make_shared<deepmd::DeepPotJAX>(model, gpu_rank, file_content);
#else
throw deepmd::deepmd_exception(
Expand Down
2 changes: 1 addition & 1 deletion source/api_cc/src/DeepPotJAX.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: LGPL-3.0-or-later
#ifdef BUILD_TENSORFLOW
#if defined(BUILD_TENSORFLOW) || defined(BUILD_JAX)

#include "DeepPotJAX.h"

Expand Down
40 changes: 40 additions & 0 deletions source/cmake/FindTensorFlowC.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Find TensorFlow C library (libtensorflow) Define target
# TensorFlow::tensorflow_c If TensorFlow::tensorflow_cc is not found, also
# define: - TENSORFLOWC_INCLUDE_DIR - TENSORFLOWC_LIBRARY

if(TARGET TensorFlow::tensorflow_cc)
# since tensorflow_cc contain tensorflow_c, just use it
add_library(TensorFlow::tensorflow_c ALIAS TensorFlow::tensorflow_cc)
set(TensorFlowC_FOUND TRUE)
endif()

if(NOT TensorFlowC_FOUND)
find_path(
TENSORFLOWC_INCLUDE_DIR
NAMES tensorflow/c/c_api.h
PATH_SUFFIXES include
DOC "Path to TensorFlow C include directory")

find_library(
TENSORFLOWC_LIBRARY
NAMES tensorflow
PATH_SUFFIXES lib
DOC "Path to TensorFlow C library")

include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(
TensorFlowC REQUIRED_VARS TENSORFLOWC_LIBRARY TENSORFLOWC_INCLUDE_DIR)

if(TensorFlowC_FOUND)
set(TensorFlowC_INCLUDE_DIRS ${TENSORFLOWC_INCLUDE_DIR})
set(TensorFlowC_LIBRARIES ${TENSORFLOWC_LIBRARY})
endif()

add_library(TensorFlow::tensorflow_c SHARED IMPORTED GLOBAL)
set_property(TARGET TensorFlow::tensorflow_c PROPERTY IMPORTED_LOCATION
${TENSORFLOWC_LIBRARY})
target_include_directories(TensorFlow::tensorflow_c
INTERFACE ${TENSORFLOWC_INCLUDE_DIR})

mark_as_advanced(TENSORFLOWC_LIBRARY TENSORFLOWC_INCLUDE_DIR)
endif()

0 comments on commit 058e066

Please sign in to comment.