Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

yolov9-seg训练报错:RuntimeError: shape '[32, 65, -1]' is invalid for input of size 131712 #510

Open
SXleader opened this issue Jun 24, 2024 · 0 comments

Comments

@SXleader
Copy link

Logging results to runs\train\exp58
Starting training for 100 epochs...

  Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size

0%| | 0/33 00:02
Traceback (most recent call last):
File "E:\study\yolov9-main\train.py", line 634, in
main(opt)
File "E:\study\yolov9-main\train.py", line 528, in main
train(opt.hyp, opt, device, callbacks)
File "E:\study\yolov9-main\train.py", line 304, in train
loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
File "E:\study\yolov9-main\utils\loss_tal.py", line 178, in call
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
File "E:\study\yolov9-main\utils\loss_tal.py", line 178, in
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
RuntimeError: shape '[32, 65, -1]' is invalid for input of size 131712

在训练yolov9-seg时报错,其中train.py配置文件如下所示:
def parse_opt(known=False):
parser = argparse.ArgumentParser()
# parser.add_argument('--weights', type=str, default=ROOT / 'yolo.pt', help='initial weights path')
# parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
parser.add_argument('--weights', type=str, default='./yolov9-c-seg.pt', help='initial weights path')
parser.add_argument('--cfg', type=str, default='models/segment/gelan-c-seg.yaml', help='model.yaml path')
parser.add_argument('--data', type=str, default=ROOT / './data/my_coco.yaml', help='dataset.yaml path')
parser.add_argument('--hyp', type=str, default=ROOT / 'data/hyps/hyp.scratch-high.yaml', help='hyperparameters path')
parser.add_argument('--epochs', type=int, default=100, help='total training epochs')
parser.add_argument('--batch-size', type=int, default=1, help='total batch size for all GPUs, -1 for autobatch')
parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=434, help='train, val image size (pixels)')

报错定位到了这个位置
    def __call__(self, p, targets, img=None, epoch=0):
    loss = torch.zeros(3, device=self.device)  # box, cls, dfl
    feats = p[1] if isinstance(p, tuple) else p[0]
    pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
        (self.reg_max * 4, self.nc), 1)
 
 数据集为33张train,11张val,3通道tif文件格式,434*434大小,标签数据集为txt文件,有无大佬可以帮我解决报错
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant