We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
def forward(self, obj_vecs, attr_vecs, rela_vecs, edges, rela_masks=None): # for easily indexing the subject and object of each relation in the tensors obj_vecs, attr_vecs, rela_vecs, edges, ori_shape = self.feat_3d_to_2d(obj_vecs, attr_vecs, rela_vecs, edges) # obj new_obj_vecs = obj_vecs # attr new_attr_vecs = self.gnn_attr(torch.cat([obj_vecs, attr_vecs], dim=-1)) + attr_vecs # rela # get node features for each triplet <subject, relation, object> s_idx = edges[:, 0].contiguous() # index of subject o_idx = edges[:, 1].contiguous() # index of object s_vecs = obj_vecs[s_idx] o_vecs = obj_vecs[o_idx] if self.opt.rela_gnn_type == 0: t_vecs = torch.cat([s_vecs, rela_vecs, o_vecs], dim=1) elif self.opt.rela_gnn_type == 1: t_vecs = torch.cat([s_vecs + o_vecs, rela_vecs], dim=1) else: raise NotImplementedError() new_rela_vecs = self.gnn_rela(t_vecs)+rela_vecs new_obj_vecs, new_attr_vecs, new_rela_vecs = self.feat_2d_to_3d(new_obj_vecs, new_attr_vecs, new_rela_vecs, rela_masks, ori_shape) return new_obj_vecs, new_attr_vecs, new_rela_vecs
您好,我想问一下这里new_obj_vecs,new_attr_vecs,new_rela_vecs是GNN优化后的图像的三种类型的特征吗
The text was updated successfully, but these errors were encountered:
Yes, you are right.
Sorry, something went wrong.
No branches or pull requests
您好,我想问一下这里new_obj_vecs,new_attr_vecs,new_rela_vecs是GNN优化后的图像的三种类型的特征吗
The text was updated successfully, but these errors were encountered: