forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_class_type.cpp
58 lines (49 loc) · 1.88 KB
/
test_class_type.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#include <test/cpp/jit/test_base.h>
#include <test/cpp/jit/test_utils.h>
#include <torch/torch.h>
namespace torch {
namespace jit {
void testClassTypeAddRemoveAttr() {
auto cu = std::make_shared<CompilationUnit>();
auto cls = ClassType::create("foo.bar", cu, true);
cls->addAttribute("attr1", TensorType::get(), true);
cls->addAttribute("attr2", TensorType::get());
cls->addAttribute("attr3", TensorType::get());
ASSERT_TRUE(cls->hasAttribute("attr1"));
ASSERT_TRUE(cls->hasAttribute("attr2"));
ASSERT_TRUE(cls->hasAttribute("attr3"));
// removing attribute attr2
cls->unsafeRemoveAttribute("attr2");
ASSERT_TRUE(cls->hasAttribute("attr1"));
ASSERT_FALSE(cls->hasAttribute("attr2"));
ASSERT_TRUE(cls->hasAttribute("attr3"));
// removing parameter attr1
cls->unsafeRemoveAttribute("attr1");
ASSERT_FALSE(cls->hasAttribute("attr1"));
ASSERT_FALSE(cls->hasAttribute("attr2"));
ASSERT_TRUE(cls->hasAttribute("attr3"));
// check that we can still add a non-parameter attr1 with
// different type
cls->addAttribute("attr1", IntType::get());
}
void testClassTypeAddRemoveConstant() {
auto cu = std::make_shared<CompilationUnit>();
auto cls = ClassType::create("foo.bar", cu);
cls->addConstant("const1", IValue(1));
cls->addConstant("const2", IValue(2));
cls->addConstant("const3", IValue(2));
ASSERT_EQ(cls->numConstants(), 3);
ASSERT_TRUE(cls->hasConstant("const1"));
ASSERT_TRUE(cls->hasConstant("const2"));
ASSERT_TRUE(cls->hasConstant("const3"));
ASSERT_FALSE(cls->hasConstant("const4"));
ASSERT_EQ(cls->getConstant("const1").toInt(), 1);
ASSERT_EQ(cls->getConstant("const2").toInt(), 2);
ASSERT_EQ(cls->getConstant("const2").toInt(), 3);
cls->unsafeRemoveConstant("const2");
ASSERT_TRUE(cls->hasConstant("const1"));
ASSERT_FALSE(cls->hasConstant("const2"));
ASSERT_TRUE(cls->hasConstant("const3"));
}
} // namespace jit
} // namespace torch