Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhance] Add upsample_cfg in irr-pwc decoder #53

Merged
merged 2 commits into from
Dec 14, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 16 additions & 23 deletions mmflow/models/decoders/irrpwc_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,8 @@ class IRRPWCDecoder(BaseDecoder):
elements involved to calculate correlation or not.
Defaults to True.
warp_cfg (dict): Config for warp operation. Defaults to
dict(type='Warp', align_corners=True).
dict(type='Warp', align_corners=True) that are same to the official
implementation of IRRPWC.
densefeat_channels (Sequence[int]): Number of output channels for
dense layers. Defaults to (128, 128, 96, 64, 32).
flow_post_processor (dict, optional): Config of flow post process
Expand All @@ -230,6 +231,8 @@ class IRRPWCDecoder(BaseDecoder):
module. Defaults to None.
flow_div (float): The divisor works for scaling the ground truth.
Default: 20.
upsample_cfg (dict): Config dict of interpolate in PyTorch.
Default: dict(mode='bilinear', align_corners=True)
conv_cfg (dict, optional): Config dict of convolution layer in
module. Default: None.
norm_cfg (dict, optional): Config dict of norm layer in module.
Expand Down Expand Up @@ -259,6 +262,8 @@ def __init__(self,
occ_refined_levels: Sequence[str] = ['level0', 'level1'],
occ_upsample: dict = None,
flow_div: float = 20.,
upsample_cfg: dict = dict(
mode='bilinear', align_corners=True),
conv_cfg: Optional[dict] = None,
norm_cfg: Optional[dict] = None,
act_cfg: dict = dict(type='LeakyReLU', negative_slope=0.1),
Expand Down Expand Up @@ -300,6 +305,8 @@ def __init__(self,
norm_cfg=norm_cfg,
act_cfg=act_cfg)

self.upsample_cfg = upsample_cfg

self.flow_refine = build_components(flow_refine)
self.flow_post_processor = build_components(flow_post_processor)

Expand Down Expand Up @@ -524,8 +531,7 @@ def _scale_img(self, img: torch.Tensor, h: int, w: int) -> torch.Tensor:
Returns:
Tensor: The output image.
"""
return F.interpolate(
img, size=(h, w), mode='bilinear', align_corners=True)
return F.interpolate(img, size=(h, w), **self.upsample_cfg)

def _scale_flow(self, flow, h, w):
"""Scale flow function.
Expand All @@ -539,18 +545,10 @@ def _scale_flow(self, flow, h, w):
Tensor: The output optical flow.
"""
h_org, w_org = flow.shape[2:]
u_scale = float(w) / float(w_org)
v_scale = float(h) / float(h_org)
u = flow[:, 0, ...] * u_scale
v = flow[:, 1, ...] * v_scale
u = u[:, None, ...]
v = v[:, None, ...]

return F.interpolate(
torch.cat((u, v), dim=1),
size=(h, w),
mode='bilinear',
align_corners=True)
scale = torch.Tensor([float(w / w_org), float(h / h_org)]).to(flow)
flow = torch.einsum('b c h w, c -> b c h w', flow, scale)

return F.interpolate(flow, size=(h, w), **self.upsample_cfg)

def _scale_flow_as_gt(self, flow: torch.Tensor, H_img: int,
W_img: int) -> torch.Tensor:
Expand All @@ -564,14 +562,9 @@ def _scale_flow_as_gt(self, flow: torch.Tensor, H_img: int,
Tensor: The output optical flow.
"""
h_org, w_org = flow.shape[2:]
u_scale = float(W_img) / float(w_org)
v_scale = float(H_img) / float(h_org)
u = flow[:, 0, ...] * u_scale / self.flow_div
v = flow[:, 1, ...] * v_scale / self.flow_div
u = u[:, None, ...]
v = v[:, None, ...]

return torch.cat((u, v), dim=1)
scale = torch.Tensor([float(W_img / w_org),
float(H_img / h_org)]).to(flow) / self.flow_div
return torch.einsum('b c h w, c -> b c h w', flow, scale)

def forward_train(
self,
Expand Down