Skip to content

Commit

Permalink
automatic batch size for dp test
Browse files Browse the repository at this point in the history
Resolves deepmodeling#1149.

We start nbatch * natoms from 1024 (or we can set a different number), and iteratively multiply it by 2 until catching the OOM error.

A small issue is that it's a bit slow to catch the TF OOM error. It's a problem of TF and I don't know how to resolve it. Luckily we only need to catch once.
  • Loading branch information
njzjz committed Sep 22, 2021
1 parent 53f1567 commit 1117c54
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 4 deletions.
11 changes: 9 additions & 2 deletions deepmd/entrypoints/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
from deepmd.utils import random as dp_random
from deepmd.utils.data import DeepmdData
from deepmd.utils.weight_avg import weighted_average
from deepmd.utils.batch_size import AutoBatchSize

if TYPE_CHECKING:
from deepmd.infer import DeepDipole, DeepPolar, DeepPot, DeepWFC
from deepmd.infer.deep_eval import DeepTensor
from deepmd.infer.deep_tensor import DeepTensor

__all__ = ["test"]

Expand Down Expand Up @@ -69,6 +70,7 @@ def test(

# init model
dp = DeepPotential(model)
auto_batch_size = AutoBatchSize()

for cc, system in enumerate(all_sys):
log.info("# ---------------output of dp test--------------- ")
Expand All @@ -82,6 +84,7 @@ def test(
err = test_ener(
dp,
data,
auto_batch_size,
system,
numb_test,
detail_file,
Expand Down Expand Up @@ -159,6 +162,7 @@ def save_txt_file(
def test_ener(
dp: "DeepPot",
data: DeepmdData,
auto_batch_size: AutoBatchSize,
system: str,
numb_test: int,
detail_file: Optional[str],
Expand Down Expand Up @@ -226,7 +230,10 @@ def test_ener(
else:
aparam = None

ret = dp.eval(
ret = auto_batch_size.execuate_all(
dp.eval,
numb_test,
natoms,
coord,
box,
atype,
Expand Down
2 changes: 1 addition & 1 deletion deepmd/infer/deep_pot.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def _eval_inner(
feed_dict_test[self.t_fparam] = np.reshape(fparam, [-1])
if self.has_aparam:
feed_dict_test[self.t_aparam] = np.reshape(aparam, [-1])
v_out = self.sess.run (t_out, feed_dict = feed_dict_test)
v_out = run_sess(self.sess, t_out, feed_dict = feed_dict_test)
energy = v_out[0]
force = v_out[1]
virial = v_out[2]
Expand Down
117 changes: 117 additions & 0 deletions deepmd/utils/batch_size.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import logging
from typing import Callable, Tuple

import numpy as np

from deepmd.utils.errors import OutOfMemoryError

class AutoBatchSize:
"""This class allows DeePMD-kit to automatically decide the maximum
batch size that will not cause an OOM error.
Notes
-----
We assume all OOM error will raise :metd:`OutOfMemoryError`.
Parameters
----------
initial_batch_size : int, default: 1024
initial batch size (number of total atoms)
Attributes
----------
current_batch_size : int
current batch size (number of total atoms)
maximum_working_batch_size : int
maximum working batch size
minimal_not_working_batch_size : int
minimal not working batch size
"""
def __init__(self, initial_batch_size: int = 1024) -> None:
# See also PyTorchLightning/pytorch-lightning#1638
# TODO: discuss a proper initial batch size
self.current_batch_size = initial_batch_size
self.maximum_working_batch_size = 0
self.minimal_not_working_batch_size = 2**31

def execuate(self, callable: Callable, start_index: int, natoms: int) -> Tuple[int, tuple]:
"""Excuate a method with given batch size.
Parameters
----------
callable : Callable
The method should accept the batch size and start_index as parameters,
and returns execuated batch size and data.
start_index : int
start index
natoms : int
natoms
Returns
-------
int
execuated batch size * number of atoms
tuple
result from callable, None if failing to execuate
"""
try:
n_batch, result = callable(max(self.current_batch_size // natoms, 1), start_index)
except OutOfMemoryError as e:
# TODO: it's very slow to catch OOM error; I don't know what TF is doing here
# but luckily we only need to catch once
self.minimal_not_working_batch_size = min(self.minimal_not_working_batch_size, self.current_batch_size)
if self.maximum_working_batch_size >= self.minimal_not_working_batch_size:
self.maximum_working_batch_size = self.minimal_not_working_batch_size // 2
if self.minimal_not_working_batch_size <= natoms:
raise OutOfMemoryError("The callable still throws an out-of-memory (OOM) error even when batch size is 1!") from e
# adjust the next batch size
self._adjust_batch_size(0.5)
return 0, None
else:
n_tot = n_batch * natoms
self.maximum_working_batch_size = max(self.maximum_working_batch_size, n_tot)
# adjust the next batch size
if n_tot >= self.current_batch_size and self.current_batch_size * 2 < self.minimal_not_working_batch_size:
self._adjust_batch_size(2)
return n_batch, result

def _adjust_batch_size(self, factor: float):
old_batch_size = self.current_batch_size
self.current_batch_size = int(self.current_batch_size * factor)
logging.info("Adjust batch size from %d to %d" % (old_batch_size, self.current_batch_size))

def execuate_all(self, callable: Callable, total_size: int, natoms: int, *args, **kwargs) -> Tuple[np.ndarray]:
"""Excuate a method with all given data.
Parameters
----------
callable : Callable
The method should accept *args and **kwargs as input and return the similiar array.
total_size : int
Total size
natoms : int
The number of atoms
**kwargs
If 2D np.ndarray, assume the first axis is batch; otherwise do nothing.
"""
def execuate_with_batch_size(batch_size: int, start_index: int) -> Tuple[int, Tuple[np.ndarray]]:
end_index = start_index + batch_size
end_index = min(end_index, total_size)
return (end_index - start_index), callable(
*[(vv[start_index:end_index] if isinstance(vv, np.ndarray) and vv.ndim > 1 else vv) for vv in args],
**{kk: (vv[start_index:end_index] if isinstance(vv, np.ndarray) and vv.ndim > 1 else vv) for kk, vv in kwargs.items()},
)

index = 0
results = []
while index < total_size:
n_batch, result = self.execuate(execuate_with_batch_size, index, natoms)
if not isinstance(result, tuple):
result = (result,)
index += n_batch
if n_batch:
for rr in result:
rr.reshape((n_batch, -1))
results.append(result)

return tuple([np.concatenate(r, axis=0) for r in zip(*results)])
3 changes: 3 additions & 0 deletions deepmd/utils/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,6 @@ class GraphTooLargeError(Exception):

class GraphWithoutTensorError(Exception):
pass

class OutOfMemoryError(Exception):
"""This error is caused by out-of-memory (OOM)."""
3 changes: 2 additions & 1 deletion deepmd/utils/sess.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

from deepmd.env import tf
from deepmd.utils.errors import OutOfMemoryError


def run_sess(sess: tf.Session, *args, **kwargs):
Expand Down Expand Up @@ -35,4 +36,4 @@ def run_sess(sess: tf.Session, *args, **kwargs):
"variable (current value: %s).\n" % (
os.getenv("CUDA_VISIBLE_DEVICES", None),
))
raise RuntimeError(MESSAGE) from e
raise OutOfMemoryError(MESSAGE) from e

0 comments on commit 1117c54

Please sign in to comment.