forked from Glaciohound/VCML
-
Notifications
You must be signed in to change notification settings - Fork 0
/
reasoning.py
207 lines (161 loc) · 6.18 KB
/
reasoning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# File : reasoning.py
# Author : Chi Han, Jiayuan Mao
# Email : haanchi@gmail.com, maojiayuan@gmail.com
# Date : 23.07.2019
# Last Modified Date: 21.11.2019
# Last Modified By : Chi Han
#
# This file is part of the VCML codebase
# Distributed under MIT license
import torch
import torch.nn as nn
import torch.nn.functional as F
from utility.common import min_fn, log, detach
INF = 100
class ProgramExecutor(nn.Module):
def __init__(self, args, tools, device):
super().__init__()
if not args.not_build_reasoning:
self.build()
self.args = args
self.tools = tools
self.device = device
def build(self):
self.exist_offset = nn.Parameter(torch.tensor(0.))
self.exist_scale = nn.Parameter(torch.tensor(1.))
def forward(
self,
program, answer, question_cat,
objects, logits,
embedding,
):
result = None
for j, op in enumerate(program):
operation = op['operation']
argument = op['argument']
if operation == 'select_object':
result = self.select_object_fn(objects.shape[0], INF)
elif operation == 'select_concept':
result = self.select_concept_fn(embedding, argument)
elif operation == 'unique_object':
result = self.unique_object_fn(result, objects)
elif operation == 'query':
raise NotImplementedError('Querying module not implemented')
elif operation == 'filter':
result = self.filter_fn(result, logits[j])
elif operation == 'exist':
result = self.exist_fn(result)
elif operation in ['synonym', 'hypernym',
'samekind', 'meronym']:
result = self.judge_relation(
result, argument, embedding, operation)
elif operation == 'isinstanceof':
result = self.isinstanceof_fn(
result, embedding)
elif operation in ['<END>']:
pass
else:
raise Exception('unsupported opeartion: {}'.format(op))
""" modifying values"""
loss, output, debug = self.analyze(result, answer)
if question_cat == 'conceptual':
loss = loss * self.args.conceptual_weight
return loss, output, debug
# the following are operation modules
def select_object_fn(self, n, INF):
return 'object_logits', torch.ones(n).to(self.device) * INF
def select_concept_fn(self, embedding, argument):
return 'concept_embedding', \
embedding.get_embedding('concept', argument)
def unique_object_fn(self, result, objects):
weighted_sum = (F.softmax(result[1], dim=0)[None] *
objects).sum(0)
return 'object_embedding', weighted_sum
def unique_concept_fn(self, result, embedding):
weighted_sum = (F.softmax(result[1], dim=0)[None] *
embedding.all_concept_embeddings).sum(0)
return 'concept_embedding', weighted_sum
def filter_fn(self, results, logits):
filtered_logits = min_fn(results[1], logits)
return 'object_logits', filtered_logits
def exist_fn(self, result):
if self.args.not_build_reasoning:
output = result[1].max()
else:
output = (result[1].max() + self.exist_offset) * self.exist_scale
return ('boolean', output)
def isinstanceof_fn(self, result, embedding):
attributes = embedding.all_attribute_embeddings
detach_concept = self.args.detach_in_rel
result = embedding.determine_relation(
result[1], attributes,
detach=(detach_concept, False),
)[:, 1]
result = (
'attribute_logits',
result,
)
return result
def judge_relation(self, result, argument, embedding, operation):
metaconcept_index = {
'synonym': 0, 'hypernym': 2, 'samekind': 3, 'meronym': 4,
}[operation]
another_concept = embedding.get_embedding('concept', argument)
detach_concept = self.args.detach_in_rel
judgement = embedding.determine_relation(
result[1], another_concept,
detach=(detach_concept, detach_concept),
)
result = (
'boolean',
judgement[metaconcept_index]
)
return result
# analyzing outputs
def boolean_analyze(self, result, answer):
output = {
'yes': detach(torch.sigmoid(result[1])),
'no': detach(torch.sigmoid(-result[1])),
}
if answer == 'yes':
loss = -log(result[1])
else:
loss = -log(-result[1])
return loss, output, {}
def attribute_logits_analyze(self, result, answer):
logs = F.log_softmax(result[1], dim=0)
output = dict(zip(
self.tools.attributes,
detach(logs.exp())
))
index = self.tools.attributes[answer]
target = torch.LongTensor([index]).to(self.device)
loss = F.nll_loss(logs[None], target)
return loss, output, {}
def analyze(self, result, answer):
if result[0] == 'boolean':
return self.boolean_analyze(result, answer)
elif result[0] == 'attribute_logits':
return self.attribute_logits_analyze(result, answer)
else:
raise Exception('result type error: wrong type %s' % result[0])
def init(self):
pass
class Classification(nn.Module):
def __init__(self, args, tools, device):
super().__init__()
self.args = args
self.tools = tools
self.device = device
def forward(self, logits, answer, argument_index):
target = answer.to(self.device)
classify_logits = torch.stack(logits).transpose(1, 0)
binary_loss = F.binary_cross_entropy_with_logits(
classify_logits, target, reduction='none'
)
output = detach(torch.sigmoid(classify_logits))
return binary_loss, output, {}
def init(self):
pass