Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add paddle backend in devel #4157

Draft
wants to merge 46 commits into
base: devel
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
59b9af5
Add paddle backend code(WIP)
HydrogenSulfate Sep 3, 2024
7f32618
update runnable code with water/se_e2_a
HydrogenSulfate Sep 4, 2024
217bf36
update correct water/se_e2_a code
HydrogenSulfate Sep 4, 2024
0ad720d
fix extra state
HydrogenSulfate Sep 5, 2024
1d7b0d1
Fix typo and bugs
HydrogenSulfate Sep 6, 2024
f6eeef6
fix concat
HydrogenSulfate Sep 6, 2024
ec021f7
update fix code
HydrogenSulfate Sep 7, 2024
55f71f6
add normalize composite impl
HydrogenSulfate Sep 9, 2024
ebdbed2
use prim function instead of vanilla function, supporting
HydrogenSulfate Sep 10, 2024
e0aeb73
Merge branch 'devel' into add_paddle_backend
HydrogenSulfate Sep 10, 2024
0896c90
Merge branch 'devel' into add_paddle_backend
HydrogenSulfate Sep 13, 2024
2e79d68
update inference code(WIP)
HydrogenSulfate Sep 14, 2024
482d588
update CMAKE code(WIP)
HydrogenSulfate Sep 18, 2024
a3c4663
update CMAKE
HydrogenSulfate Sep 18, 2024
2b4832a
add build script
HydrogenSulfate Sep 18, 2024
f1100e4
Update water/se_e2_a + LAMMPS code
HydrogenSulfate Sep 23, 2024
3068869
Merge branch 'devel' into add_paddle_backend
HydrogenSulfate Sep 23, 2024
68a2d62
fix get_pd_version
HydrogenSulfate Sep 23, 2024
ff7e0ef
fix read_env.py
HydrogenSulfate Sep 23, 2024
023ba53
fix suffix
HydrogenSulfate Sep 23, 2024
8a1834f
fix pd/cxx_op.py
HydrogenSulfate Sep 23, 2024
396bd54
fix main.py
HydrogenSulfate Sep 23, 2024
66734bc
fix get_item_paddle
HydrogenSulfate Sep 23, 2024
ba02ae8
restore in.lammps
HydrogenSulfate Sep 23, 2024
40157dd
restore pyproject.toml
HydrogenSulfate Sep 23, 2024
8d53aec
simplify CMAKE
HydrogenSulfate Sep 23, 2024
e39d466
restore c_api.cc
HydrogenSulfate Sep 23, 2024
b97571e
fix bugs
HydrogenSulfate Sep 24, 2024
50092c6
change pt -> pd
HydrogenSulfate Sep 24, 2024
e02dd11
refactor DeepPotPD.cc
HydrogenSulfate Sep 24, 2024
8343077
update commonPD.h
HydrogenSulfate Sep 24, 2024
09c54f3
remove boost
HydrogenSulfate Sep 24, 2024
bc854b2
refine code
HydrogenSulfate Sep 25, 2024
892fd80
update refined infer code
HydrogenSulfate Sep 26, 2024
04b9064
Merge branch 'devel' into add_paddle_backend
HydrogenSulfate Sep 27, 2024
b3a6408
remove redundant code
HydrogenSulfate Sep 27, 2024
15a7e75
refine docstring of get_buffer
HydrogenSulfate Sep 27, 2024
c792563
Merge branch 'devel' into add_paddle_backend
HydrogenSulfate Sep 27, 2024
72241ea
update pd version code
HydrogenSulfate Sep 27, 2024
8694476
restore non related files
HydrogenSulfate Sep 27, 2024
e246a34
add paddle to related docs
HydrogenSulfate Sep 27, 2024
5ee8bcf
optimize cmake paddle macro name
HydrogenSulfate Sep 27, 2024
e3c1ceb
update parallel training with paddle backend
HydrogenSulfate Sep 27, 2024
49ba5a5
use 0-D Tensor as buffer shape
HydrogenSulfate Sep 27, 2024
f1cae59
support DCU(rocm)
HydrogenSulfate Sep 29, 2024
a83fb63
simplify compile code
HydrogenSulfate Sep 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 31 additions & 31 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,19 +51,19 @@ repos:
hooks:
- id: blacken-docs
# C++
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v18.1.8
hooks:
- id: clang-format
exclude: ^source/3rdparty|source/lib/src/gpu/cudart/.+\.inc
# - repo: https://github.com/pre-commit/mirrors-clang-format
# rev: v18.1.8
# hooks:
# - id: clang-format
# exclude: ^source/3rdparty|source/lib/src/gpu/cudart/.+\.inc
# markdown, yaml, CSS, javascript
- repo: https://github.com/pre-commit/mirrors-prettier
rev: v4.0.0-alpha.8
hooks:
- id: prettier
types_or: [markdown, yaml, css]
# workflow files cannot be modified by pre-commit.ci
exclude: ^(source/3rdparty|\.github/workflows|\.clang-format)
# - repo: https://github.com/pre-commit/mirrors-prettier
# rev: v4.0.0-alpha.8
# hooks:
# - id: prettier
# types_or: [markdown, yaml, css]
# # workflow files cannot be modified by pre-commit.ci
# exclude: ^(source/3rdparty|\.github/workflows|\.clang-format)
# Shell
- repo: https://github.com/scop/pre-commit-shfmt
rev: v3.9.0-1
Expand All @@ -75,25 +75,25 @@ repos:
hooks:
- id: cmake-format
#- id: cmake-lint
- repo: https://github.com/njzjz/mirrors-bibtex-tidy
rev: v1.13.0
hooks:
- id: bibtex-tidy
args:
- --curly
- --numeric
- --align=13
- --blank-lines
# disable sort: the order of keys and fields has explict meanings
#- --sort=key
- --duplicates=key,doi,citation,abstract
- --merge=combine
#- --sort-fields
#- --strip-comments
- --trailing-commas
- --encode-urls
- --remove-empty-fields
- --wrap=80
# - repo: https://github.com/njzjz/mirrors-bibtex-tidy
# rev: v1.13.0
# hooks:
# - id: bibtex-tidy
# args:
# - --curly
# - --numeric
# - --align=13
# - --blank-lines
# # disable sort: the order of keys and fields has explict meanings
# #- --sort=key
# - --duplicates=key,doi,citation,abstract
# - --merge=combine
# #- --sort-fields
# #- --strip-comments
# - --trailing-commas
# - --encode-urls
# - --remove-empty-fields
# - --wrap=80
# license header
- repo: https://github.com/Lucas-C/pre-commit-hooks
rev: v1.5.5
Expand Down
5 changes: 5 additions & 0 deletions backend/dp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

from scikit_build_core import build as _orig

from .find_paddle import (
find_paddle,
)
from .find_pytorch import (
find_pytorch,
)
Expand Down Expand Up @@ -47,6 +50,7 @@ def get_requires_for_build_wheel(
_orig.get_requires_for_build_wheel(config_settings)
+ find_tensorflow()[1]
+ find_pytorch()[1]
+ find_paddle()[1]
)


Expand All @@ -57,4 +61,5 @@ def get_requires_for_build_editable(
_orig.get_requires_for_build_editable(config_settings)
+ find_tensorflow()[1]
+ find_pytorch()[1]
+ find_paddle()[1]
)
145 changes: 145 additions & 0 deletions backend/find_paddle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import importlib
import os
import site
from functools import (
lru_cache,
)
from importlib.machinery import (
FileFinder,
)
from importlib.util import (
find_spec,
)
from pathlib import (
Path,
)
from sysconfig import (
get_path,
)
from typing import (
List,
Optional,
Tuple,
Union,
)

from packaging.version import (
Version,
)


@lru_cache
def find_paddle() -> Tuple[Optional[str], List[str]]:
"""Find PaddlePadle library.

Tries to find PaddlePadle in the order of:

1. Environment variable `PADDLE_ROOT` if set
2. The current Python environment.
3. user site packages directory if enabled
4. system site packages directory (purelib)

Considering the default PaddlePadle package still uses old CXX11 ABI, we
cannot install it automatically.

Returns
-------
str, optional
PaddlePadle library path if found.
list of str
TensorFlow requirement if not found. Empty if found.
"""
if os.environ.get("DP_ENABLE_PADDLE", "0") == "0":
return None, []
requires = []
pd_spec = None

if (pd_spec is None or not pd_spec) and os.environ.get("PADDLE_ROOT") is not None:
site_packages = Path(os.environ.get("PADDLE_ROOT")).parent.absolute()
pd_spec = FileFinder(str(site_packages)).find_spec("paddle")

# get paddle spec
# note: isolated build will not work for backend
if pd_spec is None or not pd_spec:
pd_spec = find_spec("paddle")

if not pd_spec and site.ENABLE_USER_SITE:
# first search TF from user site-packages before global site-packages
site_packages = site.getusersitepackages()
if site_packages:
pd_spec = FileFinder(site_packages).find_spec("paddle")

if not pd_spec:
# purelib gets site-packages path
site_packages = get_path("purelib")
if site_packages:
pd_spec = FileFinder(site_packages).find_spec("paddle")

# get install dir from spec
try:
pd_install_dir = pd_spec.submodule_search_locations[0] # type: ignore
# AttributeError if ft_spec is None
# TypeError if submodule_search_locations are None
# IndexError if submodule_search_locations is an empty list
except (AttributeError, TypeError, IndexError):
pd_install_dir = None
requires.extend(get_pd_requirement()["paddle"])
return pd_install_dir, requires


@lru_cache
def get_pd_requirement(pd_version: str = "") -> dict:
"""Get PaddlePadle requirement when Paddle is not installed.

If pd_version is not given and the environment variable `PADDLE_VERSION` is set, use it as the requirement.

Parameters
----------
pd_version : str, optional
Paddle version

Returns
-------
dict
PaddlePadle requirement.
"""
if pd_version is None:
return {"paddle": []}
if pd_version == "":
pd_version = os.environ.get("PADDLE_VERSION", "")

return {
"paddle": [
# uv has different local version behaviors, i.e. `==2.3.1` cannot match `==2.3.1+cpu`
# https://github.com/astral-sh/uv/blob/main/PIP_COMPATIBILITY.md#local-version-identifiers
# luckily, .* (prefix matching) defined in PEP 440 can match any local version
# https://peps.python.org/pep-0440/#version-matching
f"paddle=={Version(pd_version).base_version}.*"
if pd_version != ""
else "paddle>=3.0.0",
],
}


@lru_cache
def get_pd_version(pd_path: Optional[Union[str, Path]]) -> str:
"""Get Paddle version from a Paddle Python library path.

Parameters
----------
pd_path : str or Path
pd Python library path

Returns
-------
str
version
"""
if pd_path is None or pd_path == "":
return ""
version_file = Path(pd_path) / "version.py"
spec = importlib.util.spec_from_file_location("paddle.version", version_file)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module.full_version
17 changes: 17 additions & 0 deletions backend/read_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
Version,
)

from .find_paddle import (
find_paddle,
)
from .find_pytorch import (
find_pytorch,
get_pt_version,
Expand Down Expand Up @@ -120,6 +123,19 @@
cmake_args.append("-DENABLE_PYTORCH=OFF")
pt_version = None

if os.environ.get("DP_ENABLE_PADDLE", "0") == "1":
pd_install_dir, _ = find_paddle()
pt_version = get_pt_version(pd_install_dir)
cmake_args.extend(
[
"-DENABLE_PADDLE=ON",
f"-DCMAKE_PREFIX_PATH={pd_install_dir}",
]
)
else:
cmake_args.append("-DENABLE_PADDLE=OFF")
pd_version = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix variable assignment for Paddle version.

The implementation for Paddle support looks good overall, but there's still an issue to address:

On line 128, pt_version is assigned instead of pd_version. This appears to be a copy-paste error that was not fixed in the previous review.

Please apply the following change:

 if os.environ.get("DP_ENABLE_PADDLE", "0") == "1":
     pd_install_dir, _ = find_paddle()
-    pt_version = get_pt_version(pd_install_dir)
+    pd_version = get_pt_version(pd_install_dir)
     cmake_args.extend(
         [
             "-DENABLE_PADDLE=ON",
             f"-DCMAKE_PREFIX_PATH={pd_install_dir}",
         ]
     )
 else:
     cmake_args.append("-DENABLE_PADDLE=OFF")
     pd_version = None

This change will correctly assign the Paddle version.

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if os.environ.get("DP_ENABLE_PADDLE", "0") == "1":
pd_install_dir, _ = find_paddle()
pt_version = get_pt_version(pd_install_dir)
cmake_args.extend(
[
"-DENABLE_PADDLE=ON",
f"-DCMAKE_PREFIX_PATH={pd_install_dir}",
]
)
else:
cmake_args.append("-DENABLE_PADDLE=OFF")
pd_version = None
if os.environ.get("DP_ENABLE_PADDLE", "0") == "1":
pd_install_dir, _ = find_paddle()
pd_version = get_pt_version(pd_install_dir)
cmake_args.extend(
[
"-DENABLE_PADDLE=ON",
f"-DCMAKE_PREFIX_PATH={pd_install_dir}",
]
)
else:
cmake_args.append("-DENABLE_PADDLE=OFF")
pd_version = None


cmake_args = [
"-DBUILD_PY_IF:BOOL=TRUE",
*cmake_args,
Expand All @@ -131,6 +147,7 @@
extra_scripts,
tf_version,
pt_version,
pd_version,

Check failure

Code scanning / CodeQL

Potentially uninitialized local variable Error

Local variable 'pd_version' may be used before it is initialized.
)


Expand Down
Loading
Loading