-
Notifications
You must be signed in to change notification settings - Fork 6
/
SpatialConvolutionDCT.lua
77 lines (69 loc) · 2.39 KB
/
SpatialConvolutionDCT.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
local SpatialConvolutionDCT, parent = torch.class('SpatialConvolutionDCT', 'nn.Module')
function SpatialConvolutionDCT:__init(conv_module)--nInputPlane, nOutputPlane, kW, kH, dW, dH, padding)
self.conv = conv_module
self:reset()
end
function SpatialConvolutionDCT:reset(stdv)
if self.conv.nInputPlane>1 then
self.conv.weight:copy(odct3dict(self.conv.nInputPlane,self.conv.kW,self.conv.kH,self.conv.nOutputPlane):narrow(2,1,self.conv.nOutputPlane):t())
else
self.conv.weight:copy(odct2dict(self.conv.kW,self.conv.kH,self.conv.nOutputPlane):narrow(2,1,self.conv.nOutputPlane):t())
end
end
function SpatialConvolutionDCT:updateOutput(input)
self.output = self.conv:updateOutput(input)
return self.output
end
function SpatialConvolutionDCT:updateGradInput(input, gradOutput)
self.gradInput = self.conv:updateGradInput(input,gradOutput)
return self.gradInput
end
function SpatialConvolutionDCT:parameters()
return {self.conv.bias}, {self.conv.gradBias}
end
function SpatialConvolutionDCT:accGradParameters(input, gradOutput, scale)
end
function SpatialConvolutionDCT:type(t)
self.conv:type(t)
end
--require 'cunn'
--
--local SpatialConvolutionDCT, parent = torch.class('SpatialConvolutionDCT', 'nn.SpatialConvolutionMM')
--
--function SpatialConvolutionDCT:__init(nInputPlane, nOutputPlane, kW, kH, dW, dH, padding)
-- parent.__init(self, nInputPlane, nOutputPlane, kW, kH, dW, dH, padding)
--
-- self:reset()
--end
--
--function SpatialConvolutionDCT:reset(stdv)
-- if stdv then
-- stdv = stdv * math.sqrt(3)
-- else
-- stdv = 1/math.sqrt(self.kW*self.kH*self.nInputPlane)
-- end
-- if nn.oldSeed then
-- self.weight:apply(function()
-- return torch.uniform(-stdv, stdv)
-- end)
-- self.bias:apply(function()
-- return torch.uniform(-stdv, stdv)
-- end)
-- else
-- self.weight:uniform(-stdv, stdv)
-- self.bias:uniform(-stdv, stdv)
-- end
-- if self.nInputPlane>1 then
-- self.weight:copy(odct3dict(self.nInputPlane,self.kW,self.kH,self.nOutputPlane):narrow(2,1,self.nOutputPlane):t())
-- else
--
-- self.weight:copy(odct2dict(self.kW,self.kH,self.nOutputPlane):narrow(2,1,self.nOutputPlane):t())
-- end
--end
--
--function SpatialConvolutionDCT:parameters()
--return {self.bias}, {self.gradBias}
--end
--
--function SpatialConvolutionDCT:accGradParameters(input, gradOutput, scale)
--end