Skip to content

Commit

Permalink
Add eager support for aten:: equal. (#12020)
Browse files Browse the repository at this point in the history
  • Loading branch information
WilBrady authored Jun 30, 2022
1 parent 0ee0b8c commit 0fa2041
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 4 deletions.
2 changes: 1 addition & 1 deletion orttraining/orttraining/eager/opgen/opgen/atenops.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def __init__(self, dY, X):
"aten::_local_scalar_dense": MakeTorchFallback(),
"aten::gt.Scalar_out": MakeTorchFallback(),
"aten::lt.Scalar_out": MakeTorchFallback(),
"aten::equal": MakeTorchFallback(),
"aten::equal": SignatureOnly(),
"aten::_softmax": Softmax("self", axis="dim"),
"aten::argmax.out": SignatureOnly(),
}
Expand Down
66 changes: 66 additions & 0 deletions orttraining/orttraining/eager/ort_aten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,72 @@ at::Tensor& out) {
return out;
}

// aten::equal(Tensor self, Tensor other) -> bool
bool equal(
const at::Tensor& self,
const at::Tensor& other) {
ORT_LOG_FN(self, other);

if (
std::vector<at::ScalarType> supportedTypes =
{at::kFloat, at::kBFloat16, at::kHalf, at::kDouble, at::kLong, at::kByte, at::kInt, at::kShort, at::kBool};
!IsSupportedType(self, supportedTypes) ||
!IsSupportedType(other, supportedTypes)) {
return at::native::call_fallback_fn<
&at::native::cpu_fallback,
ATEN_OP(equal)>::call(self, other);
}

auto& invoker = GetORTInvoker(self.device());

auto ort_input_self = create_ort_value(invoker, self);
auto ort_input_other = create_ort_value(invoker, other);

auto& ort_tensor_self = ort_input_self.Get<onnxruntime::Tensor>();
auto& shape_self = ort_tensor_self.Shape();
auto& ort_tensor_other = ort_input_other.Get<onnxruntime::Tensor>();
auto& shape_other = ort_tensor_other.Shape();

// ensure shape is equal
if (shape_self != shape_other) return false;

// to check content, we'll do elementwise comparison
// then we'll reduce to the mininum value based on false
// being less than true, so any false will reduce to false.
std::vector<OrtValue> ort_outputs_0_Equal(1);

auto equalStatus = invoker.Invoke("Equal", {
std::move(ort_input_self),
std::move(ort_input_other),
}, ort_outputs_0_Equal, nullptr);

if (!equalStatus.IsOK())
throw std::runtime_error(
"ORT Equal return failure status:" + equalStatus.ErrorMessage());

// now reduce the resulting tensor of bool values to its minimum value (any false)
NodeAttributes attrs(1);
attrs["keepdims"] = create_ort_attribute(
"keepdims", 0, at::ScalarType::Int);

std::vector<OrtValue> ort_outputs_0_ReduceMin(1);

// ReduceMin does not support bool or short and CastToType does not support Byte because
// GetONNXTensorProtoDataType doesn't support byte, which leaves us with int
OrtValue equalAsInt = CastToType(invoker, ort_outputs_0_Equal[0], at::ScalarType::Int);

auto reduceStatus = invoker.Invoke("ReduceMin", {
std::move(equalAsInt),
}, ort_outputs_0_ReduceMin, &attrs);

if (!reduceStatus.IsOK())
throw std::runtime_error(
"ORT ReduceMin return failure reduceStatus:" + reduceStatus.ErrorMessage());

auto* ort_tensor = ort_outputs_0_ReduceMin[0].GetMutable<onnxruntime::Tensor>();
// the first (and only) value of the tensor will be 0 for false else true
return *(ort_tensor->Data<int>()) != 0;
}

} // namespace aten

Expand Down
26 changes: 23 additions & 3 deletions orttraining/orttraining/eager/test/ort_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,29 @@ def test_min(self):

def test_equal(self):
device = self.get_device()
cpu_x = torch.ones(3, 3, dtype=torch.float32)
cpu_y = torch.ones(3, 3, dtype=torch.float32)
assert torch.equal(cpu_x.to(device), cpu_y.to(device))
cpu_a = torch.Tensor([1.0, 1.5])
ort_a = cpu_a.to(device)
cpu_b = torch.Tensor([1.0, 1.5])
ort_b = cpu_b.to(device)
cpu_c = torch.Tensor([1.0, 1.8])
ort_c = cpu_c.to(device)
cpu_d = torch.Tensor([1.0, 1.5, 2.1])
ort_d = cpu_d.to(device)
cpu_e = torch.Tensor([[1.0, 1.5]])
ort_e = cpu_e.to(device)

# a = b
assert torch.equal(cpu_a, cpu_b)
assert torch.equal(ort_a, ort_b)
# a != c based on one value
assert not torch.equal(cpu_a, cpu_c)
assert not torch.equal(ort_a, ort_c)
# a != d because size of dim 1 is not equal
assert not torch.equal(cpu_a, cpu_d)
assert not torch.equal(ort_a, ort_d)
# a != e because dim does not match
assert not torch.equal(cpu_a, cpu_e)
assert not torch.equal(ort_a, ort_e)

def test_torch_ones(self):
device = self.get_device()
Expand Down

0 comments on commit 0fa2041

Please sign in to comment.