Skip to content

Commit

Permalink
【Hackathon 5th No.15】 为 Paddle 新增 Tensor.to() 以及 Layer.astype() API -…
Browse files Browse the repository at this point in the history
…part (#58244)

* support tensor.to and layer.astype

* add UT and comments

* update

* update com

* update dtype

* fix example test

* update example

* add some ut to fix ci-coverage

* fix codestyle

* fix codestyle

* update

* add ut to test layer params' and buffers' type

* update test

* fix doc

* Update python/paddle/base/dygraph/tensor_patch_methods.py

Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com>

* update doc

---------

Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com>
  • Loading branch information
YibinLiu666 and sunzhongkai588 committed Nov 9, 2023
1 parent 3d167b5 commit 2927ed9
Show file tree
Hide file tree
Showing 4 changed files with 413 additions and 0 deletions.
124 changes: 124 additions & 0 deletions python/paddle/base/dygraph/tensor_patch_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,129 @@ def transform(t, device, dtype, blocking):
warnings.filterwarnings("ignore", category=UserWarning)
return transform(self, device, dtype, blocking)

@framework.dygraph_only
def to(self, *args, **kwargs):
"""
Performs Tensor dtype and/or device conversion. A paddle.dtype and place
are inferred from the arguments of ``self.to(*args, **kwargs)``.There are
three ways to call `to`:
1. to(dtype, blocking=True)
2. to(device, dtype=None, blocking=True)
3. to(other, blocking=True)
Returns:
Tensor: self
Examples:
.. code-block:: python
>>> import paddle
>>> tensorx = paddle.to_tensor([1,2,3])
>>> print(tensorx)
Tensor(shape=[3], dtype=int64, place=Place(gpu:0), stop_gradient=True,
[1, 2, 3])
>>> tensorx = tensorx.to("cpu")
>>> print(tensorx.place)
Place(cpu)
>>> tensorx = tensorx.to("float32")
>>> print(tensorx.dtype)
paddle.float32
>>> tensorx = tensorx.to("gpu", "int16")
>>> print(tensorx)
Tensor(shape=[3], dtype=int16, place=Place(gpu:0), stop_gradient=True,
[1, 2, 3])
>>> tensor2 = paddle.to_tensor([4,5,6])
>>> tensor2
Tensor(shape=[3], dtype=int64, place=Place(gpu:0), stop_gradient=True,
[4, 5, 6])
>>> tensor2 = tensor2.to(tensorx)
>>> print(tensor2)
Tensor(shape=[3], dtype=int16, place=Place(gpu:0), stop_gradient=True,
[4, 5, 6])
"""
device = None
dtype = None
blocking = None
size_args = len(args)
size_kwargs = len(kwargs)

def get_device_dtype_from_tensor(other):
if other is not None:
device = str(other.place)[6:-1]
dtype = other.dtype
return device, dtype
else:
return None, None

if size_args + size_kwargs > 3 or size_args + size_kwargs == 0:
raise TypeError(
"to() received too mant arguments - expected one of:\n \
* (Union[str, paddle.CPUPlace(), paddle.CUDAPlace(), paddle.CUDAPinnedPlace(), paddle.XPUPlace(), paddle.CustomPlace()] \
device, Union[str, paddle.dtype, numpy.dtype] dtype, bool blocking)\n \
* (Union[str, paddle.dtype, numpy.dtype] dtype, bool blocking)\n \
* (paddle.Tensor other, bool blocking) "
)
valid_keys = {"device", "dtype", "blocking", "other"}
valid_dtypes = [
"bfloat16",
"float16",
"float32",
"float64",
"int8",
"int16",
"int32",
"int64",
"uint8",
"complex64",
"complex128",
"bool",
]
invalid_keys = set(kwargs.keys()) - valid_keys
if len(invalid_keys) != 0:
raise TypeError(
"to() got an unexpected keyword argument "
+ list(invalid_keys)[0]
)
if size_args > 0:
if isinstance(args[0], paddle.Tensor):
device, dtype = get_device_dtype_from_tensor(args[0])
if size_args == 2:
blocking = args[1]
else:
blocking = kwargs.get("blocking", None)
elif (
isinstance(args[0], (paddle.dtype, np.dtype))
or isinstance(args[0], str)
and args[0].lower() in valid_dtypes
):
dtype = args[0]
if size_args == 2:
blocking = args[1]
else:
blocking = kwargs.get("blocking", None)
else:
device = args[0]
if size_args == 2:
dtype = args[1]
elif size_args == 3:
dtype, blocking = args[1], args[2]
else:
dtype = kwargs.get("dtype", None)
blocking = kwargs.get("blocking", None)
else:
device = kwargs.get("device", None)
dtype = kwargs.get("dtype", None)
blocking = kwargs.get("blocking", None)
if device is None and dtype is None:
device, dtype = get_device_dtype_from_tensor(
kwargs.get("other", None)
)
return self._to(device, dtype, blocking)

@property
def grad(self):
"""
Expand Down Expand Up @@ -1020,6 +1143,7 @@ def coalesce(self, name=None):
("item", item),
("__setitem__", __setitem__),
("_to", _to),
("to", to),
("values", values),
("to_dense", to_dense),
("to_sparse_coo", to_sparse_coo),
Expand Down
80 changes: 80 additions & 0 deletions python/paddle/nn/layer/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,6 +919,86 @@ def parameters(self, include_sublayers=True):
]
return ret

def astype(self, dtype=None):
"""
Casts all parameters and buffers to dtype and then return the Layer.
Parameters:
dtype(str|paddle.dtype|numpy.dtype): target data type of layer.
If set str, it can be "bool", "bfloat16", "float16", "float32", "float64",
"int8", "int16", "int32", "int64", "uint8", "complex64", "complex128".
Default: None
Returns:
Layer, self
Examples:
.. code-block:: python
>>> import paddle
>>> import paddle.nn as nn
>>> weight_attr = paddle.ParamAttr(name="weight",initializer=paddle.nn.initializer.Constant(value=1.5))
>>> bias_attr = paddle.ParamAttr(name="bias",initializer=paddle.nn.initializer.Constant(value=2.5))
>>> linear = paddle.nn.Linear(2, 2, weight_attr=weight_attr, bias_attr=bias_attr).to(device="cpu",dtype="float32")
>>> print(linear)
Linear(in_features=2, out_features=2, dtype=float32)
>>> print(linear.parameters())
[Parameter containing:
Tensor(shape=[2, 2], dtype=float32, place=Place(cpu), stop_gradient=False,
[[1.50000000, 1.50000000],
[1.50000000, 1.50000000]]), Parameter containing:
Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=False,
[2.50000000, 2.50000000])]
>>> linear=linear.astype("int8")
>>> print(linear)
Linear(in_features=2, out_features=2, dtype=paddle.int8)
>>> print(linear.parameters())
[Parameter containing:
Tensor(shape=[2, 2], dtype=int8, place=Place(cpu), stop_gradient=False,
[[1, 1],
[1, 1]]), Parameter containing:
Tensor(shape=[2], dtype=int8, place=Place(cpu), stop_gradient=False,
[2, 2])]
"""
valid_dtypes = [
"bfloat16",
"float16",
"float32",
"float64",
"int8",
"int16",
"int32",
"int64",
"uint8",
"complex64",
"complex128",
"bool",
]
if (
isinstance(dtype, (paddle.dtype, np.dtype))
or type(dtype) is str
and dtype in valid_dtypes
):
if isinstance(dtype, (str, np.dtype)):
dtype = framework.convert_np_dtype_to_dtype_(dtype)
self._dtype = dtype
for layer in self.sublayers():
layer._dtype = dtype
for _, param in self.named_parameters(include_sublayers=True):
param._to(None, dtype)
for _, buffer in self.named_buffers(include_sublayers=True):
buffer.to(None, dtype)
return self
else:
raise ValueError(
"dtype value error, must be 'bfloat16', 'float16', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64', 'uint8', 'complex64', 'complex128', 'bool', or paddle.dtype, numpy.dtype, but recieve "
+ str(dtype)
)

def children(self):
"""
Expand Down
147 changes: 147 additions & 0 deletions test/legacy_test/test_Tensor_to.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import paddle
from paddle import base


class TensorToTest(unittest.TestCase):
def test_Tensor_to_dtype(self):
tensorx = paddle.to_tensor([1, 2, 3])
valid_dtypes = [
"bfloat16",
"float16",
"float32",
"float64",
"int8",
"int16",
"int32",
"int64",
"uint8",
"complex64",
"complex128",
"bool",
]
for dtype in valid_dtypes:
tensorx = tensorx.to(dtype)
typex_str = str(tensorx.dtype)
self.assertTrue(typex_str, "paddle." + dtype)

def test_Tensor_to_device(self):
tensorx = paddle.to_tensor([1, 2, 3])
places = ["cpu"]
if base.core.is_compiled_with_cuda():
places.append("gpu:0")
places.append("gpu")

for place in places:
tensorx = tensorx.to(place)
placex_str = str(tensorx.place)
if place == "gpu":
self.assertTrue(placex_str, "Place(" + place + ":0)")
else:
self.assertTrue(placex_str, "Place(" + place + ")")

def test_Tensor_to_device_dtype(self):
tensorx = paddle.to_tensor([1, 2, 3])
places = ["cpu"]
if base.core.is_compiled_with_cuda():
places.append("gpu:0")
places.append("gpu")
valid_dtypes = [
"bfloat16",
"float16",
"float32",
"float64",
"int8",
"int16",
"int32",
"int64",
"uint8",
"complex64",
"complex128",
"bool",
]
for dtype in valid_dtypes:
for place in places:
tensorx = tensorx.to(place, dtype)
placex_str = str(tensorx.place)
if place == "gpu":
self.assertTrue(placex_str, "Place(" + place + ":0)")
else:
self.assertTrue(placex_str, "Place(" + place + ")")
typex_str = str(tensorx.dtype)
self.assertTrue(typex_str, "paddle." + dtype)

def test_Tensor_to_blocking(self):
tensorx = paddle.to_tensor([1, 2, 3])
tensorx = tensorx.to("cpu", "int32", False)
placex_str = str(tensorx.place)
self.assertTrue(placex_str, "Place(cpu)")
typex_str = str(tensorx.dtype)
self.assertTrue(typex_str, "paddle.int32")
tensor2 = paddle.to_tensor([4, 5, 6])
tensor2 = tensor2.to(tensorx, False)
place2_str = str(tensor2.place)
self.assertTrue(place2_str, "Place(cpu)")
type2_str = str(tensor2.dtype)
self.assertTrue(type2_str, "paddle.int32")
tensor2 = tensor2.to("float16", False)
type2_str = str(tensor2.dtype)
self.assertTrue(type2_str, "paddle.float16")

def test_Tensor_to_other(self):
tensor1 = paddle.to_tensor([1, 2, 3], dtype="int8", place="cpu")
tensor2 = paddle.to_tensor([1, 2, 3])
tensor2 = tensor2.to(tensor1)
self.assertTrue(tensor2.dtype, tensor1.dtype)
self.assertTrue(type(tensor2.place), type(tensor1.place))

def test_kwargs(self):
tensorx = paddle.to_tensor([1, 2, 3])
tensorx = tensorx.to(device="cpu", dtype="int8", blocking=True)
placex_str = str(tensorx.place)
self.assertTrue(placex_str, "Place(cpu)")
typex_str = str(tensorx.dtype)
self.assertTrue(typex_str, "paddle.int8")
tensor2 = paddle.to_tensor([4, 5, 6])
tensor2 = tensor2.to(other=tensorx)
place2_str = str(tensor2.place)
self.assertTrue(place2_str, "Place(cpu)")
type2_str = str(tensor2.dtype)
self.assertTrue(type2_str, "paddle.int8")

def test_error(self):
tensorx = paddle.to_tensor([1, 2, 3])
# device value error
try:
tensorx = tensorx.to("error_device")
except Exception as error:
self.assertIsInstance(error, ValueError)
# to many augments
try:
tensorx = tensorx.to("cpu", "int32", False, "test_aug")
except Exception as error:
self.assertIsInstance(error, TypeError)
# invalid key
try:
tensorx = tensorx.to("cpu", "int32", test_key=False)
except Exception as error:
self.assertIsInstance(error, TypeError)


if __name__ == '__main__':
unittest.main()
Loading

0 comments on commit 2927ed9

Please sign in to comment.