Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support different number of input channels to YOLOX backbone #1239

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions exps/default/yolox_nano.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def __init__(self):
super(Exp, self).__init__()
self.depth = 0.33
self.width = 0.25
self.backbone_in_channels = 3
self.input_size = (416, 416)
self.random_size = (10, 20)
self.mosaic_scale = (0.5, 1.5)
Expand All @@ -34,8 +35,12 @@ def init_yolo(M):
in_channels = [256, 512, 1024]
# NANO model use depthwise = True, which is main difference.
backbone = YOLOPAFPN(
self.depth, self.width, in_channels=in_channels,
act=self.act, depthwise=True,
self.depth,
self.width,
backbone_in_channels=self.backbone_in_channels,
in_channels=in_channels,
act=self.act,
depthwise=True,
)
head = YOLOXHead(
self.num_classes, self.width, in_channels=in_channels,
Expand Down
14 changes: 12 additions & 2 deletions yolox/exp/yolox_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ def __init__(self):
super().__init__()

# ---------------- model config ---------------- #
# number of input channels, e.g. 3 for RGB input
self.backbone_in_channels = 3
# detect classes number of model
self.num_classes = 80
# factor of model depth
Expand Down Expand Up @@ -118,8 +120,16 @@ def init_yolo(M):

if getattr(self, "model", None) is None:
in_channels = [256, 512, 1024]
backbone = YOLOPAFPN(self.depth, self.width, in_channels=in_channels, act=self.act)
head = YOLOXHead(self.num_classes, self.width, in_channels=in_channels, act=self.act)
backbone = YOLOPAFPN(
self.depth,
self.width,
backbone_in_channels=self.backbone_in_channels,
in_channels=in_channels,
act=self.act,
)
head = YOLOXHead(
self.num_classes, self.width, in_channels=in_channels, act=self.act
)
self.model = YOLOX(backbone, head)

self.model.apply(init_yolo)
Expand Down
51 changes: 35 additions & 16 deletions yolox/models/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@


def create_yolox_model(
name: str, pretrained: bool = True, num_classes: int = 80, device=None
name: str,
pretrained: bool = True,
backbone_in_channels: int = 3,
num_classes: int = 80,
device=None,
) -> nn.Module:
"""creates and loads a YOLOX model

Expand All @@ -50,9 +54,10 @@ def create_yolox_model(

assert name in _CKPT_FULL_PATH, f"user should use one of value in {_CKPT_FULL_PATH.keys()}"
exp: Exp = get_exp(exp_name=name)
exp.backbone_in_channels = backbone_in_channels
exp.num_classes = num_classes
yolox_model = exp.get_model()
if pretrained and num_classes == 80:
if pretrained and backbone_in_channels == 3 and num_classes == 80:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if the architecture fully supports it, but in the past, when using N>3 for FRCNN or MRCNN w/ resnet backbone, I had better luck adapting weights to the extra channels. Certainly beats training from scratch.

Rather than ignoring the weights in the case of N_Channels != 3, is it possible to randomize the extra weights or duplicate weights from a different channel?

At the very least, might be nice to log a warning that the weights are being ignored, despite the "pretrained" input being true.

Copy link
Author

@weiji14 weiji14 Apr 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if the architecture fully supports it, but in the past, when using N>3 for FRCNN or MRCNN w/ resnet backbone, I had better luck adapting weights to the extra channels. Certainly beats training from scratch.

Rather than ignoring the weights in the case of N_Channels != 3, is it possible to randomize the extra weights or duplicate weights from a different channel?

Hmm, I'm not sure how to randomize weights for extra channels, do you some example code to do that? Maybe this can be done in a follow up Pull Request so as not to overcomplicate things.

At the very least, might be nice to log a warning that the weights are being ignored, despite the "pretrained" input being true.

Good idea. Or maybe it should just be an error? Edit: decided to just let it raise an error, done in commit 4e42e61

Copy link

@dcyoung dcyoung Apr 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm. For resnet50 channel additions I've mainly used tensorflow where I did something like:

tensor = params["conv0/W"]
assert tensor.shape == (7, 7, 3, 64)
# Create a replacement with 4 channels, using the existing first 3 and a copy of the 1st
replacement = np.zeros((7, 7, 4, 64), tensor.dtype)
replacement[:, :, :3, :] = tensor
replacement[:, :, 3, :] = tensor[:, :, 0, :]
params[target_layer_name] = replacement

I'm not sure how well that translates to architecture used here.

For pytorch, I believe you can simply modify the state dict before loading. You could do this to avoid loading any tensors with mismtached sizes. That is, attempt to use all weights which CAN be used. For example, a model trained on a different number of classes could still be used to populate weights of the backbone, omitting just the weights from the model head. . Here is an example from Huggingface: https://github.com/huggingface/transformers/blob/v4.18.0/src/transformers/modeling_utils.py#L1989

In the case of N channels != 3, you might need to manipulate the weights. I've had success manipulating weights directly like so:

# Load from a PyTorch checkpoint
state_dict = torch.load(archive_file, map_location="cpu")

# Manually manipulate the weights relevant to 4 channel model
wip = state_dict["param_name"]
// manipulate
wip = ...
# update
state_dict["param_name"] = wip

weights_url = _CKPT_FULL_PATH[name]
ckpt = load_state_dict_from_url(weights_url, map_location="cpu")
if "model" in ckpt:
Expand All @@ -63,29 +68,43 @@ def create_yolox_model(
return yolox_model


def yolox_nano(pretrained=True, num_classes=80, device=None):
return create_yolox_model("yolox-nano", pretrained, num_classes, device)
def yolox_nano(pretrained=True, backbone_in_channels=3, num_classes=80, device=None):
return create_yolox_model(
"yolox-nano", pretrained, backbone_in_channels, num_classes, device
)


def yolox_tiny(pretrained=True, num_classes=80, device=None):
return create_yolox_model("yolox-tiny", pretrained, num_classes, device)
def yolox_tiny(pretrained=True, backbone_in_channels=3, num_classes=80, device=None):
return create_yolox_model(
"yolox-tiny", pretrained, backbone_in_channels, num_classes, device
)


def yolox_s(pretrained=True, num_classes=80, device=None):
return create_yolox_model("yolox-s", pretrained, num_classes, device)
def yolox_s(pretrained=True, backbone_in_channels=3, num_classes=80, device=None):
return create_yolox_model(
"yolox-s", pretrained, backbone_in_channels, num_classes, device
)


def yolox_m(pretrained=True, num_classes=80, device=None):
return create_yolox_model("yolox-m", pretrained, num_classes, device)
def yolox_m(pretrained=True, backbone_in_channels=3, num_classes=80, device=None):
return create_yolox_model(
"yolox-m", pretrained, backbone_in_channels, num_classes, device
)


def yolox_l(pretrained=True, num_classes=80, device=None):
return create_yolox_model("yolox-l", pretrained, num_classes, device)
def yolox_l(pretrained=True, backbone_in_channels=3, num_classes=80, device=None):
return create_yolox_model(
"yolox-l", pretrained, backbone_in_channels, num_classes, device
)


def yolox_x(pretrained=True, num_classes=80, device=None):
return create_yolox_model("yolox-x", pretrained, num_classes, device)
def yolox_x(pretrained=True, backbone_in_channels=3, num_classes=80, device=None):
return create_yolox_model(
"yolox-x", pretrained, backbone_in_channels, num_classes, device
)


def yolov3(pretrained=True, num_classes=80, device=None):
return create_yolox_model("yolox-tiny", pretrained, num_classes, device)
def yolov3(pretrained=True, backbone_in_channels=3, num_classes=80, device=None):
return create_yolox_model(
"yolox-tiny", pretrained, backbone_in_channels, num_classes, device
)
3 changes: 2 additions & 1 deletion yolox/models/darknet.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def __init__(
self,
dep_mul,
wid_mul,
in_channels=3,
out_features=("dark3", "dark4", "dark5"),
depthwise=False,
act="silu",
Expand All @@ -112,7 +113,7 @@ def __init__(
base_depth = max(round(dep_mul * 3), 1) # 3

# stem
self.stem = Focus(3, base_channels, ksize=3, act=act)
self.stem = Focus(in_channels, base_channels, ksize=3, act=act)

# dark2
self.dark2 = nn.Sequential(
Expand Down
3 changes: 2 additions & 1 deletion yolox/models/yolo_fpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@ class YOLOFPN(nn.Module):
def __init__(
self,
depth=53,
backbone_in_channels=3,
in_features=["dark3", "dark4", "dark5"],
):
super().__init__()

self.backbone = Darknet(depth)
self.backbone = Darknet(depth, in_channels=backbone_in_channels)
self.in_features = in_features

# out 1
Expand Down
5 changes: 4 additions & 1 deletion yolox/models/yolo_pafpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@ def __init__(
self,
depth=1.0,
width=1.0,
backbone_in_channels=3,
in_features=("dark3", "dark4", "dark5"),
in_channels=[256, 512, 1024],
depthwise=False,
act="silu",
):
super().__init__()
self.backbone = CSPDarknet(depth, width, depthwise=depthwise, act=act)
self.backbone = CSPDarknet(
depth, width, in_channels=backbone_in_channels, depthwise=depthwise, act=act
)
self.in_features = in_features
self.in_channels = in_channels
Conv = DWConv if depthwise else BaseConv
Expand Down