Skip to content

Commit

Permalink
Parametrize multiple of number of channels in Conv (#12508)
Browse files Browse the repository at this point in the history
* Parametrize multiple of number of channels in Conv

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix issue when exporting

Signed-off-by: Angelo Delli Santi <dellisanti.angelo@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Angelo Delli Santi <dellisanti.angelo@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
  • Loading branch information
3 people authored Dec 18, 2023
1 parent dafe39b commit f33d42d
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
9 changes: 6 additions & 3 deletions models/tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,9 +380,12 @@ def call(self, inputs):

def parse_model(d, ch, model, imgsz): # model_dict, input_channels(3)
LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10} {'module':<40}{'arguments':<30}")
anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
anchors, nc, gd, gw, ch_mul = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple'], d.get(
'channel_multiple')
na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
if not ch_mul:
ch_mul = 8

layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
Expand All @@ -399,7 +402,7 @@ def parse_model(d, ch, model, imgsz): # model_dict, input_channels(3)
nn.Conv2d, Conv, DWConv, DWConvTranspose2d, Bottleneck, SPP, SPPF, MixConv2d, Focus, CrossConv,
BottleneckCSP, C3, C3x]:
c1, c2 = ch[f], args[0]
c2 = make_divisible(c2 * gw, 8) if c2 != no else c2
c2 = make_divisible(c2 * gw, ch_mul) if c2 != no else c2

args = [c1, c2, *args[1:]]
if m in [BottleneckCSP, C3, C3x]:
Expand All @@ -414,7 +417,7 @@ def parse_model(d, ch, model, imgsz): # model_dict, input_channels(3)
if isinstance(args[1], int): # number of anchors
args[1] = [list(range(args[1] * 2))] * len(f)
if m is Segment:
args[3] = make_divisible(args[3] * gw, 8)
args[3] = make_divisible(args[3] * gw, ch_mul)
args.append(imgsz)
else:
c2 = ch[f]
Expand Down
9 changes: 6 additions & 3 deletions models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,10 +299,13 @@ def _from_yaml(self, cfg):
def parse_model(d, ch): # model_dict, input_channels(3)
# Parse a YOLOv5 model.yaml dictionary
LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10} {'module':<40}{'arguments':<30}")
anchors, nc, gd, gw, act = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple'], d.get('activation')
anchors, nc, gd, gw, act, ch_mul = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple'], d.get(
'activation'), d.get('channel_multiple')
if act:
Conv.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = nn.SiLU()
LOGGER.info(f"{colorstr('activation:')} {act}") # print
if not ch_mul:
ch_mul = 8
na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
no = na * (nc + 5) # number of outputs = anchors * (classes + 5)

Expand All @@ -319,7 +322,7 @@ def parse_model(d, ch): # model_dict, input_channels(3)
BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x}:
c1, c2 = ch[f], args[0]
if c2 != no: # if not output
c2 = make_divisible(c2 * gw, 8)
c2 = make_divisible(c2 * gw, ch_mul)

args = [c1, c2, *args[1:]]
if m in {BottleneckCSP, C3, C3TR, C3Ghost, C3x}:
Expand All @@ -335,7 +338,7 @@ def parse_model(d, ch): # model_dict, input_channels(3)
if isinstance(args[1], int): # number of anchors
args[1] = [list(range(args[1] * 2))] * len(f)
if m is Segment:
args[3] = make_divisible(args[3] * gw, 8)
args[3] = make_divisible(args[3] * gw, ch_mul)
elif m is Contract:
c2 = ch[f] * args[0] ** 2
elif m is Expand:
Expand Down

0 comments on commit f33d42d

Please sign in to comment.