Skip to content
This repository has been archived by the owner on Mar 12, 2024. It is now read-only.

Unable to jit.script DETR res50 model: :Dictionary inputs to traced functions must have consistent type. Found Tensor and List[Dict[str, Tensor]] #208

Closed
lessw2020 opened this issue Aug 18, 2020 · 15 comments
Labels
question Further information is requested

Comments

@lessw2020
Copy link
Contributor

Instructions To Reproduce the 🐛 Bug:

  1. what changes you made (git diff) or what code you wrote
Fine tuned DETR model,  resnet50 backbone - 3 classes.
  1. what exact command you run:
    prepared a single image per normal validation process as the 'sample' (resize, tensorize, normalize, unsqueeze to make batch 1, push to gpu).
    then:
    traced` = torch.jit.trace(model,single_batch_tensorimg)

  2. what you observed (including full logs):

runtime error - summary error is "Dictionary inputs to traced functions must have consistent type.  Found Found Tensor and List[Dict[str, Tensor]]"

Full error log:
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-83-905883ed8622> in <module>
----> 1 traced = torch.jit.trace(model,single_batch_tensorimg)

~/anaconda3/lib/python3.7/site-packages/torch/jit/__init__.py in trace(func, example_inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit)
    953         return trace_module(func, {'forward': example_inputs}, None,
    954                             check_trace, wrap_check_inputs(check_inputs),
--> 955                             check_tolerance, strict, _force_outplace, _module_class)
    956 
    957     if (hasattr(func, '__self__') and isinstance(func.__self__, torch.nn.Module) and

~/anaconda3/lib/python3.7/site-packages/torch/jit/__init__.py in trace_module(mod, inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit)
   1107             func = mod if method_name == "forward" else getattr(mod, method_name)
   1108             example_inputs = make_tuple(example_inputs)
-> 1109             module._c._create_method_from_trace(method_name, func, example_inputs, var_lookup_fn, strict, _force_outplace)
   1110             check_trace_method = module._c._get_method(method_name)
   1111 

RuntimeError: Tracer cannot infer type of {'pred_logits': tensor([[[-1.3257e+00, -4.5160e+00, -3.3199e+00,  6.2382e+00],
         [-1.2233e+00, -4.2994e+00, -3.4213e+00,  6.4357e+00],
         [-1.7037e+00, -4.1408e+00, -3.2093e+00,  6.6415e+00],
         [-1.9977e+00, -4.2006e+00, -2.9134e+00,  6.6455e+00],
         [-9.1402e-01, -4.1852e+00, -3.3812e+00,  6.2901e+00],
         [-9.0749e-01, -3.8806e+00, -3.7153e+00,  6.6058e+00],
         [-8.6349e-01, -4.2650e+00, -3.2326e+00,  6.3258e+00],
         [-2.1840e+00, -2.8677e+00, -2.8063e+00,  6.0706e+00],
         [-1.4402e+00, -4.1639e+00, -3.0899e+00,  6.0722e+00],
         [-1.0651e+00, -4.3728e+00, -3.3428e+00,  6.2091e+00],
         [-8.4225e-01, -4.3105e+00, -3.2828e+00,  6.3823e+00],
         [-1.3332e+00, -4.4080e+00, -3.4408e+00,  6.4879e+00],
         [-1.9474e+00, -3.7540e+00, -2.2489e+00,  6.4509e+00],
         [-7.6170e-01, -4.3160e+00, -3.2560e+00,  6.5407e+00],
         [-1.5858e+00, -4.7425e+00, -3.0672e+00,  6.5151e+00],
         [-1.2460e+00, -4.2103e+00, -3.4207e+00,  6.0635e+00],
         [-1.4751e+00, -4.3444e+00, -3.3478e+00,  6.2889e+00],
         [-1.5262e+00, -4.0728e+00, -3.1277e+00,  6.4521e+00],
         [-1.3835e+00, -4.4318e+00, -3.2489e+00,  6.8782e+00],
         [-8.3059e-01, -4.1737e+00, -3.5075e+00,  6.5841e+00],
         [-8.6003e-01, -3.9930e+00, -2.8658e+00,  6.1958e+00],
         [-1.1195e+00, -4.1224e+00, -3.1234e+00,  6.4970e+00],
         [-9.7752e-01, -4.3823e+00, -3.3829e+00,  6.2668e+00],
         [-7.1669e-01, -4.2383e+00, -3.4157e+00,  6.5341e+00],
         [-8.8181e-01, -4.0628e+00, -3.1997e+00,  6.2601e+00],
         [-8.2779e-01, -3.9883e+00, -3.1007e+00,  6.1781e+00],
         [-9.9500e-01, -4.2641e+00, -3.5869e+00,  6.4713e+00],
         [-2.0311e+00, -3.2360e+00, -2.9700e+00,  6.1869e+00],
         [-1.4078e+00, -4.7912e+00, -3.1322e+00,  6.5051e+00],
         [-1.2046e+00, -4.3279e+00, -3.5028e+00,  6.2667e+00],
         [-1.1305e+00, -4.1406e+00, -3.4895e+00,  6.6066e+00],
         [-8.5910e-01, -4.4805e+00, -3.1130e+00,  6.4499e+00],
         [-1.5736e+00, -4.5274e+00, -3.1096e+00,  6.2207e+00],
         [-8.7325e-01, -4.3082e+00, -3.4655e+00,  6.5778e+00],
         [-9.6917e-01, -4.2796e+00, -3.6319e+00,  6.5982e+00],
         [-1.0941e+00, -4.3939e+00, -3.3200e+00,  6.4394e+00],
         [-1.0556e+00, -4.3867e+00, -3.3821e+00,  6.5570e+00],
         [-1.8775e+00, -3.5439e+00, -3.1150e+00,  6.2070e+00],
         [-1.2131e+00, -4.0072e+00, -3.0816e+00,  6.3504e+00],
         [-2.3224e+00, -3.3842e+00, -2.3198e+00,  6.0812e+00],
         [-1.0608e+00, -4.1834e+00, -3.5235e+00,  6.3782e+00],
         [-1.0443e+00, -4.2378e+00, -3.4654e+00,  6.3219e+00],
         [-1.0877e+00, -4.1976e+00, -3.4058e+00,  6.5459e+00],
         [-6.2044e-01, -3.1038e+00, -3.2906e+00,  5.6900e+00],
         [-1.2118e+00, -4.2138e+00, -3.4915e+00,  6.2584e+00],
         [-2.3219e+00, -2.5062e+00, -2.4384e+00,  5.7084e+00],
         [-1.1533e+00, -4.0388e+00, -3.1493e+00,  6.2860e+00],
         [-1.0967e+00, -4.6921e+00, -3.2188e+00,  6.5986e+00],
         [-1.0805e+00, -4.2764e+00, -3.4335e+00,  6.3430e+00],
         [-1.8296e+00, -3.2435e+00, -3.0589e+00,  6.0212e+00],
         [-1.0244e+00, -4.3455e+00, -3.3756e+00,  6.2947e+00],
         [-1.2441e+00, -4.2204e+00, -3.3408e+00,  6.4183e+00],
         [-1.3930e+00, -4.6263e+00, -3.2685e+00,  6.3670e+00],
         [-1.0628e+00, -3.7810e+00, -3.2089e+00,  5.7979e+00],
         [-1.6535e+00, -4.3075e+00, -3.3257e+00,  6.6525e+00],
         [-2.1558e+00, -2.8542e+00, -2.5645e+00,  5.9237e+00],
         [-1.2327e+00, -4.3747e+00, -3.3858e+00,  6.2452e+00],
         [-9.7423e-01, -4.3141e+00, -3.4089e+00,  6.4728e+00],
         [-1.1418e+00, -4.2310e+00, -3.5703e+00,  6.5939e+00],
         [-7.1537e-01, -4.1040e+00, -3.5085e+00,  6.1348e+00],
         [-1.4839e+00, -4.0042e+00, -3.1457e+00,  6.0577e+00],
         [-1.7408e+00,  4.1293e+00, -4.9509e-01, -1.5525e+00],
         [-1.0284e+00, -4.2687e+00, -3.3726e+00,  6.2829e+00],
         [-1.3814e+00, -4.1389e+00, -3.3440e+00,  6.1566e+00],
         [-1.0025e+00, -4.3539e+00, -3.3363e+00,  6.3218e+00],
         [-1.5108e+00, -4.3231e+00, -3.2468e+00,  6.5698e+00],
         [-1.2099e+00, -4.1864e+00, -3.3779e+00,  6.1933e+00],
         [-1.6244e+00, -3.8396e+00, -3.2318e+00,  6.2905e+00],
         [-1.4818e+00, -3.9799e+00, -3.2267e+00,  6.1274e+00],
         [-1.0574e+00, -4.0651e+00, -3.4919e+00,  6.4319e+00],
         [-1.2549e+00, -4.1688e+00, -3.5131e+00,  6.4428e+00],
         [ 4.4515e+00,  3.8680e-03, -1.3146e+00, -1.8654e+00],
         [-7.6814e-01, -4.1199e+00, -3.3639e+00,  6.6328e+00],
         [-1.0299e+00, -4.2803e+00, -3.3329e+00,  6.6037e+00],
         [-8.1970e-01, -4.5901e+00, -3.2575e+00,  6.5023e+00],
         [-1.2137e+00, -4.0303e+00, -3.0325e+00,  6.4221e+00],
         [-1.3679e+00, -4.1100e+00, -3.4035e+00,  6.2796e+00],
         [-7.4305e-01, -4.3340e+00, -3.5650e+00,  6.5321e+00],
         [-1.0473e+00, -4.1474e+00, -3.4294e+00,  6.4154e+00],
         [-1.0486e+00, -4.2867e+00, -3.4619e+00,  6.4207e+00],
         [-1.0286e+00, -4.2844e+00, -3.5789e+00,  6.7687e+00],
         [-7.5652e-01, -3.7627e+00, -3.3852e+00,  5.9946e+00],
         [-1.0777e+00, -4.1732e+00, -3.3713e+00,  6.4387e+00],
         [-6.9987e-01, -4.1442e+00, -3.2779e+00,  6.0329e+00],
         [-1.5770e+00, -3.8518e+00, -3.3054e+00,  6.2581e+00],
         [-1.2293e+00, -4.8014e+00, -2.8239e+00,  6.5217e+00],
         [-7.1122e-01, -4.2456e+00, -3.5109e+00,  6.5091e+00],
         [-1.2753e+00, -4.2083e+00, -3.3322e+00,  6.4163e+00],
         [-4.7021e-01, -3.9568e+00, -3.3516e+00,  6.2389e+00],
         [-9.1776e-01, -4.0323e+00, -3.4145e+00,  6.3345e+00],
         [-7.6549e-01, -4.0218e+00, -3.1490e+00,  6.1941e+00],
         [-1.2081e+00, -4.1735e+00, -3.3727e+00,  6.5662e+00],
         [-9.8999e-01, -4.1010e+00, -3.2139e+00,  6.5327e+00],
         [-1.2535e+00, -4.1246e+00, -3.5824e+00,  6.5121e+00],
         [-9.4460e-01, -4.2512e+00, -3.2911e+00,  6.4086e+00],
         [-7.2473e-01, -4.3610e+00, -3.5410e+00,  6.5070e+00],
         [-1.3768e+00, -4.0155e+00, -3.4594e+00,  6.3249e+00],
         [-1.3593e+00, -4.2313e+00, -3.4009e+00,  6.5089e+00],
         [-2.9125e+00, -6.6248e-01,  4.4765e+00, -7.4652e-01],
         [-1.0689e+00, -4.0516e+00, -3.5897e+00,  6.4756e+00]]],
       device='cuda:0', grad_fn=<SelectBackward>), 'pred_boxes': tensor([[[0.4914, 0.4477, 0.8947, 0.7548],
         [0.5291, 0.4641, 0.8433, 0.5586],
         [0.4912, 0.4702, 0.6988, 0.7264],
         [0.6840, 0.4511, 0.4998, 0.5831],
         [0.4970, 0.4911, 0.8647, 0.5582],
         [0.5062, 0.4874, 0.8768, 0.5273],
         [0.4951, 0.5404, 0.9817, 0.3977],
         [0.4263, 0.5440, 0.6654, 0.3528],
         [0.4721, 0.4357, 0.6923, 0.7251],
         [0.4961, 0.4482, 0.9674, 0.7162],
         [0.4957, 0.5516, 0.9916, 0.4525],
         [0.4894, 0.4508, 0.9339, 0.7519],
         [0.5813, 0.5459, 0.6686, 0.3448],
         [0.4983, 0.5499, 0.9455, 0.4293],
         [0.4995, 0.4650, 0.9940, 0.7674],
         [0.4933, 0.4329, 0.7041, 0.7378],
         [0.4930, 0.4378, 0.8551, 0.7548],
         [0.5851, 0.4342, 0.6135, 0.6085],
         [0.5161, 0.4882, 0.7445, 0.5979],
         [0.5108, 0.5296, 0.8478, 0.4747],
         [0.5037, 0.5535, 0.9071, 0.3481],
         [0.5079, 0.5374, 0.9085, 0.4283],
         [0.4912, 0.4524, 0.9914, 0.7348],
         [0.4962, 0.5320, 0.8992, 0.4527],
         [0.5127, 0.5383, 0.9057, 0.3978],
         [0.5108, 0.5544, 0.8970, 0.3692],
         [0.4983, 0.4343, 0.9809, 0.7201],
         [0.4438, 0.4996, 0.6731, 0.4750],
         [0.5694, 0.4991, 0.6922, 0.6957],
         [0.5076, 0.4382, 0.7700, 0.7317],
         [0.5000, 0.4607, 0.8836, 0.6110],
         [0.4932, 0.5476, 0.9992, 0.5053],
         [0.4885, 0.4541, 0.9959, 0.7642],
         [0.5050, 0.5129, 0.7764, 0.5286],
         [0.5005, 0.4495, 0.9250, 0.6592],
         [0.4998, 0.4882, 0.9710, 0.5647],
         [0.4943, 0.4527, 0.9832, 0.7090],
         [0.4539, 0.4514, 0.6784, 0.6066],
         [0.4920, 0.5466, 0.9152, 0.4152],
         [0.7626, 0.5217, 0.3840, 0.3530],
         [0.5153, 0.4426, 0.8132, 0.6955],
         [0.4959, 0.4345, 0.9892, 0.7281],
         [0.5035, 0.4461, 0.9234, 0.6455],
         [0.5035, 0.5573, 0.9327, 0.3174],
         [0.4946, 0.4378, 0.6953, 0.7401],
         [0.3533, 0.5500, 0.5404, 0.3252],
         [0.5398, 0.5231, 0.7213, 0.4283],
         [0.4912, 0.5226, 0.9994, 0.6730],
         [0.4968, 0.4570, 0.7264, 0.6410],
         [0.4413, 0.5372, 0.6804, 0.3830],
         [0.5051, 0.4868, 0.8158, 0.6040],
         [0.5172, 0.4280, 0.7019, 0.6798],
         [0.4862, 0.4567, 0.9995, 0.7663],
         [0.4519, 0.5481, 0.6927, 0.3951],
         [0.4887, 0.4515, 0.9357, 0.7436],
         [0.4121, 0.5410, 0.6361, 0.3439],
         [0.4938, 0.4421, 0.6918, 0.7467],
         [0.4957, 0.5243, 0.9956, 0.4674],
         [0.5030, 0.4340, 0.9764, 0.6499],
         [0.5019, 0.4317, 0.8988, 0.6941],
         [0.4609, 0.4235, 0.6984, 0.7339],
         [0.2117, 0.5438, 0.3544, 0.3137],
         [0.4917, 0.4439, 0.9956, 0.7291],
         [0.4842, 0.4339, 0.7036, 0.7416],
         [0.5047, 0.4434, 0.9468, 0.6899],
         [0.4906, 0.4715, 0.9374, 0.7335],
         [0.5004, 0.4371, 0.8745, 0.7278],
         [0.4769, 0.4211, 0.6882, 0.7038],
         [0.4698, 0.4408, 0.6993, 0.6940],
         [0.5102, 0.4556, 0.8173, 0.6199],
         [0.5031, 0.4379, 0.9853, 0.7267],
         [0.4914, 0.5326, 0.9300, 0.3334],
         [0.5315, 0.5557, 0.7553, 0.3874],
         [0.4991, 0.5213, 0.9842, 0.4994],
         [0.5005, 0.5456, 0.9432, 0.4854],
         [0.5168, 0.5417, 0.8261, 0.4033],
         [0.4930, 0.4339, 0.7079, 0.7174],
         [0.4940, 0.5140, 0.8836, 0.5233],
         [0.5099, 0.4978, 0.7676, 0.5618],
         [0.4982, 0.4456, 0.9477, 0.7125],
         [0.4935, 0.4617, 0.9833, 0.6653],
         [0.4935, 0.5589, 0.9184, 0.3493],
         [0.5057, 0.4839, 0.8128, 0.5647],
         [0.5068, 0.5214, 0.8670, 0.4747],
         [0.4698, 0.4308, 0.7013, 0.7014],
         [0.5683, 0.5551, 0.6711, 0.4762],
         [0.4989, 0.5223, 0.8108, 0.4794],
         [0.5037, 0.4180, 0.9437, 0.6967],
         [0.4983, 0.5587, 0.9043, 0.3711],
         [0.4946, 0.5367, 0.9640, 0.4033],
         [0.4971, 0.5511, 0.9602, 0.3713],
         [0.5127, 0.4491, 0.8125, 0.6487],
         [0.5112, 0.5485, 0.9091, 0.3727],
         [0.5275, 0.4376, 0.8270, 0.5822],
         [0.4979, 0.5453, 0.9824, 0.4286],
         [0.4928, 0.4865, 0.9910, 0.5705],
         [0.4966, 0.4507, 0.7086, 0.6685],
         [0.5273, 0.4258, 0.7093, 0.6504],
         [0.7762, 0.5236, 0.3506, 0.3170],
         [0.5303, 0.4724, 0.7451, 0.5736]]], device='cuda:0',
       grad_fn=<SelectBackward>), 'aux_outputs': [{'pred_logits': tensor([[[-1.2504, -3.5808, -2.8100,  5.6911],
         [-1.7174, -3.4767, -2.2490,  5.6828],
         [-2.6437, -2.7794, -2.5360,  5.6440],
         [-1.6839, -3.5199, -2.1526,  5.7934],
         [-1.5109, -3.4127, -2.5975,  5.5956],
         [-1.5813, -3.1623, -2.2432,  5.2131],
         [-1.7586, -3.6252, -2.3301,  5.7826],
         [-2.6149, -2.6782, -2.6334,  5.6346],
         [-1.7482, -3.2138, -2.6830,  5.6412],
         [-1.9907, -3.3143, -2.4883,  5.7245],
         [-2.0257, -3.4633, -2.2368,  5.9086],
         [-2.1168, -3.2112, -2.7409,  5.8544],
         [-2.5858, -1.8501, -0.0446,  3.2839],
         [-1.9236, -3.5415, -2.3592,  5.9901],
         [-2.4247, -3.6177, -2.2216,  5.9407],
         [-2.5777, -2.1795, -2.2625,  5.3009],
         [-1.3931, -3.3417, -2.6735,  5.5696],
         [-1.4456, -3.4021, -2.2070,  5.7579],
         [-2.0504, -3.1556, -2.1044,  5.7540],
         [-1.1664, -2.6429, -2.5354,  4.8856],
         [-2.0992, -3.2797, -2.1168,  5.6873],
         [-2.5497, -2.6683, -1.5737,  5.1301],
         [-0.8091, -3.5701, -2.5191,  5.1769],
         [-1.4676, -3.6757, -2.5934,  5.6852],
         [-1.7347, -3.1917, -2.1327,  5.6545],
         [-1.8354, -3.2167, -2.1430,  5.5313],
         [-1.0303, -3.7101, -2.0850,  5.3256],
         [-2.7579, -1.6260, -2.1614,  4.8621],
         [-1.8907, -3.5321, -2.2240,  5.7266],
         [-1.1177, -3.5266, -2.5566,  5.4634],
         [-1.4514, -3.6044, -2.6691,  5.6349],
         [-2.3068, -3.4657, -2.2415,  5.7814],
         [-2.2074, -3.5584, -2.5430,  5.9490],
         [-1.2341, -3.5993, -2.2337,  5.5949],
         [-1.1097, -3.6458, -2.2985,  5.4380],
         [-1.6259, -3.4376, -2.5200,  5.6947],
         [-2.0318, -3.0793, -1.9409,  5.4452],
         [-2.8346, -1.4108, -2.2085,  4.8268],
         [-2.3220, -3.2941, -2.4598,  5.8615],
         [-1.4937, -3.3622, -2.2159,  5.7735],
         [-1.3712, -3.4375, -2.2487,  5.6910],
         [-1.3862, -3.5070, -2.1787,  5.6287],
         [-1.4329, -3.4135, -2.4437,  5.6301],
         [-2.2192, -2.2570, -2.4264,  5.1008],
         [-1.6275, -3.1183, -2.6253,  5.6007],
         [-2.6928, -0.7844, -1.8550,  3.8880],
         [-1.5130, -3.2955, -2.2231,  5.7158],
         [-2.2927, -3.5005, -2.1343,  5.9383],
         [-1.2989, -3.4128, -2.6555,  5.5151],
         [-2.3887, -2.8568, -2.6236,  5.6963],
         [-1.9759, -3.4624, -2.4298,  5.7296],
         [-2.0314, -2.9194, -2.0413,  5.6329],
         [-2.4279, -3.4936, -2.3452,  5.8252],
         [-2.2354, -2.5249, -2.5076,  5.4787],
         [-2.3441, -3.2688, -2.6272,  5.9214],
         [-2.9130, -1.3358, -2.1282,  4.7073],
         [-1.4614, -3.3791, -2.8841,  5.7596],
         [-1.5462, -3.5268, -2.1745,  5.6404],
         [-1.6422, -3.4747, -2.3748,  5.6628],
         [-0.0673, -2.7692, -1.3930,  3.5237],
         [-2.5801, -2.2560, -2.4323,  5.5397],
         [-1.8590,  3.4514, -0.7615, -1.0146],
         [-0.9717, -3.5272, -2.5300,  5.2912],
         [-1.9296, -2.9323, -2.7167,  5.6732],
         [-1.2924, -3.5150, -2.2950,  5.7105],
         [-2.5546, -3.2509, -2.3787,  5.8437],
         [-1.1181, -3.5135, -2.5717,  5.4183],
         [-3.0005, -1.0266, -1.6711,  4.0950],
         [-2.0288, -2.9483, -2.7073,  5.6883],
         [-0.8906, -3.5666, -2.4421,  5.2736],
         [-1.1664, -3.5286, -2.3683,  5.5724],
         [ 3.7035, -0.9299, -0.8875, -0.8911],
         [-0.9587, -2.6410, -2.2875,  4.4634],
         [-1.7858, -3.3497, -2.2749,  5.8184],
         [-2.0658, -3.6200, -2.2396,  5.9750],
         [-2.4014, -2.5836, -1.7479,  5.2817],
         [-2.6368, -2.0338, -2.3552,  5.2596],
         [-1.8437, -3.5947, -2.6635,  5.8387],
         [-1.7162, -3.3048, -2.6070,  5.6582],
         [-1.2139, -3.4071, -2.6661,  5.5219],
         [-1.7903, -3.6471, -2.3148,  5.9152],
         [-1.9094, -3.2170, -2.4986,  5.6285],
         [-1.1993, -3.6512, -2.5350,  5.4498],
         [-1.8164, -3.6394, -2.3245,  5.7558],
         [-2.3619, -2.6117, -2.6648,  5.6662],
         [-1.7949, -3.6007, -2.4724,  6.0046],
         [-1.5244, -3.8219, -2.2344,  5.7408],
         [-1.7583, -2.8067, -2.0539,  5.1993],
         [-1.2698, -3.4389, -2.4968,  5.1424],
         [-1.8990, -3.5888, -2.1812,  5.6946],
         [-1.5099, -3.3564, -1.9802,  5.2863],
         [-2.2901, -2.8322, -1.7029,  5.2737],
         [-2.5696, -2.3374, -0.9916,  4.3264],
         [-1.6742, -3.3663, -2.1964,  5.7185],
         [-1.5840, -3.5024, -2.2013,  5.7061],
         [-2.2832, -3.5258, -2.0242,  5.7375],
         [-1.6337, -3.1316, -2.8015,  5.6610],
         [-1.2943, -3.4227, -2.3167,  5.7404],
         [-1.9429, -0.2366,  2.3898, -0.2129],
         [-1.4359, -3.3523, -2.2822,  5.6752]]], device='cuda:0',
       grad_fn=<SelectBackward>), 'pred_boxes': tensor([[[0.4604, 0.5678, 0.6222, 0.5269],
         [0.4946, 0.5053, 0.8159, 0.7065],
         [0.4148, 0.5657, 0.6347, 0.4115],
         [0.5399, 0.5307, 0.7094, 0.6822],
         [0.4803, 0.5292, 0.7044, 0.6569],
         [0.4841, 0.5336, 0.9993, 0.4318],
         [0.4933, 0.5585, 0.7030, 0.5799],
         [0.4228, 0.5797, 0.6753, 0.4620],
         [0.4618, 0.5563, 0.6805, 0.6270],
         [0.4381, 0.5482, 0.6391, 0.6216],
         [0.6543, 0.5640, 0.4959, 0.6476],
         [0.4758, 0.5878, 0.6723, 0.4811],
         [0.7833, 0.5275, 0.3221, 0.3220],
         [0.5726, 0.5871, 0.5710, 0.5618],
         [0.5479, 0.6703, 0.5612, 0.4588],
         [0.3641, 0.5438, 0.5521, 0.3616],
         [0.4648, 0.5407, 0.6449, 0.6546],
         [0.4839, 0.5263, 0.6703, 0.7073],
         [0.5699, 0.6116, 0.6397, 0.6014],
         [0.4879, 0.5457, 0.9130, 0.3709],
         [0.6972, 0.5338, 0.4488, 0.4283],
         [0.7477, 0.5385, 0.3976, 0.3647],
         [0.4946, 0.5956, 0.6173, 0.5494],
         [0.5167, 0.6472, 0.6076, 0.4792],
         [0.5427, 0.5630, 0.9414, 0.6511],
         [0.6177, 0.5298, 0.5622, 0.5124],
         [0.4934, 0.5732, 0.7960, 0.6242],
         [0.2879, 0.5454, 0.4704, 0.3416],
         [0.4724, 0.5305, 0.6666, 0.6638],
         [0.4663, 0.5690, 0.6127, 0.5937],
         [0.4215, 0.5938, 0.5175, 0.4684],
         [0.5199, 0.5942, 0.6701, 0.4964],
         [0.3840, 0.7057, 0.5078, 0.4034],
         [0.4841, 0.5714, 0.7456, 0.6142],
         [0.4984, 0.5840, 0.8158, 0.6057],
         [0.4833, 0.5335, 0.7073, 0.6441],
         [0.5333, 0.5764, 0.6981, 0.5413],
         [0.2634, 0.5400, 0.4128, 0.3290],
         [0.4813, 0.6561, 0.6621, 0.4301],
         [0.4961, 0.5177, 0.7870, 0.7166],
         [0.4853, 0.5454, 0.7149, 0.6733],
         [0.5126, 0.5674, 0.6872, 0.6494],
         [0.5058, 0.5016, 0.9797, 0.7209],
         [0.4124, 0.5561, 0.6832, 0.3856],
         [0.4871, 0.5972, 0.6945, 0.5891],
         [0.2339, 0.5449, 0.3813, 0.3339],
         [0.4890, 0.5264, 0.7567, 0.7111],
         [0.6430, 0.5668, 0.5358, 0.6128],
         [0.4908, 0.5581, 0.7370, 0.4396],
         [0.4553, 0.6075, 0.7086, 0.5344],
         [0.4417, 0.5511, 0.6493, 0.6173],
         [0.6377, 0.5691, 0.4980, 0.6918],
         [0.4388, 0.7091, 0.5929, 0.3743],
         [0.4562, 0.5991, 0.7348, 0.5579],
         [0.4276, 0.6948, 0.6112, 0.3766],
         [0.2637, 0.5457, 0.4123, 0.3296],
         [0.4726, 0.5710, 0.6709, 0.5863],
         [0.4869, 0.5543, 0.9997, 0.6144],
         [0.4832, 0.5539, 0.9925, 0.6163],
         [0.4916, 0.5313, 0.9126, 0.3601],
         [0.4115, 0.5755, 0.6611, 0.5027],
         [0.2074, 0.5450, 0.3423, 0.3132],
         [0.4840, 0.5688, 0.6484, 0.5858],
         [0.4667, 0.5776, 0.6848, 0.6116],
         [0.4871, 0.5390, 0.6446, 0.6618],
         [0.4308, 0.6707, 0.6360, 0.3572],
         [0.4647, 0.5434, 0.6170, 0.6425],
         [0.2198, 0.5445, 0.3444, 0.3269],
         [0.4624, 0.5826, 0.6977, 0.5931],
         [0.4830, 0.5760, 0.6470, 0.5879],
         [0.4814, 0.5632, 0.6707, 0.6164],
         [0.4933, 0.5341, 0.9235, 0.3340],
         [0.4939, 0.5445, 0.9276, 0.3617],
         [0.5657, 0.5838, 0.7674, 0.6033],
         [0.6468, 0.6087, 0.5028, 0.5641],
         [0.7275, 0.5405, 0.4231, 0.4298],
         [0.3388, 0.5471, 0.5081, 0.3577],
         [0.3829, 0.6364, 0.5168, 0.4415],
         [0.4673, 0.5361, 0.6874, 0.6511],
         [0.4704, 0.5534, 0.6298, 0.6429],
         [0.5352, 0.5983, 0.6172, 0.5616],
         [0.4824, 0.5696, 0.8035, 0.5356],
         [0.5149, 0.6210, 0.6139, 0.5142],
         [0.5085, 0.5416, 0.6502, 0.6358],
         [0.4442, 0.6033, 0.6836, 0.5341],
         [0.5524, 0.6150, 0.5663, 0.5509],
         [0.5144, 0.6252, 0.6611, 0.5232],
         [0.6996, 0.5334, 0.4685, 0.4364],
         [0.4957, 0.5508, 0.8883, 0.4110],
         [0.5328, 0.5181, 0.9347, 0.5674],
         [0.5677, 0.5327, 0.7561, 0.5289],
         [0.6944, 0.5367, 0.4802, 0.4249],
         [0.7633, 0.5369, 0.3843, 0.3164],
         [0.5093, 0.5134, 0.9764, 0.7148],
         [0.5165, 0.5672, 0.9942, 0.5952],
         [0.5907, 0.5147, 0.6184, 0.5648],
         [0.4722, 0.5639, 0.6748, 0.6355],
         [0.5123, 0.5160, 0.6655, 0.7074],
         [0.7786, 0.5245, 0.3477, 0.3157],
         [0.5107, 0.5055, 0.9462, 0.7265]]], device='cuda:0',
       grad_fn=<SelectBackward>)}, {'pred_logits': tensor([[[-1.6500, -2.7915, -3.1436,  5.4312],
         [-1.1625, -3.4246, -2.2464,  5.4831],
         [-3.0083, -2.1793, -2.1680,  5.0686],
         [-1.7548, -3.3570, -2.0972,  5.8348],
         [-0.8475, -3.5415, -2.3718,  5.4747],
         [-0.8487, -2.9381, -2.5146,  4.8731],
         [-0.8502, -3.3339, -2.3982,  5.2410],
         [-3.3714, -2.3902, -2.0733,  5.3285],
         [-1.8389, -3.3280, -2.5668,  5.7871],
         [-1.2719, -3.4554, -2.5377,  5.4720],
         [-1.7363, -3.1968, -2.0783,  5.4526],
         [-2.5652, -2.4822, -2.6682,  5.3470],
         [-3.3959, -2.2413, -0.3635,  4.1775],
         [-1.5810, -3.6376, -2.3309,  5.5078],
         [-2.6936, -3.3247, -2.0543,  6.0493],
         [-2.4304, -2.6600, -2.5299,  5.5151],
         [-1.5922, -3.1945, -2.6763,  5.5996],
         [-1.6343, -3.1181, -2.0488,  5.6691],
         [-3.0456, -2.7233, -1.1656,  5.1956],
         [-1.2091, -3.0932, -2.6456,  5.2706],
         [-1.7967, -3.5811, -1.4576,  4.7321],
         [-3.2810, -2.6536, -1.4455,  5.2102],
         [-0.6922, -3.6481, -2.4980,  5.1173],
         [-1.0279, -3.1719, -2.2171,  5.1160],
         [-1.1500, -2.8763, -2.3207,  5.3278],
         [-1.5585, -3.1494, -2.1503,  5.2553],
         [-0.5562, -3.6813, -2.5903,  5.1625],
         [-3.0774, -2.0517, -2.3170,  5.1939],
         [-1.4737, -3.6631, -2.3451,  5.8114],
         [-1.0028, -2.9610, -2.9918,  5.2103],
         [-1.2367, -2.7583, -2.3197,  4.6859],
         [-2.1664, -3.1773, -2.2628,  5.8080],
         [-2.5753, -3.1692, -2.4102,  6.1224],
         [-0.8347, -3.5695, -2.5362,  5.3728],
         [-1.2025, -3.4671, -2.5868,  5.2017],
         [-1.1864, -3.5760, -2.5045,  5.7025],
         [-2.5731, -3.0669, -1.9202,  5.4031],
         [-3.2221, -1.8033, -2.4251,  5.2437],
         [-2.5627, -3.0804, -2.3463,  5.8798],
         [-2.0421, -3.2197, -2.0862,  5.8463],
         [-0.9628, -3.4827, -2.4417,  5.5131],
         [-0.6889, -3.7807, -2.1987,  5.3701],
         [-0.8387, -3.3127, -2.1262,  5.2251],
         [-2.1453, -2.3900, -2.5374,  5.0199],
         [-1.3858, -2.7812, -2.9140,  5.2243],
         [-3.1792, -1.2023, -2.1183,  4.4930],
         [-1.4958, -3.0142, -2.2522,  5.5348],
         [-2.3517, -3.1673, -2.0741,  5.8402],
         [-1.4280, -3.0156, -2.9565,  5.2168],
         [-2.3719, -2.9097, -2.3282,  5.5514],
         [-1.3908, -3.6179, -2.6021,  5.6949],
         [-2.0438, -3.1872, -1.8978,  5.6104],
         [-2.3888, -3.0966, -2.3693,  5.9452],
         [-1.6300, -2.9063, -2.5600,  5.4086],
         [-3.0860, -2.7783, -2.3277,  5.7895],
         [-3.4280, -1.3252, -2.3770,  4.8525],
         [-2.1773, -2.7614, -3.0224,  5.7173],
         [-0.6196, -3.1901, -2.1711,  4.9953],
         [-0.6019, -3.2987, -2.2895,  5.0478],
         [-0.8112, -3.1888, -1.9121,  4.5232],
         [-2.3307, -2.4977, -2.3405,  5.4466],
         [-1.7402,  3.7450, -0.8112, -1.5496],
         [-0.2642, -3.5207, -2.4279,  4.9732],
         [-1.9299, -2.9438, -2.6822,  5.6510],
         [-0.5841, -3.8389, -2.3040,  5.6133],
         [-2.7626, -2.9781, -2.1239,  5.7334],
         [-0.8485, -3.0610, -2.5015,  5.0099],
         [-3.2560, -1.6718, -2.2750,  5.0969],
         [-2.1008, -3.1064, -2.6322,  5.8120],
         [-0.3918, -3.1805, -2.3507,  4.8291],
         [-0.5685, -3.6090, -2.3130,  5.4461],
         [ 3.8248, -0.4455, -0.5164, -1.6635],
         [-1.2142, -2.8344, -2.7436,  4.8697],
         [-1.7848, -2.9460, -2.1537,  5.5382],
         [-1.9505, -3.1606, -1.9869,  5.6989],
         [-2.3788, -2.8860, -1.7778,  5.1813],
         [-3.1216, -2.0013, -2.5002,  5.2550],
         [-2.1499, -2.9026, -2.3163,  5.3898],
         [-1.5119, -3.3686, -2.6113,  5.8128],
         [-1.3604, -3.1144, -2.7825,  5.4365],
         [-1.4920, -3.4651, -2.4586,  5.3295],
         [-1.0875, -3.1580, -2.4765,  5.1716],
         [-0.6646, -3.0004, -2.1291,  4.7016],
         [-0.8419, -3.5948, -2.2256,  5.4344],
         [-2.8410, -2.3441, -2.3497,  5.3130],
         [-1.6268, -3.4865, -2.2048,  5.8213],
         [-1.2152, -3.3844, -2.2169,  5.2241],
         [-1.6030, -3.3270, -2.0410,  5.3373],
         [-1.1779, -3.2011, -2.3970,  4.9374],
         [-0.7314, -3.0841, -2.3147,  4.8985],
         [-1.1215, -3.1661, -2.5295,  5.1939],
         [-2.6078, -2.8395, -1.6961,  5.3535],
         [-2.3633, -2.8395, -1.4931,  4.7564],
         [-1.5178, -3.2933, -2.1866,  5.5775],
         [-0.9147, -3.0284, -2.2748,  5.1922],
         [-1.2472, -3.0956, -2.2667,  5.1220],
         [-1.6271, -3.0285, -2.7366,  5.6759],
         [-1.0379, -3.2832, -2.2338,  5.5998],
         [-2.4184, -0.3453,  3.0088, -0.4882],
         [-1.5080, -3.2473, -2.1650,  5.5897]]], device='cuda:0',
       grad_fn=<SelectBackward>), 'pred_boxes': tensor([[[0.4435, 0.5335, 0.6490, 0.4586],
         [0.4984, 0.4164, 0.9986, 0.6202],
         [0.2734, 0.5520, 0.4552, 0.3693],
         [0.5702, 0.4160, 0.6332, 0.6383],
         [0.4848, 0.4619, 1.0000, 0.6557],
         [0.4862, 0.5400, 1.0000, 0.4781],
         [0.4873, 0.5179, 1.0000, 0.5560],
         [0.2647, 0.5478, 0.3135, 0.3521],
         [0.4616, 0.4520, 0.7353, 0.6962],
         [0.4819, 0.4759, 1.0000, 0.7055],
         [0.5628, 0.5435, 0.6601, 0.5494],
         [0.4766, 0.5598, 0.7217, 0.4352],
         [0.7762, 0.5291, 0.3528, 0.3253],
         [0.5485, 0.5489, 0.7354, 0.4340],
         [0.5547, 0.6926, 0.7873, 0.4258],
         [0.2686, 0.4724, 0.4653, 0.5615],
         [0.4152, 0.4529, 0.6370, 0.6787],
         [0.6711, 0.4442, 0.4861, 0.5198],
         [0.7686, 0.5272, 0.3504, 0.3726],
         [0.5144, 0.5140, 0.9033, 0.4753],
         [0.6027, 0.5356, 0.6319, 0.3717],
         [0.7210, 0.5339, 0.4533, 0.3686],
         [0.5001, 0.5478, 0.9019, 0.5160],
         [0.4886, 0.5671, 0.9998, 0.5277],
         [0.6592, 0.5310, 0.5478, 0.4122],
         [0.6451, 0.5501, 0.5938, 0.4185],
         [0.5018, 0.5069, 0.9308, 0.5592],
         [0.3140, 0.5467, 0.4554, 0.3919],
         [0.5372, 0.4419, 0.9762, 0.6901],
         [0.4765, 0.5458, 0.7323, 0.4238],
         [0.4867, 0.5480, 1.0000, 0.4288],
         [0.4818, 0.5823, 1.0000, 0.5026],
         [0.4830, 0.6588, 1.0000, 0.4594],
         [0.4950, 0.5369, 0.8929, 0.4443],
         [0.5107, 0.5524, 0.9006, 0.4467],
         [0.4858, 0.4542, 1.0000, 0.6567],
         [0.6927, 0.5335, 0.5051, 0.4104],
         [0.2438, 0.5160, 0.4057, 0.4169],
         [0.4839, 0.5479, 0.9998, 0.5697],
         [0.6528, 0.4206, 0.5024, 0.5292],
         [0.5230, 0.4769, 0.7375, 0.5580],
         [0.5291, 0.4864, 0.8320, 0.6253],
         [0.4888, 0.4632, 0.9985, 0.6740],
         [0.4215, 0.5541, 0.7006, 0.3759],
         [0.3839, 0.5476, 0.5896, 0.4090],
         [0.2374, 0.5468, 0.3234, 0.3320],
         [0.6845, 0.5106, 0.4927, 0.4501],
         [0.5214, 0.6020, 0.9908, 0.5296],
         [0.4952, 0.5463, 0.9952, 0.4448],
         [0.4672, 0.5480, 0.7635, 0.4349],
         [0.4852, 0.4622, 0.9998, 0.7699],
         [0.6903, 0.4612, 0.4534, 0.5263],
         [0.4802, 0.6362, 1.0000, 0.4682],
         [0.4807, 0.5294, 0.8114, 0.4868],
         [0.4509, 0.5955, 0.6845, 0.4806],
         [0.2471, 0.5363, 0.4268, 0.3557],
         [0.4079, 0.5413, 0.5972, 0.4363],
         [0.4864, 0.5016, 1.0000, 0.5880],
         [0.4825, 0.4723, 1.0000, 0.6350],
         [0.4924, 0.5040, 0.9254, 0.5326],
         [0.2207, 0.5188, 0.3718, 0.4699],
         [0.2077, 0.5443, 0.3449, 0.3116],
         [0.4970, 0.5391, 0.9060, 0.4792],
         [0.3067, 0.4724, 0.5038, 0.6417],
         [0.5347, 0.4896, 0.7209, 0.6656],
         [0.4830, 0.5919, 0.9157, 0.4869],
         [0.4791, 0.5270, 0.7869, 0.4634],
         [0.2575, 0.5347, 0.4426, 0.4375],
         [0.4318, 0.4813, 0.7001, 0.5998],
         [0.4946, 0.5474, 0.9105, 0.4393],
         [0.5115, 0.4991, 0.7890, 0.5764],
         [0.4906, 0.5336, 0.9247, 0.3336],
         [0.5025, 0.5481, 0.8955, 0.3384],
         [0.6276, 0.5447, 0.5885, 0.4528],
         [0.5294, 0.6384, 0.7434, 0.5068],
         [0.6889, 0.5208, 0.4952, 0.4122],
         [0.2139, 0.5306, 0.3845, 0.3974],
         [0.4910, 0.5644, 0.9567, 0.5233],
         [0.4853, 0.4131, 1.0000, 0.6359],
         [0.4872, 0.5130, 0.7930, 0.4482],
         [0.5180, 0.5424, 0.8669, 0.4290],
         [0.4839, 0.5257, 1.0000, 0.5093],
         [0.4903, 0.5378, 0.9972, 0.4460],
         [0.5006, 0.4859, 0.9946, 0.6789],
         [0.2332, 0.5429, 0.3347, 0.3830],
         [0.5441, 0.5842, 0.7345, 0.5780],
         [0.5114, 0.5470, 0.8986, 0.3870],
         [0.6417, 0.4721, 0.5899, 0.5758],
         [0.5046, 0.5485, 0.9289, 0.4603],
         [0.4857, 0.5313, 1.0000, 0.5369],
         [0.5804, 0.5324, 0.7023, 0.5170],
         [0.7062, 0.4988, 0.4578, 0.4503],
         [0.6902, 0.5344, 0.5039, 0.3537],
         [0.5202, 0.4077, 0.9553, 0.5891],
         [0.5423, 0.5399, 0.7645, 0.5042],
         [0.5001, 0.5282, 0.9981, 0.5460],
         [0.3405, 0.5143, 0.5157, 0.4203],
         [0.5902, 0.4427, 0.6202, 0.6028],
         [0.7771, 0.5233, 0.3448, 0.3143],
         [0.5552, 0.4100, 0.6718, 0.5885]]], device='cuda:0',
       grad_fn=<SelectBackward>)}, {'pred_logits': tensor([[[-1.4231, -2.9854, -2.8175,  5.8054],
         [-1.0688, -3.1451, -1.9913,  5.2892],
         [-2.2668, -2.2551, -2.0935,  5.3384],
         [-2.2930, -3.0509, -2.3439,  6.1837],
         [-0.3857, -3.1018, -2.4049,  5.2037],
         [-0.4172, -2.8127, -2.5592,  5.2967],
         [-0.1772, -2.8594, -2.4700,  5.2208],
         [-2.6316, -2.3948, -2.1473,  5.8397],
         [-1.0754, -3.4004, -2.5016,  5.8086],
         [-0.9364, -3.2111, -2.5882,  5.6272],
         [-1.2814, -3.0725, -2.0933,  5.4641],
         [-1.8306, -2.6572, -2.5809,  5.5190],
         [-2.5935, -2.4510, -0.5497,  4.8935],
         [-1.3579, -3.2580, -2.1144,  5.7502],
         [-2.9793, -3.5588, -1.6405,  5.8022],
         [-1.2079, -2.9609, -2.4564,  5.5147],
         [-1.1592, -3.3643, -2.6971,  5.8965],
         [-1.6293, -3.1758, -1.7289,  5.7704],
         [-2.5603, -2.8904, -1.5268,  6.0258],
         [-1.6527, -3.0394, -2.5868,  5.7763],
         [-1.8605, -3.1928, -1.5969,  5.3262],
         [-2.0028, -2.4993, -1.1249,  5.2402],
         [-0.9432, -3.2229, -2.5962,  5.6276],
         [-0.9028, -3.0440, -2.4930,  5.4841],
         [-1.2124, -3.0218, -1.9411,  5.7027],
         [-1.6408, -3.0563, -1.7618,  5.5005],
         [-0.7354, -3.2734, -2.4202,  5.4369],
         [-2.2038, -2.5064, -2.1167,  5.7239],
         [-1.7384, -3.4724, -2.6246,  5.8576],
         [-1.1197, -3.1313, -2.9491,  5.7304],
         [-0.8632, -2.9101, -2.5362,  5.1129],
         [-1.5062, -3.1172, -2.2261,  5.6268],
         [-2.4600, -3.3812, -2.6840,  6.2146],
         [-1.2011, -3.2752, -2.4849,  5.8270],
         [-1.2211, -3.2926, -2.4304,  5.5794],
         [-0.3876, -3.1259, -2.4386,  5.2451],
         [-1.8785, -3.0099, -1.7789,  5.7546],
         [-2.1598, -2.5278, -2.1110,  5.8583],
         [-1.0620, -2.8968, -2.2855,  5.4242],
         [-2.0634, -3.1955, -2.0007,  6.0761],
         [-0.6662, -3.3848, -2.2469,  5.4752],
         [-0.6792, -3.3040, -2.1194,  5.4110],
         [-0.5684, -3.1767, -2.1038,  5.2438],
         [-1.4119, -2.6081, -2.5424,  5.3156],
         [-1.4367, -2.9236, -2.7233,  5.6171],
         [-2.5253, -2.1629, -1.9071,  5.6778],
         [-1.1680, -3.0122, -1.5060,  5.3812],
         [-2.5888, -3.1524, -2.1023,  6.1741],
         [-1.3286, -2.8386, -2.7702,  5.3292],
         [-1.6407, -2.8293, -2.3317,  5.8231],
         [-0.7324, -3.3081, -2.3558,  5.2729],
         [-1.4734, -3.2299, -1.5420,  5.5400],
         [-2.2989, -3.3367, -2.6485,  6.1313],
         [-0.6450, -2.9306, -2.7030,  5.2167],
         [-2.0550, -2.6860, -2.4078,  5.7369],
         [-2.2476, -2.1310, -1.9005,  5.5592],
         [-1.6438, -3.0448, -2.8049,  5.9452],
         [-0.4286, -2.9496, -2.1185,  5.2636],
         [-0.1255, -2.8942, -2.3951,  5.2833],
         [-0.2177, -2.8967, -1.9466,  4.8758],
         [-1.4949, -2.8404, -2.4574,  5.7899],
         [-1.9299,  3.9416, -0.5631, -1.5091],
         [-1.0772, -3.3255, -2.7333,  5.7596],
         [-1.1553, -3.1691, -2.6078,  5.7869],
         [-0.2924, -3.4447, -2.1180,  5.2366],
         [-1.9532, -2.9933, -2.3763,  5.7377],
         [-1.0900, -3.1266, -2.6527,  5.6877],
         [-2.0086, -2.3956, -2.0749,  5.5983],
         [-0.9961, -3.1074, -2.5959,  5.6631],
         [-0.6606, -3.0797, -2.4826,  5.3518],
         [-0.7554, -3.2609, -2.3134,  5.5423],
         [ 4.0333, -0.1656, -0.3935, -1.9464],
         [-1.6300, -3.0205, -2.5745,  5.5450],
         [-1.4464, -3.0224, -1.8684,  5.7209],
         [-2.2923, -3.3039, -2.0065,  5.7042],
         [-1.8417, -2.8574, -1.3871,  5.4191],
         [-1.8479, -2.5814, -2.3762,  5.5721],
         [-1.8005, -2.8841, -2.5357,  5.4072],
         [-0.9564, -3.2119, -2.5473,  5.6096],
         [-1.0831, -3.2866, -2.7719,  5.8481],
         [-1.1163, -3.0596, -2.2407,  5.4824],
         [-0.5106, -2.8198, -2.5989,  5.4795],
         [-1.0136, -3.1482, -2.5053,  5.4294],
         [-0.5503, -3.2515, -2.2499,  5.2227],
         [-1.8307, -2.6809, -2.4564,  5.7499],
         [-2.1241, -3.6411, -1.9773,  5.6121],
         [-1.2523, -3.1468, -2.3464,  5.5933],
         [-1.0927, -3.2725, -1.6309,  5.2525],
         [-0.8953, -2.9698, -2.5888,  5.4499],
         [-0.0419, -2.8144, -2.4175,  5.2348],
         [-1.0292, -3.0725, -2.1068,  5.4061],
         [-1.8448, -2.8445, -1.2782,  5.4698],
         [-1.8900, -3.0359, -1.5603,  5.3642],
         [-1.5739, -3.0762, -2.2299,  5.9266],
         [-1.0384, -2.9962, -1.9951,  5.5450],
         [-0.2123, -2.9351, -2.4236,  5.2250],
         [-1.3655, -3.1446, -2.7411,  5.9292],
         [-0.9977, -3.2676, -1.8949,  5.6012],
         [-2.3189, -0.4809,  3.7322, -0.5375],
         [-1.2448, -3.1832, -1.9165,  5.6238]]], device='cuda:0',
       grad_fn=<SelectBackward>), 'pred_boxes': tensor([[[0.3452, 0.4236, 0.5137, 0.7136],
         [0.5966, 0.4288, 0.6507, 0.6229],
         [0.2612, 0.5497, 0.4357, 0.4605],
         [0.6839, 0.3533, 0.4651, 0.4353],
         [0.4929, 0.4320, 0.9962, 0.6833],
         [0.4980, 0.4623, 0.9969, 0.6161],
         [0.5167, 0.4669, 0.8921, 0.5866],
         [0.2714, 0.5032, 0.4565, 0.4839],
         [0.2987, 0.4134, 0.4529, 0.7180],
         [0.4916, 0.4285, 0.9138, 0.7579],
         [0.5166, 0.4886, 0.9620, 0.6142],
         [0.4424, 0.4825, 0.6524, 0.5860],
         [0.6539, 0.5351, 0.5402, 0.3493],
         [0.5491, 0.5323, 0.7693, 0.4716],
         [0.6717, 0.7030, 0.5366, 0.3924],
         [0.2654, 0.4112, 0.4170, 0.7247],
         [0.3180, 0.4264, 0.4696, 0.7328],
         [0.6669, 0.4307, 0.5154, 0.5532],
         [0.6458, 0.5412, 0.5122, 0.4214],
         [0.5392, 0.4689, 0.8056, 0.5169],
         [0.6007, 0.5453, 0.6566, 0.3815],
         [0.4893, 0.5343, 0.7789, 0.4097],
         [0.5311, 0.4431, 0.7820, 0.6966],
         [0.5018, 0.4536, 0.9780, 0.6516],
         [0.6033, 0.4404, 0.6700, 0.6404],
         [0.6379, 0.5262, 0.6060, 0.4390],
         [0.5277, 0.4245, 0.9445, 0.7164],
         [0.3239, 0.4782, 0.5292, 0.5347],
         [0.5656, 0.3754, 0.5897, 0.4433],
         [0.4460, 0.4300, 0.6202, 0.6885],
         [0.4916, 0.4875, 0.9995, 0.5495],
         [0.5159, 0.4763, 0.9736, 0.6553],
         [0.5115, 0.5098, 0.6909, 0.7220],
         [0.5284, 0.5017, 0.8627, 0.5313],
         [0.5379, 0.4945, 0.8352, 0.5934],
         [0.4984, 0.4213, 0.9958, 0.7140],
         [0.5095, 0.4925, 0.9329, 0.5776],
         [0.2445, 0.4431, 0.4131, 0.6591],
         [0.4952, 0.4405, 0.9724, 0.6556],
         [0.6943, 0.4029, 0.4380, 0.4537],
         [0.5302, 0.4354, 0.9538, 0.6851],
         [0.5259, 0.4209, 0.8996, 0.7331],
         [0.5095, 0.4579, 0.9880, 0.6413],
         [0.4867, 0.5162, 0.7702, 0.4664],
         [0.4048, 0.4312, 0.5832, 0.6815],
         [0.2863, 0.5430, 0.4358, 0.3623],
         [0.6302, 0.4579, 0.5896, 0.5667],
         [0.6080, 0.4678, 0.6147, 0.7205],
         [0.5002, 0.4845, 0.9566, 0.5778],
         [0.3514, 0.4502, 0.5528, 0.6221],
         [0.5094, 0.4299, 0.8437, 0.7190],
         [0.6541, 0.4285, 0.5521, 0.6735],
         [0.5358, 0.5255, 0.7381, 0.6893],
         [0.3586, 0.4415, 0.5903, 0.6584],
         [0.4502, 0.4177, 0.6587, 0.7150],
         [0.2439, 0.5322, 0.4317, 0.4004],
         [0.3391, 0.4274, 0.5031, 0.7238],
         [0.4957, 0.4544, 0.9982, 0.6779],
         [0.4993, 0.4505, 0.9917, 0.6821],
         [0.5027, 0.4206, 0.9144, 0.6893],
         [0.2686, 0.4217, 0.4246, 0.7322],
         [0.2089, 0.5446, 0.3485, 0.3129],
         [0.5314, 0.4501, 0.9330, 0.6856],
         [0.2908, 0.4217, 0.4504, 0.7331],
         [0.5310, 0.4364, 0.7414, 0.6728],
         [0.4592, 0.4235, 0.7121, 0.7318],
         [0.4755, 0.4225, 0.7031, 0.6971],
         [0.2301, 0.5065, 0.3959, 0.5284],
         [0.3465, 0.4194, 0.5078, 0.7431],
         [0.4989, 0.4861, 0.9501, 0.5636],
         [0.5237, 0.4262, 0.9158, 0.7212],
         [0.4912, 0.5352, 0.9290, 0.3313],
         [0.5392, 0.5367, 0.7641, 0.3953],
         [0.5417, 0.4701, 0.7558, 0.6174],
         [0.5829, 0.6021, 0.6730, 0.4650],
         [0.5954, 0.5055, 0.6446, 0.4591],
         [0.2728, 0.4434, 0.4444, 0.6624],
         [0.4994, 0.4313, 0.7829, 0.7068],
         [0.5099, 0.3872, 0.9722, 0.6102],
         [0.4915, 0.4229, 0.8801, 0.6766],
         [0.5343, 0.4959, 0.8622, 0.5861],
         [0.4965, 0.4514, 0.9941, 0.6664],
         [0.5045, 0.5043, 0.9973, 0.5368],
         [0.5378, 0.4534, 0.7889, 0.6374],
         [0.2370, 0.4365, 0.4008, 0.6881],
         [0.6100, 0.5555, 0.6206, 0.5051],
         [0.5319, 0.5019, 0.8353, 0.5148],
         [0.6205, 0.4127, 0.6338, 0.7021],
         [0.5204, 0.5360, 0.9056, 0.4497],
         [0.5101, 0.4485, 0.9371, 0.6502],
         [0.5966, 0.4728, 0.6990, 0.5794],
         [0.5350, 0.4790, 0.7196, 0.5450],
         [0.6119, 0.5364, 0.6574, 0.3876],
         [0.5811, 0.3825, 0.6973, 0.6082],
         [0.5463, 0.4540, 0.7464, 0.6466],
         [0.5082, 0.4592, 0.9692, 0.6734],
         [0.4029, 0.4363, 0.5727, 0.6946],
         [0.6250, 0.4178, 0.6237, 0.6686],
         [0.7770, 0.5237, 0.3505, 0.3150],
         [0.5973, 0.4250, 0.6790, 0.6539]]], device='cuda:0',
       grad_fn=<SelectBackward>)}, {'pred_logits': tensor([[[-1.1162, -3.7035, -3.0795,  6.1133],
         [-1.2160, -3.5730, -2.2189,  5.3453],
         [-1.9854, -2.3146, -2.5552,  5.7973],
         [-2.0566, -3.0735, -2.3586,  5.9485],
         [-0.8528, -3.4781, -2.6605,  5.7390],
         [-0.3874, -3.1070, -3.0614,  5.8268],
         [-0.7707, -3.2150, -2.7647,  5.6134],
         [-2.1406, -2.1442, -2.4924,  5.5742],
         [-1.3929, -3.7180, -2.6538,  5.8116],
         [-1.1005, -3.7090, -2.8030,  5.8197],
         [-1.0991, -3.4300, -2.6064,  5.7625],
         [-1.4804, -3.1677, -2.9624,  6.0346],
         [-2.2091, -2.6808, -1.3839,  5.2839],
         [-1.0450, -3.3302, -2.3920,  5.4023],
         [-2.7047, -3.9290, -2.1605,  6.4024],
         [-1.2894, -3.5306, -2.8537,  5.7472],
         [-1.2925, -3.8171, -2.8466,  5.9891],
         [-1.5422, -3.1310, -2.0406,  5.3751],
         [-1.8916, -3.2877, -2.4486,  6.3133],
         [-0.8722, -3.2748, -2.8729,  5.7728],
         [-1.5530, -3.1965, -2.0213,  5.2309],
         [-1.6819, -3.0973, -2.1486,  5.8468],
         [-0.6525, -3.8565, -2.8088,  5.8247],
         [-0.9484, -3.2603, -2.8669,  6.0622],
         [-1.1130, -3.3053, -2.5153,  5.5193],
         [-1.4192, -3.1684, -2.2446,  5.1990],
         [-0.6525, -3.7336, -2.7140,  5.7248],
         [-1.8674, -2.2906, -2.4300,  5.5211],
         [-1.3664, -3.7100, -2.7815,  5.8811],
         [-0.8301, -3.7452, -3.1194,  5.9734],
         [-0.5863, -3.2202, -2.9713,  5.8879],
         [-1.5549, -3.4965, -2.6137,  6.0383],
         [-2.1716, -3.9864, -2.8182,  6.6471],
         [-0.8454, -3.4220, -2.7243,  5.8475],
         [-0.7476, -3.5711, -2.7505,  5.8240],
         [-0.8372, -3.5942, -2.6337,  5.5234],
         [-1.2164, -3.5857, -2.6168,  6.2293],
         [-1.8055, -2.5923, -2.4705,  5.3860],
         [-1.4073, -2.9542, -2.6590,  5.7547],
         [-2.2043, -2.7893, -1.9274,  5.4537],
         [-0.5882, -3.7436, -2.5869,  5.5419],
         [-0.7416, -3.7267, -2.4966,  5.5371],
         [-0.6933, -3.5245, -2.5769,  5.9205],
         [-1.0323, -2.4939, -2.8852,  5.2109],
         [-1.1496, -3.5252, -3.0470,  6.0987],
         [-2.0737, -1.8361, -2.1939,  5.1878],
         [-1.3481, -3.1971, -2.0598,  5.1385],
         [-2.1031, -3.7799, -2.4462,  6.1121],
         [-0.9884, -3.1527, -2.9635,  5.6913],
         [-1.5664, -2.6683, -2.5971,  5.3102],
         [-1.0866, -3.6840, -2.5177,  5.5214],
         [-1.2936, -3.6461, -2.2073,  5.4684],
         [-1.8741, -3.9879, -2.8203,  6.4552],
         [-0.7395, -3.1404, -2.8558,  5.0678],
         [-1.9889, -3.1105, -2.8205,  6.2112],
         [-1.9853, -1.9834, -2.3098,  5.5369],
         [-1.2614, -3.5394, -3.1021,  6.1532],
         [-0.8467, -3.2594, -2.5738,  5.5828],
         [-0.5321, -3.3350, -2.8172,  5.7600],
         [-0.6414, -3.3768, -2.4083,  5.2766],
         [-1.3374, -3.2302, -2.6630,  5.6000],
         [-1.9823,  4.1271, -0.4793, -1.5280],
         [-0.8125, -3.7206, -2.9084,  6.0304],
         [-1.3315, -3.5814, -2.8126,  5.8134],
         [-0.6584, -3.7776, -2.3883,  5.4691],
         [-1.7295, -3.2527, -2.7123,  5.9896],
         [-0.9021, -3.6485, -2.9098,  6.0115],
         [-1.7219, -2.8510, -2.4660,  5.8685],
         [-1.2134, -3.4111, -2.7187,  5.5608],
         [-0.6291, -3.2688, -2.8018,  5.8175],
         [-0.6566, -3.7216, -2.6624,  5.6881],
         [ 4.3031,  0.1184, -1.0466, -1.8691],
         [-1.0885, -3.2767, -2.7986,  5.8433],
         [-1.2281, -3.4030, -2.5778,  6.0583],
         [-1.8854, -3.6581, -2.3959,  5.7738],
         [-1.7223, -3.0947, -2.1583,  5.6772],
         [-1.6307, -3.1154, -2.7600,  5.8258],
         [-1.3146, -3.2644, -2.7935,  5.9752],
         [-0.8583, -3.4668, -2.7229,  5.6486],
         [-0.7019, -3.6436, -2.9980,  5.9987],
         [-0.7727, -3.3370, -2.7549,  5.8421],
         [-0.7255, -2.9233, -2.8486,  5.3102],
         [-0.9144, -3.3213, -2.8859,  6.0472],
         [-0.6404, -3.5078, -2.5673,  5.3582],
         [-1.5694, -3.0120, -2.7123,  5.6290],
         [-1.8813, -3.7789, -2.1495,  5.4608],
         [-0.7428, -3.2995, -2.6378,  5.5890],
         [-1.5827, -3.5136, -1.9974,  5.3308],
         [-0.7550, -3.0012, -2.9295,  5.6938],
         [-0.6125, -3.1273, -2.7412,  5.4797],
         [-1.0233, -3.2150, -2.5671,  5.3730],
         [-1.4848, -3.3401, -2.1883,  5.6891],
         [-1.5731, -3.2015, -2.4313,  5.7881],
         [-1.2180, -3.4321, -2.5310,  5.6778],
         [-1.1180, -3.2675, -2.5200,  5.5461],
         [-0.5237, -3.3511, -2.9367,  5.8799],
         [-1.2963, -3.2625, -2.9715,  5.6808],
         [-1.2281, -3.5656, -2.2486,  5.4848],
         [-3.0042, -0.3081,  4.2593, -0.8091],
         [-1.1701, -3.4849, -2.3317,  5.4553]]], device='cuda:0',
       grad_fn=<SelectBackward>), 'pred_boxes': tensor([[[0.1702, 0.4606, 0.2849, 0.7321],
         [0.5229, 0.4664, 0.7909, 0.5886],
         [0.1794, 0.5380, 0.3181, 0.4549],
         [0.6342, 0.4746, 0.4965, 0.5822],
         [0.4871, 0.4810, 0.9909, 0.6433],
         [0.4894, 0.4765, 0.9973, 0.6173],
         [0.4934, 0.5311, 0.9609, 0.4637],
         [0.2162, 0.5345, 0.3881, 0.3895],
         [0.2601, 0.4509, 0.4261, 0.7153],
         [0.4728, 0.4630, 0.7207, 0.7464],
         [0.4880, 0.5489, 0.9970, 0.5006],
         [0.1716, 0.5310, 0.2883, 0.5460],
         [0.5069, 0.5339, 0.8696, 0.3526],
         [0.5004, 0.5490, 0.9000, 0.4269],
         [0.6205, 0.7156, 0.5946, 0.4369],
         [0.2105, 0.4435, 0.3443, 0.7319],
         [0.2432, 0.4403, 0.3907, 0.7508],
         [0.5588, 0.4828, 0.7084, 0.5058],
         [0.4955, 0.5213, 0.8765, 0.4684],
         [0.4902, 0.5435, 0.9792, 0.4820],
         [0.5123, 0.5432, 0.8540, 0.3800],
         [0.4896, 0.5297, 0.9920, 0.4322],
         [0.4779, 0.4921, 0.7076, 0.6520],
         [0.4900, 0.5279, 0.9353, 0.5081],
         [0.4998, 0.5346, 0.8706, 0.4404],
         [0.5207, 0.5507, 0.8463, 0.3923],
         [0.4937, 0.4635, 0.8708, 0.7061],
         [0.2295, 0.5313, 0.4042, 0.4176],
         [0.5295, 0.4830, 0.6036, 0.6269],
         [0.3144, 0.4618, 0.4565, 0.7007],
         [0.4877, 0.5020, 0.9857, 0.5402],
         [0.4901, 0.5578, 0.9758, 0.5342],
         [0.2974, 0.5703, 0.4497, 0.6416],
         [0.4939, 0.5395, 0.9115, 0.4680],
         [0.4899, 0.5215, 0.8905, 0.5683],
         [0.4879, 0.4878, 0.9151, 0.6498],
         [0.4893, 0.4880, 0.9843, 0.6714],
         [0.2185, 0.5207, 0.3928, 0.4724],
         [0.4905, 0.5371, 0.9458, 0.4564],
         [0.7018, 0.5016, 0.4213, 0.4112],
         [0.4879, 0.4914, 0.8177, 0.6248],
         [0.4935, 0.4529, 0.8365, 0.7173],
         [0.4871, 0.4840, 0.9991, 0.6672],
         [0.4775, 0.5481, 0.9031, 0.3607],
         [0.2811, 0.4520, 0.4367, 0.7219],
         [0.2116, 0.5506, 0.3670, 0.3559],
         [0.5244, 0.5215, 0.8424, 0.4114],
         [0.5391, 0.5701, 0.7667, 0.5812],
         [0.4621, 0.5238, 0.6844, 0.5013],
         [0.2304, 0.5351, 0.3946, 0.4050],
         [0.4751, 0.4810, 0.8100, 0.6608],
         [0.5058, 0.4797, 0.7607, 0.6399],
         [0.4294, 0.5329, 0.6049, 0.7048],
         [0.3394, 0.5167, 0.5596, 0.4860],
         [0.1984, 0.4984, 0.3309, 0.6260],
         [0.2048, 0.5215, 0.3719, 0.3858],
         [0.2147, 0.4714, 0.3536, 0.7174],
         [0.4891, 0.5415, 0.9906, 0.5080],
         [0.4897, 0.4724, 0.9985, 0.6923],
         [0.4929, 0.4591, 0.9780, 0.6876],
         [0.2508, 0.4492, 0.4215, 0.7213],
         [0.2106, 0.5442, 0.3519, 0.3156],
         [0.4911, 0.4713, 0.7783, 0.6908],
         [0.2445, 0.4448, 0.4008, 0.7318],
         [0.4928, 0.4719, 0.8208, 0.6396],
         [0.3257, 0.5086, 0.5025, 0.6188],
         [0.4625, 0.4509, 0.6706, 0.7155],
         [0.2037, 0.4770, 0.3550, 0.6187],
         [0.2760, 0.4584, 0.4520, 0.6909],
         [0.4882, 0.4996, 0.9585, 0.5546],
         [0.4878, 0.4604, 0.7394, 0.7019],
         [0.4920, 0.5336, 0.9306, 0.3339],
         [0.5077, 0.5456, 0.8863, 0.3958],
         [0.4878, 0.5201, 0.9987, 0.5524],
         [0.5112, 0.5827, 0.8250, 0.4683],
         [0.4903, 0.5379, 0.9411, 0.4306],
         [0.1927, 0.4942, 0.3308, 0.5906],
         [0.4882, 0.5213, 0.8267, 0.5426],
         [0.4800, 0.4833, 0.8291, 0.6418],
         [0.4834, 0.4752, 0.7484, 0.6918],
         [0.4918, 0.5149, 0.9105, 0.5857],
         [0.4931, 0.5429, 0.9266, 0.4164],
         [0.4894, 0.5239, 0.8400, 0.5194],
         [0.4842, 0.4992, 0.8786, 0.5336],
         [0.2190, 0.4838, 0.3833, 0.5947],
         [0.5852, 0.5808, 0.6921, 0.4246],
         [0.4936, 0.5404, 0.8906, 0.4358],
         [0.4980, 0.4581, 0.7582, 0.6727],
         [0.4950, 0.5393, 0.9369, 0.4023],
         [0.4920, 0.5089, 0.9805, 0.5369],
         [0.5048, 0.5481, 0.8513, 0.4191],
         [0.4832, 0.5249, 0.9516, 0.4689],
         [0.4885, 0.5384, 0.9551, 0.4283],
         [0.4931, 0.4780, 0.9497, 0.6184],
         [0.4896, 0.5500, 0.9681, 0.4385],
         [0.4887, 0.4731, 0.9990, 0.6993],
         [0.3442, 0.5132, 0.5138, 0.5130],
         [0.5190, 0.4620, 0.8191, 0.6131],
         [0.7753, 0.5238, 0.3522, 0.3171],
         [0.5016, 0.4808, 0.8970, 0.5692]]], device='cuda:0',
       grad_fn=<SelectBackward>)}, {'pred_logits': tensor([[[-1.6195e+00, -4.4590e+00, -2.9284e+00,  6.3438e+00],
         [-1.3827e+00, -3.8483e+00, -2.8002e+00,  5.8636e+00],
         [-2.4032e+00, -3.3997e+00, -2.4636e+00,  6.2632e+00],
         [-2.1288e+00, -3.7418e+00, -2.3507e+00,  6.2070e+00],
         [-1.2213e+00, -3.9480e+00, -2.8529e+00,  5.9780e+00],
         [-1.3764e+00, -3.5331e+00, -3.1036e+00,  6.2684e+00],
         [-1.1390e+00, -3.8203e+00, -2.7873e+00,  5.8636e+00],
         [-2.5303e+00, -2.5651e+00, -2.4481e+00,  5.8075e+00],
         [-1.5532e+00, -4.2327e+00, -2.6495e+00,  5.9321e+00],
         [-1.2408e+00, -4.3479e+00, -2.8397e+00,  6.1178e+00],
         [-1.4951e+00, -4.0306e+00, -2.7509e+00,  6.1961e+00],
         [-1.8421e+00, -4.0824e+00, -2.8595e+00,  6.4210e+00],
         [-2.4118e+00, -3.2511e+00, -1.9146e+00,  6.1430e+00],
         [-1.3563e+00, -3.8963e+00, -2.6481e+00,  6.1175e+00],
         [-2.4518e+00, -4.7317e+00, -2.3036e+00,  6.6069e+00],
         [-1.3567e+00, -4.1709e+00, -2.9985e+00,  5.9754e+00],
         [-1.5624e+00, -4.3517e+00, -2.9236e+00,  6.2333e+00],
         [-1.7804e+00, -3.5279e+00, -2.6411e+00,  6.0306e+00],
         [-1.9572e+00, -4.0045e+00, -2.6954e+00,  6.6457e+00],
         [-1.2843e+00, -3.6856e+00, -3.0383e+00,  6.1717e+00],
         [-1.5199e+00, -3.6128e+00, -2.4441e+00,  5.7818e+00],
         [-1.8320e+00, -3.7049e+00, -2.6154e+00,  6.2285e+00],
         [-1.1837e+00, -4.4366e+00, -2.8877e+00,  6.2217e+00],
         [-1.2268e+00, -3.9429e+00, -2.9941e+00,  6.4816e+00],
         [-1.4157e+00, -3.7111e+00, -2.7412e+00,  5.8815e+00],
         [-1.5027e+00, -3.5832e+00, -2.5867e+00,  5.7584e+00],
         [-1.2625e+00, -4.1434e+00, -2.9202e+00,  6.1943e+00],
         [-2.2661e+00, -2.7710e+00, -2.6199e+00,  5.8227e+00],
         [-1.5452e+00, -4.4758e+00, -2.5350e+00,  6.2031e+00],
         [-1.3601e+00, -4.2704e+00, -3.1336e+00,  6.2617e+00],
         [-1.4358e+00, -3.7951e+00, -3.0669e+00,  6.3648e+00],
         [-1.5361e+00, -4.2397e+00, -2.5798e+00,  6.3967e+00],
         [-2.1231e+00, -4.7545e+00, -2.7531e+00,  6.6037e+00],
         [-1.3064e+00, -3.9693e+00, -2.8833e+00,  6.2371e+00],
         [-1.4271e+00, -4.0420e+00, -3.0311e+00,  6.3297e+00],
         [-1.1235e+00, -4.1531e+00, -2.7350e+00,  5.9347e+00],
         [-1.5445e+00, -4.2464e+00, -2.7308e+00,  6.3902e+00],
         [-2.1667e+00, -3.1046e+00, -2.6195e+00,  5.7563e+00],
         [-1.5931e+00, -3.6883e+00, -2.6906e+00,  6.1028e+00],
         [-2.5027e+00, -2.9647e+00, -2.1161e+00,  5.8212e+00],
         [-1.2159e+00, -4.0123e+00, -2.9169e+00,  6.0621e+00],
         [-1.2250e+00, -4.1474e+00, -2.8074e+00,  6.0214e+00],
         [-1.3915e+00, -3.9405e+00, -2.9009e+00,  6.2916e+00],
         [-1.1725e+00, -2.8746e+00, -3.0120e+00,  5.4685e+00],
         [-1.4669e+00, -4.1936e+00, -2.9564e+00,  6.1645e+00],
         [-2.4655e+00, -2.1226e+00, -2.2757e+00,  5.4959e+00],
         [-1.5861e+00, -3.6893e+00, -2.5827e+00,  5.8086e+00],
         [-1.9286e+00, -4.4928e+00, -2.5641e+00,  6.5601e+00],
         [-1.3641e+00, -3.9145e+00, -3.0208e+00,  6.0581e+00],
         [-2.0842e+00, -2.9913e+00, -2.6180e+00,  5.5829e+00],
         [-1.2247e+00, -4.1773e+00, -2.7245e+00,  5.9249e+00],
         [-1.4689e+00, -3.8988e+00, -2.7617e+00,  5.9699e+00],
         [-1.8488e+00, -4.7588e+00, -2.7748e+00,  6.5606e+00],
         [-1.1917e+00, -3.4813e+00, -2.7955e+00,  5.3110e+00],
         [-2.3956e+00, -4.0149e+00, -2.7081e+00,  6.6886e+00],
         [-2.3813e+00, -2.5525e+00, -2.3205e+00,  5.7332e+00],
         [-1.4815e+00, -4.2510e+00, -2.9703e+00,  6.2246e+00],
         [-1.3955e+00, -3.7976e+00, -2.8563e+00,  6.0663e+00],
         [-1.4332e+00, -3.8399e+00, -2.9510e+00,  6.2020e+00],
         [-1.1752e+00, -3.9332e+00, -2.7589e+00,  5.9056e+00],
         [-1.5764e+00, -3.8702e+00, -2.6761e+00,  5.8174e+00],
         [-2.0656e+00,  4.1276e+00, -4.3906e-01, -1.4345e+00],
         [-1.2471e+00, -4.2527e+00, -2.9475e+00,  6.2896e+00],
         [-1.5740e+00, -4.0977e+00, -2.8201e+00,  6.0091e+00],
         [-1.1104e+00, -4.1922e+00, -2.7236e+00,  5.9785e+00],
         [-2.1378e+00, -4.0418e+00, -2.6308e+00,  6.4888e+00],
         [-1.4818e+00, -4.1321e+00, -2.9266e+00,  6.1879e+00],
         [-2.0038e+00, -3.5830e+00, -2.5550e+00,  6.0215e+00],
         [-1.6301e+00, -3.8899e+00, -2.7091e+00,  5.8515e+00],
         [-1.3691e+00, -3.8933e+00, -2.9713e+00,  6.2729e+00],
         [-1.3947e+00, -4.0538e+00, -2.9421e+00,  6.2044e+00],
         [ 4.3452e+00, -2.8690e-03, -1.2556e+00, -1.9017e+00],
         [-1.2634e+00, -3.7075e+00, -2.8204e+00,  6.1854e+00],
         [-1.6534e+00, -3.9514e+00, -2.7208e+00,  6.3711e+00],
         [-1.4873e+00, -4.2531e+00, -2.6164e+00,  6.1156e+00],
         [-1.8539e+00, -3.6203e+00, -2.5836e+00,  6.1172e+00],
         [-1.7158e+00, -3.8673e+00, -2.8349e+00,  6.0272e+00],
         [-1.2710e+00, -3.9977e+00, -3.0403e+00,  6.3519e+00],
         [-1.1534e+00, -3.9524e+00, -2.8002e+00,  5.9601e+00],
         [-1.2766e+00, -4.1817e+00, -2.9220e+00,  6.2540e+00],
         [-1.4688e+00, -4.0363e+00, -2.9055e+00,  6.4675e+00],
         [-1.2079e+00, -3.3657e+00, -2.9877e+00,  5.6439e+00],
         [-1.4096e+00, -3.9031e+00, -3.0698e+00,  6.3509e+00],
         [-9.8401e-01, -3.7624e+00, -2.7551e+00,  5.5333e+00],
         [-1.8545e+00, -3.6251e+00, -2.7463e+00,  5.9438e+00],
         [-1.7581e+00, -4.4184e+00, -2.3085e+00,  6.1497e+00],
         [-1.1907e+00, -3.9245e+00, -2.8815e+00,  6.1389e+00],
         [-1.5773e+00, -3.9300e+00, -2.7206e+00,  5.9558e+00],
         [-1.0606e+00, -3.5542e+00, -3.0294e+00,  6.0895e+00],
         [-1.4166e+00, -3.6723e+00, -2.8987e+00,  6.0354e+00],
         [-1.3823e+00, -3.6085e+00, -2.8037e+00,  5.8583e+00],
         [-1.7279e+00, -3.8922e+00, -2.6499e+00,  6.2398e+00],
         [-1.7093e+00, -3.7567e+00, -2.7649e+00,  6.2483e+00],
         [-1.5017e+00, -3.7128e+00, -2.8925e+00,  5.9712e+00],
         [-1.4966e+00, -3.8514e+00, -2.6731e+00,  6.0206e+00],
         [-1.1200e+00, -4.1016e+00, -2.9189e+00,  6.2101e+00],
         [-1.7542e+00, -3.7218e+00, -2.9580e+00,  6.0476e+00],
         [-1.4632e+00, -3.8406e+00, -2.9112e+00,  6.0628e+00],
         [-3.1172e+00, -5.0856e-01,  4.4328e+00, -7.6503e-01],
         [-1.4182e+00, -3.7123e+00, -2.9029e+00,  5.9989e+00]]],
       device='cuda:0', grad_fn=<SelectBackward>), 'pred_boxes': tensor([[[0.4831, 0.5319, 0.6748, 0.5955],
         [0.5303, 0.4793, 0.8343, 0.5207],
         [0.4420, 0.5738, 0.6661, 0.5213],
         [0.6993, 0.4884, 0.4894, 0.5257],
         [0.5203, 0.5036, 0.8594, 0.5192],
         [0.5145, 0.5104, 0.8863, 0.4615],
         [0.5142, 0.5462, 0.9188, 0.3874],
         [0.3376, 0.5506, 0.5492, 0.3326],
         [0.4104, 0.5135, 0.6020, 0.5497],
         [0.5266, 0.5145, 0.7976, 0.5729],
         [0.5109, 0.5679, 0.9969, 0.4354],
         [0.4717, 0.5606, 0.6730, 0.5274],
         [0.5442, 0.5422, 0.7430, 0.3496],
         [0.5237, 0.5615, 0.9200, 0.3930],
         [0.5604, 0.5628, 0.6212, 0.5938],
         [0.4453, 0.5012, 0.6666, 0.5863],
         [0.4490, 0.4788, 0.6575, 0.6430],
         [0.5806, 0.4787, 0.6582, 0.4769],
         [0.5247, 0.5399, 0.7399, 0.4287],
         [0.5201, 0.5440, 0.9396, 0.4189],
         [0.5217, 0.5497, 0.8925, 0.3481],
         [0.5175, 0.5488, 0.9152, 0.3926],
         [0.5333, 0.5308, 0.7300, 0.5863],
         [0.5173, 0.5495, 0.8851, 0.4150],
         [0.5203, 0.5400, 0.9239, 0.3849],
         [0.5211, 0.5575, 0.8972, 0.3551],
         [0.5277, 0.4936, 0.7905, 0.5897],
         [0.3799, 0.5512, 0.5850, 0.3609],
         [0.5694, 0.5001, 0.6967, 0.5762],
         [0.4984, 0.4928, 0.6746, 0.5845],
         [0.5225, 0.5177, 0.8408, 0.4723],
         [0.5113, 0.5744, 0.9983, 0.4589],
         [0.5331, 0.5235, 0.6810, 0.6517],
         [0.5252, 0.5410, 0.8530, 0.4389],
         [0.5280, 0.5166, 0.8034, 0.4989],
         [0.5259, 0.5123, 0.8828, 0.5264],
         [0.5255, 0.5215, 0.7702, 0.5492],
         [0.3749, 0.5405, 0.5748, 0.4065],
         [0.5079, 0.5651, 0.8727, 0.3719],
         [0.7518, 0.5220, 0.4103, 0.3640],
         [0.5240, 0.4939, 0.8256, 0.5545],
         [0.5263, 0.4829, 0.7970, 0.6086],
         [0.5218, 0.4919, 0.9067, 0.5549],
         [0.5083, 0.5550, 0.9364, 0.3232],
         [0.4482, 0.5063, 0.6469, 0.5696],
         [0.2876, 0.5533, 0.4501, 0.3252],
         [0.5219, 0.5177, 0.8431, 0.3970],
         [0.5153, 0.5823, 0.9911, 0.5523],
         [0.5178, 0.5335, 0.7157, 0.4494],
         [0.3790, 0.5502, 0.5763, 0.3378],
         [0.5212, 0.5263, 0.8309, 0.5019],
         [0.5222, 0.4797, 0.7566, 0.5602],
         [0.5148, 0.5203, 0.7105, 0.6439],
         [0.3888, 0.5488, 0.6089, 0.3771],
         [0.4877, 0.5543, 0.6773, 0.5488],
         [0.3359, 0.5448, 0.5282, 0.3304],
         [0.4417, 0.5335, 0.6581, 0.5375],
         [0.5111, 0.5390, 0.9735, 0.4293],
         [0.5190, 0.4832, 0.8596, 0.5578],
         [0.5106, 0.4779, 0.8847, 0.5794],
         [0.3850, 0.5084, 0.5735, 0.5447],
         [0.2120, 0.5446, 0.3542, 0.3123],
         [0.5302, 0.5176, 0.7469, 0.5867],
         [0.4178, 0.5030, 0.6217, 0.5722],
         [0.5248, 0.4926, 0.8498, 0.5579],
         [0.4975, 0.5560, 0.7114, 0.5328],
         [0.5235, 0.4954, 0.7076, 0.5772],
         [0.4102, 0.5266, 0.6218, 0.4837],
         [0.4073, 0.5210, 0.6183, 0.5172],
         [0.5240, 0.5176, 0.7819, 0.4885],
         [0.5282, 0.4932, 0.7304, 0.5890],
         [0.4913, 0.5331, 0.9302, 0.3316],
         [0.5207, 0.5622, 0.7873, 0.3631],
         [0.5155, 0.5479, 0.9899, 0.4303],
         [0.5215, 0.5806, 0.8828, 0.4227],
         [0.5216, 0.5458, 0.8430, 0.3860],
         [0.4319, 0.5338, 0.6572, 0.4843],
         [0.5203, 0.5456, 0.7951, 0.4592],
         [0.5178, 0.5188, 0.8549, 0.4968],
         [0.5207, 0.5057, 0.7450, 0.5722],
         [0.5211, 0.5263, 0.8715, 0.5075],
         [0.5070, 0.5543, 0.9154, 0.3384],
         [0.5264, 0.5381, 0.7512, 0.4494],
         [0.5134, 0.5181, 0.8871, 0.4635],
         [0.3812, 0.5276, 0.5685, 0.4873],
         [0.5839, 0.5816, 0.6801, 0.3979],
         [0.5215, 0.5488, 0.8670, 0.4045],
         [0.5223, 0.4780, 0.7783, 0.5832],
         [0.5113, 0.5610, 0.9186, 0.3581],
         [0.5165, 0.5303, 0.8985, 0.4189],
         [0.5198, 0.5536, 0.9361, 0.3572],
         [0.5229, 0.5216, 0.7920, 0.4792],
         [0.5166, 0.5513, 0.9085, 0.3657],
         [0.5212, 0.4781, 0.8627, 0.5350],
         [0.5103, 0.5536, 0.9902, 0.3942],
         [0.5056, 0.5165, 0.9755, 0.5147],
         [0.4481, 0.5316, 0.6385, 0.4477],
         [0.5242, 0.4685, 0.8261, 0.5503],
         [0.7750, 0.5241, 0.3524, 0.3153],
         [0.5214, 0.4948, 0.8510, 0.4912]]], device='cuda:0',
       grad_fn=<SelectBackward>)}]}
:Dictionary inputs to traced functions must have consistent type. Found Tensor and List[Dict[str, Tensor]]

 
  1. please simplify the steps as much as possible so they do not require additional resources to
    run, such as a private dataset.

Expected behavior:

Able to trace and export the traced DETR model for production use :)
(I confirmed I have the latest DETR source that has the PR from June 4 with the fixes to script the resnet models).

If there are no obvious error in "what you observed" provided above,
please tell us the expected behavior.

Environment:

Provide your environment information using the following command:

Collecting environment information...
PyTorch version: 1.6.0
Is debug build: No
CUDA used to build PyTorch: 10.1

OS: Ubuntu 18.04.3 LTS
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
CMake version: version 3.10.2

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 10.1.243
GPU models and configuration: GPU 0: Tesla V100-SXM2-16GB
Nvidia driver version: 435.21
cuDNN version: /usr/local/cuda-10.1/targets/x86_64-linux/lib/libcudnn.so.7.6.3

Versions of relevant libraries:
[pip3] numpy==1.18.5
[pip3] numpydoc==1.1.0
[pip3] torch==1.6.0
[pip3] torchvision==0.7.0
[conda] _pytorch_select           0.2                       gpu_0  
[conda] blas                      1.0                         mkl  
[conda] cudatoolkit               10.1.243             h6bb024c_0  
[conda] mkl                       2020.1                      217  
[conda] mkl-service               2.3.0            py37he904b0f_0  
[conda] mkl_fft                   1.1.0            py37h23d657b_0  
[conda] mkl_random                1.1.1            py37h0573a6f_0  
[conda] numpy                     1.18.5           py37ha1c710e_0  
[conda] numpy-base                1.18.5           py37hde5b4d6_0  
[conda] numpydoc                  1.1.0                      py_0  
[conda] pytorch                   1.6.0           py3.7_cuda10.1.243_cudnn7.6.3_0    pytorch
[conda] torchvision               0.7.0                py37_cu101    pytorch
@fmassa
Copy link
Contributor

fmassa commented Aug 18, 2020

Hi,

DETR supports scripting the model, not tracing it.

Can you try instead

model = torch.jit.script(model)

@lessw2020
Copy link
Contributor Author

Thanks very much @fmassa!
The jit scripting succeeded. However, in trying to use the jit model I get this odd error:
`---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
in
----> 1 scores,boxes = smodel(f1t)

~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
720 result = self._slow_forward(*input, **kwargs)
721 else:
--> 722 result = self.forward(*input, **kwargs)
723 for hook in itertools.chain(
724 _global_forward_hooks.values(),

AttributeError: 'Tensor' object has no attribute 'tensors'

Is there any additional steps needed after scripting? I put it into eval (smodel.eval() - not sure that is even needed as I don't believe jit can update BN etc? but to be safe..) and pushed it to the gpu and then tried the above.
I also tested with our normal detect function and error was the same.

@fmassa
Copy link
Contributor

fmassa commented Aug 19, 2020

I believe the torchscript model might only work if you pass a lost of 3d tensors, and not a single 4d tensor (but I would need to double-check). Can you try that?

@lessw2020
Copy link
Contributor Author

Hi @fmassa - I tested with a list of 3d tensor and 3d tensor only, as well as list of 4d tensor.
The error is always the same - Attribute error except it will say 'list' or 'Tensor' object has no attribute 'tensors'.
I did try to research a bit but can't find any reference to this specific error so far.
I can debug further if you have any tips on how to proceed?

jit_script_no_tensor

@alcinos
Copy link
Contributor

alcinos commented Aug 19, 2020

I believe it would still expect a NestedTensor as input.
Try:

inputs = NestedTensor.from_tensor_list([img]).to(device)
out = smodel(inputs)

Best of luck

@lessw2020
Copy link
Contributor Author

Hi @alcinos,
Thanks for the suggestion. I did try with a Nested tensor and nested_tensor_list - in either case, I no longer get the attribute error, but the models all go into a death spiral (infinite loop?), lock up my server and attempting to interrupt fails.
I have to restart the kernel to recover every time, without fail.
I tested with both cpu/gpu and with two different scripted detr models.
It's progress in that it now accepts the input, but end result is still not there yet.

Note that you listed 'inputs = NestedTensor.from_tensor_list([img]).to(device)' .... to be safe:

1 - in util.misc I only see a NestedTensor class, and a function 'nested_tensor_from_tensor_list'.
I tried both, but the code above implies a member function of NestedTensor which I don't see. (just want to make sure I didn't miss anything).

2 - am I correct that img is a tensorized/resized/normalized tensor with no batch size added dimension (I just used a single image tensor in the list)?

Thanks!

@fmassa
Copy link
Contributor

fmassa commented Aug 20, 2020

Hi @lessw2020

1 - NestedTensor.from_tensor_list was renamed to nested_tensor_from_tensor_list in #51 . Here is an example showing how to run a model with torchscript using nested_tensor_from_tensor_list

detr/test_all.py

Lines 69 to 76 in 5e66b4c

def test_model_script_detection(self):
model = detr_resnet50(pretrained=False).eval()
scripted_model = torch.jit.script(model)
x = nested_tensor_from_tensor_list([torch.rand(3, 200, 200), torch.rand(3, 200, 250)])
out = model(x)
out_script = scripted_model(x)
self.assertTrue(out["pred_logits"].equal(out_script["pred_logits"]))
self.assertTrue(out["pred_boxes"].equal(out_script["pred_boxes"]))

2 - The example above shows that img is a Tensor which has been resized / normalized, no batch dimension (so 3d tensor)

I believe this should be enough to get your example working. As such, I'm closing this issue but let us know if the problem persists

@fmassa fmassa closed this as completed Aug 20, 2020
@fmassa
Copy link
Contributor

fmassa commented Aug 20, 2020

I just realized that I missed one of your questions, which was that the code seemed to deadlock. This might be due to the fact that torchscript is compiling the code at its first / second invocation, which might take a while. How long did you wait for it?

@fmassa fmassa reopened this Aug 20, 2020
@fmassa fmassa added the question Further information is requested label Aug 20, 2020
@lessw2020
Copy link
Contributor Author

lessw2020 commented Aug 21, 2020

Hi @fmassa - I waited about a minute each time. I didn't realize it was doing the compile but did become alarmed at the lack of responsiveness and assumed it was in an infinite loop.
(for reference eager mode inference takes ~260 milliseconds on my server, and I was assuming JIT mode would be equal or faster...so 1 minute wait seemed more than enough at the time before restarting the kernel).

Let me re-run tomorrow and will give it more time and see if that was the root issue.

Thanks for the script above, that's very helpful and also the info about need to wait.

1 - Does this also mean that for production, we would want to push 1 or 2 images as 'warmup' to a given model before we set it to 'live' for incoming images as users won't expect to wait for a minute+ for a response?

Thanks again and will update tomorrow!

@fmassa
Copy link
Contributor

fmassa commented Aug 21, 2020

1 - Does this also mean that for production, we would want to push 1 or 2 images as 'warmup' to a given model before we set it to 'live' for incoming images as users won't expect to wait for a minute+ for a response?

@lessw2020 yes, for now torchscript uses a JIT (just-in-time) compiler, so we need to feed a few images beforehand so that the model can be compiled. Note that at some point in the future PyTorch might also support AOT (ahead of time) compilation, but it's not yet there.

Let us know how much time it takes to compile the model in your setup (with / without CUDA, etc). There might be improvements to the compilation time that could be done (if the times are too high)

@lessw2020
Copy link
Contributor Author

Hi @fmassa - was finally able to test this.
1 - I was unable to load the saved jit model fyi (let's ignore that for now).
2 - thus, I loaded eager mode, jit.script it and put the jit model to eval mode and pushed to gpu
3 - I passed in a nested_tensor using exact same tensors as per your script.
4 - I ran y = smodel(x) ...and literally waited 12 minutes (I timed it) and was about to kill it, and suddenly got an error at last. I am on V100 GPU, 61gb ram with 8x vcpus.
5 - the error was indicating that the input was not a cuda float? I purposefully then pushed the x (nested tensor) to gpu and re-ran and got the same error.

could you kindly review and advise?
on the positive, after the 12 minute compile time, the second run (after manually moving to gpu) was super fast (~200ms?)...but still get same error.

'''

RuntimeError Traceback (most recent call last)
in

~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
720 result = self._slow_forward(*input, **kwargs)
721 else:
--> 722 result = self.forward(*input, **kwargs)
723 for hook in itertools.chain(
724 _global_forward_hooks.values(),

RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
File "/home/ubuntu/cdetr/models/detr.py", line 69, in forward
if isinstance(samples, (list, torch.Tensor)):
samples = nested_tensor_from_tensor_list(samples)
features, pos = self.backbone(samples)
~~~~~~~~~~~~~ <--- HERE

    src, mask = features[-1].decompose()

File "/home/ubuntu/cdetr/models/backbone.py", line 101, in forward
def forward(self, tensor_list: NestedTensor):
xs = self0
~~~~~~~ <--- HERE
out: List[NestedTensor] = []
pos = []
File "/home/ubuntu/cdetr/models/backbone.py", line 73, in forward
def forward(self, tensor_list: NestedTensor):
xs = self.body(tensor_list.tensors)
~~~~~~~~~ <--- HERE
out: Dict[str, NestedTensor] = {}
for name, x in xs.items():
File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torchvision/models/_utils.py", line 63, in forward
out = OrderedDict()
for name, module in self.items():
x = module(x)
~~~~~~ <--- HERE
if name in self.return_layers:
out_name = self.return_layers[name]
File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torchvision/models/_utils.py", line 63, in forward
out = OrderedDict()
for name, module in self.items():
x = module(x)
~~~~~~ <--- HERE
if name in self.return_layers:
out_name = self.return_layers[name]
File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 419, in forward
def forward(self, input: Tensor) -> Tensor:
return self._conv_forward(input, self.weight)
~~~~~~~~~~~~~~~~~~ <--- HERE
File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 415, in _conv_forward
weight, self.bias, self.stride,
_pair(0), self.dilation, self.groups)
return F.conv2d(input, weight, self.bias, self.stride,
~~~~~~~~ <--- HERE
self.padding, self.dilation, self.groups)
RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same

'''

@lessw2020
Copy link
Contributor Author

hi @fmassa - I pushed everything to cpu and after about an 8 minute wait, got an error - this looks more promising though as it appears to have made it through the model?

jit_cpu_detr_fail

@fmassa
Copy link
Contributor

fmassa commented Aug 25, 2020

Hey @lessw2020

The torchscript version only support models without aux_loss. This is totally fine for all models for inference, because aux_loss is only used during training.

So what I would recommend you to do is to set aux_loss to False in your model, set it to inference mode, and try scripting it again.

@lessw2020
Copy link
Contributor Author

lessw2020 commented Aug 27, 2020

Hi @fmassa - thanks a bunch for all the help! The main issue was the aux_loss!
1 - Was able to get JIT mode up and running, both single images and sets.
2 - Save/load issue went away as well as part of the aux_loss removal.

One last question if I may:
(argh, hit space bar and that closed this issue..anyway hopefully you will still see this):
the one thing that threw me was that the args for aux loss has the name as "--no_aux_loss", but the value stored is really 'aux_loss'.
parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_false', ...

in other words a value for 'no_aux_loss' imo would mean that True = no aux_loss, but... True for this arg really means the presence or not of aux_loss.
In other words my understanding of the arg based on the name of '--no_aug_loss' is opposite of it's true function.
Is it worth updating the name for the arg to make it clearer - i.e. the arg is really just 'aux-loss' ?

Regardless, did want to say thanks a ton for all the help getting JIT mode working as will be using that for production! The initial compile time is still a big hit but at least being aware of it means we can handle it.

@fmassa
Copy link
Contributor

fmassa commented Sep 5, 2020

Hey, sorry I missed this before

Is it worth updating the name for the arg to make it clearer - i.e. the arg is really just 'aux-loss' ?

The original name of the arg was aux-loss. Given it's a boolean, if we were to switch it back to aux-loss we would also have to change the default value to false, which leads to worse results while training

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants