Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accuracy Errors for Resnext50 #1698

Open
TedThemistokleous opened this issue Apr 19, 2023 · 6 comments
Open

Accuracy Errors for Resnext50 #1698

TedThemistokleous opened this issue Apr 19, 2023 · 6 comments
Assignees
Labels
bug Something isn't working roadmap Tasks to finish for a release

Comments

@TedThemistokleous
Copy link
Collaborator

TedThemistokleous commented Apr 19, 2023

Since we got resnext50 running via #1283 there's still a few issues to sort out

Model mirror link: https://zenodo.org/record/6617879/files/resnext50_32x4d_fpn.onnx

  1. using the --fp16 flag fails on one of the outputs with get_tuple_elem operator
  2. we're currently unable to do a 1:1 accuracy run with the onnxruntime CPU EP using the accuracy_checker.py
  3. Turning on MIGRAPHX_GPU_DEBUG=1 gives asserts for out of bounds with our jit-gather running out of bounds

Current speculation is that there's some odd interaction between topk and Gather/GatherND

The solution to test this appears to be adding a reshape between these operators to pair down the data vector shape using a matcher as the onnxruntime CPU EP seems to indicate an expected output of 5 items vs the 300 we still see

Further investigation needs to be investigated for the fp16 issue as well as accuracy testing.

@TedThemistokleous TedThemistokleous self-assigned this Apr 19, 2023
@TedThemistokleous TedThemistokleous added bug Something isn't working roadmap Tasks to finish for a release labels Apr 19, 2023
@TedThemistokleous
Copy link
Collaborator Author

TedThemistokleous commented May 9, 2023

Just a quick update here with a log dump as Ive seen weird behavior between Navi system and the remote MI250.

Looks like this is related to the interaction with topk and gatherND I believe the added transpose is causing some odd error which flips our len dimensions causing us to throw an error due to indexing outside of the proper vector output range.

had to use a bunch of the MIGRAPHX_GPU_DEBUG flags to get this and leveraged rocgdb to get a proper look at the internals of the kernel to see what the heck was going on. Backtrace full seemed to dump what I needed to see what is happening here. Tried this with random, fill1 and fill0 and it appears that this doesn't effect the output.

Key thing was to use MIGRAPHX_GPU_OPTIMIZE=0 so the args don't all get optimized out as we compile said kernels

Ive added a simple verify test called test_gathernd_1d.cpp for now to debug this and not have to run the entire network to see if I can quickly replicate this and what triggers this error/state. I think I'll have to add a new matcher sooner than later between topk->transpose->gathernd

Run instruction: main:@864 = gpu::topk[k=5940000,axis=0,largest=1](main:@862,main:@863) -> [float_type, {5940000}, {1}, int64_type, {5940000}, {1}]
Time: 0.008746ms, 8681.9ms
Run instruction: main:@865 = transpose[permutation={1, 0}](main:@840) -> float_type, {23760000, 1}, {1, 23760000}
Time: 0.003006ms, 0.003447ms
Run instruction: main:@866 = load[offset=289263744,end=384303744](main:@1) -> float_type, {23760000}, {1}
Time: 0.002234ms, 0.002675ms
Run instruction: main:@867 = gpu::code_object[code_object=522336,symbol_name=gathernd_kernel,global=225280,local=1024,](main:@722,main:@865,main:@866) -> float_type, {23760000}, {1}
./migraphx/kernels/shape.hpp:83: index_int migraphx::shape<migraphx::integral_const_array<unsigned int, 23760000>, migraphx::integral_const_array<unsigned int, 1>>::index(index_int) const [Lens = migraphx::integral_const_array<unsigned int, 23760000>, Strides = migraphx::integral_const_array<unsigned int, 1>]: assertion 'i == compute_index(i)' failed.

Thread 2124 "driver" received signal SIGABRT, Aborted.
[Switching to thread 2124, lane 43 (AMDGPU Lane 5:1:1:11/43 (110,0,0)[683,0,0])]
0x00007faa0d12b518 in abort () at /opt/rocm-5.5.0/include/hip/amd_detail/amd_device_functions.h:800
800         return __builtin_trap();
(gdb) bt full
#0  0x00007faa0d12b518 in abort () at /opt/rocm-5.5.0/include/hip/amd_detail/amd_device_functions.h:800
No locals.
#1  0x00007faa0d1502ac in migraphx::assert_fail<char [22], char [29], int, char [278]> (assertion=..., file=..., 
    line=<error reading variable: Cannot access memory at address 0x2000000000a2c>, function=...) at ./migraphx/kernels/debug.hpp:144
No locals.
#2  0x00007faa0d1453a8 in _ZZNK8migraphx5shapeINS_20integral_const_arrayIjJLj23760000EEEENS1_IjJLj1EEEEE5indexEjENKUlDpOT_E_clIJRA22_KcRA29_SA_iRA278_SA_EEEDaS7_ (this=0x2000000000a28, private_migraphx_xs=..., private_migraphx_xs=..., private_migraphx_xs=..., 
    private_migraphx_xs=...) at ./migraphx/kernels/shape.hpp:83
No locals.
#3  0x00007faa0d143cdc in migraphx::shape<migraphx::integral_const_array<unsigned int, 23760000u>, migraphx::integral_const_array<unsigned int, 1u> >::index (this=0x20000000009c0, i=23760000) at ./migraphx/kernels/shape.hpp:83
No locals.
#4  0x00007faa0d1434f0 in migraphx::tensor_view<float, migraphx::shape<migraphx::integral_const_array<unsigned int, 23760000u>, migraphx::integral_const_array<unsigned int, 1u> > >::index_to_offset::index_to_offset<unsigned long> (this=0x20000000007f0, i=23760000)
    at ./migraphx/kernels/tensor_view.hpp:59
No locals.
#5  0x00007faa0d12e3e4 in migraphx::source_location_capture<migraphx::tensor_view<float, migraphx::shape<migraphx::integral_const_array<unsigned int, 23760000u>, migraphx::integral_const_array<unsigned int, 1u> > >::index_to_offset>::source_location_capture<unsigned long, migraphx::tensor_view<float, migraphx::shape<migraphx::integral_const_array<unsigned int, 23760000u>, migraphx::integral_const_array<unsigned int, 1u> > >::index_to_offset> (this=0x20000000007f0, px=23760000, ploc=...) at ./migraphx/kernels/debug.hpp:127
No locals.
#6  0x00007faa0d119ff0 in _ZZN8migraphx8gatherndINS_11tensor_viewIfNS_5shapeINS_20integral_const_arrayIjJLj23760000EEEENS3_IjJLj1EEEEEEEENS1_IfNS2_INS3_IjJLj23760000ELj1EEEENS3_IjJLj1ELj23760000EEEEEEEES7_NS_17gathernd_settingsINS_16generic_constantIZZZ15gathernd_kernelENKUlDpOT_E_clIJS7_SB_S7_EEEDaSG_ENKUlvE_clEvE3funEEEEEEvRKT_RKT0_RKT1_T2_ENKUlSN_E_clIjEEDaSN_ (this=0x2000000000720, i=113323)
    at ./migraphx/kernels/gathernd.hpp:93
        indices_ptr = 0x7f9fef324480
        j = 113323
        batch_idx = 0
        slice_indices = 0x7f9fef392f2c
        relative_slice_offset = 23760000
        slice_offset = 23760000
#7  0x00007faa0d116fd0 in _ZN8migraphx5index11invoke_loopIZNS_8gatherndINS_11tensor_viewIfNS_5shapeINS_20integral_const_arrayIjJLj23760000EEEENS5_IjJLj1EEEEEEEENS3_IfNS4_INS5_IjJLj23760000ELj1EEEENS5_IjJLj1ELj23760000EEEEEEEES9_NS_17gathernd_settingsINS_16generic_constantIZZZ15gathernd_kernelENKUlDpOT_E_clIJS9_SD_S9_EEEDaSI_ENKUlvE_clEvE3funEEEEEEvRKT_RKT0_RKT1_T2_EUlSP_E_jjEEDTclfp_fp0_EESP_SS_SV_ (f=..., i=113323)
    at ./migraphx/kernels/index.hpp:160
No locals.
#8  0x00007faa0d116bac in _ZN8migraphx5index15for_stride_loopIZNS_8gatherndINS_11tensor_viewIfNS_5shapeINS_20integral_const_arrayIjJLj23760000EEEENS5_IjJLj1EEEEEEEENS3_IfNS4_INS5_IjJLj23760000ELj1EEEENS5_IjJLj1ELj23760000EEEEEEEES9_NS_17gathernd_settingsINS_16generic_constantIZZZ15gathernd_kernelENKUlDpOT_E_clIJS9_SD_S9_EEEDaSI_ENKUlvE_clEvE3funEEEEEEvRKT_RKT0_RKT1_T2_EUlSP_E_NS_17integral_constantIjLj23760000EEENS10_IjLj225280EEEEEvjSS_SV_SP_ (start=113323, n=..., stride=..., f=...) at ./migraphx/kernels/index.hpp:182
--Type <RET> for more, q to quit, c to continue without paging--
        i = 113323
        k = 0
#9  0x00007faa0d116124 in _ZN8migraphx5index10for_strideILb0EZNS_8gatherndINS_11tensor_viewIfNS_5shapeINS_20integral_const_arrayIjJLj23760000EEEENS5_IjJLj1EEEEEEEENS3_IfNS4_INS5_IjJLj23760000ELj1EEEENS5_IjJLj1ELj23760000EEEEEEEES9_NS_17gathernd_settingsINS_16generic_constantIZZZ15gathernd_kernelENKUlDpOT_E_clIJS9_SD_S9_EEEDaSI_ENKUlvE_clEvE3funEEEEEEvRKT_RKT0_RKT1_T2_EUlSP_E_NS_17integral_constantIjLj23760000EEENS10_IjLj225280EEEEEvjSV_SY_SS_ (start=113323, n=..., stride=..., f=...) at ./migraphx/kernels/index.hpp:214
No locals.
#10 0x00007faa0d1158a4 in _ZNK8migraphx5index13global_strideIZNS_8gatherndINS_11tensor_viewIfNS_5shapeINS_20integral_const_arrayIjJLj23760000EEEENS5_IjJLj1EEEEEEEENS3_IfNS4_INS5_IjJLj23760000ELj1EEEENS5_IjJLj1ELj23760000EEEEEEEES9_NS_17gathernd_settingsINS_16generic_constantIZZZ15gathernd_kernelENKUlDpOT_E_clIJS9_SD_S9_EEEDaSI_ENKUlvE_clEvE3funEEEEEEvRKT_RKT0_RKT1_T2_EUlSP_E_NS_17integral_constantIjLj23760000EEEEEvSS_SP_ (this=0x2000000000228, n=..., f=...) at ./migraphx/kernels/index.hpp:226
No locals.
#11 0x00007faa0d1151f8 in _ZN8migraphx8gatherndINS_11tensor_viewIfNS_5shapeINS_20integral_const_arrayIjJLj23760000EEEENS3_IjJLj1EEEEEEEENS1_IfNS2_INS3_IjJLj23760000ELj1EEEENS3_IjJLj1ELj23760000EEEEEEEES7_NS_17gathernd_settingsINS_16generic_constantIZZZ15gathernd_kernelENKUlDpOT_E_clIJS7_SB_S7_EEEDaSG_ENKUlvE_clEvE3funEEEEEEvRKT_RKT0_RKT1_T2_ (data_t=..., indices_t=..., output_t=..., s=...)
    at ./migraphx/kernels/gathernd.hpp:68
        ind = {global = 113323, local = 683, group = 110}
        batch_dims = {static value = 0}
        output_shape = {lens = {<migraphx::array<unsigned int, 1u>> = {d = {23760000}}, <No data fields>}, 
          strides = {<migraphx::array<unsigned int, 1u>> = {d = {1}}, <No data fields>}}
        indices_shape = {lens = {<migraphx::array<unsigned int, 2u>> = {d = {23760000, 1}}, <No data fields>}, 
          strides = {<migraphx::array<unsigned int, 2u>> = {d = {1, 23760000}}, <No data fields>}}
        data_shape = {lens = {<migraphx::array<unsigned int, 1u>> = {d = {23760000}}, <No data fields>}, 
          strides = {<migraphx::array<unsigned int, 1u>> = {d = {1}}, <No data fields>}}
        indices_shape_lens = {<migraphx::array<unsigned int, 2u>> = {d = {23760000, 1}}, <No data fields>}
        data_shape_lens = {<migraphx::array<unsigned int, 1u>> = {d = {23760000}}, <No data fields>}
        num_slice_dims = 1
        num_slices = 23760000
        slice_size = 1
        num_batches = 1
        data_batch_stride = 23760000
        num_slices_per_batch = 23760000
#12 0x00007faa0d113444 in _ZZ15gathernd_kernelENKUlDpOT_E_clIJN8migraphx11tensor_viewIfNS4_5shapeINS4_20integral_const_arrayIjJLj23760000EEEENS7_IjJLj1EEEEEEEENS5_IfNS6_INS7_IjJLj23760000ELj1EEEENS7_IjJLj1ELj23760000EEEEEEEESB_EEEDaS1_ (this=0x20000000000b8, xs=..., xs=..., 
    xs=...) at main.cpp:16
        settings = {batch_dims = {static value = 0}}
#13 0x00007faa0d112ddc in _ZN8migraphx17make_tensors_implIZ15gathernd_kernelEUlDpOT_E_JLj0ELj1ELj2EEJvvvEEEDaT_NS_6detail3seqIJXspT0_EEEEDpPT1_ (f=..., xs=0x7f9fdb5dd080, xs=0x7f9fdb5dd080, xs=0x7f9fdb5dd080) at ./migraphx/kernels/args.hpp:39
No locals.
#14 0x00007faa0d112650 in _ZZZN8migraphx12make_tensorsEvENKUlDpPT_E_clIJvvvEEEDaS2_ENKUlT_E_clIZ15gathernd_kernelEUlDpOS0_E_EEDaS5_ (
--Type <RET> for more, q to quit, c to continue without paging--
    this=0x2000000000038, f=...) at ./migraphx/kernels/args.hpp:45
No locals.
#15 0x00007faa0d1124c0 in gathernd_kernel (in_data=0x7f9fe7ec6620, in_indices=0x7f9fef324480, output=0x7f9fdb5dd080) at main.cpp:14

@TedThemistokleous
Copy link
Collaborator Author

Having issues trying to reproduce this output in a test of the block that seems to throw the assert before the gathernd

main:@2049 = sigmoid(main:@2043) -> float_type, {90000, 264}, {264, 1}
main:@2050 = slice[axes={0},starts={0},ends={0}](main:@401) -> int64_type, {0}, {1}   ]  <- never used
main:@2051 = concat[axis=0](main:@2050,main:@397) -> int64_type, {1}, {1}             ]
main:@2052 = contiguous(main:@2049) -> float_type, {90000, 264}, {264, 1}
main:@2053 = reshape[dims={-1}](main:@2052) -> float_type, {23760000}, {1}
main:@2054 = multibroadcast[out_lens={23760000},out_dyn_dims={}](main:@396) -> float_type, {23760000}, {0}
main:@2055 = greater(main:@2053,main:@2054) -> float_type, {23760000}, {1}
main:@2056 = convert[target_type=0](main:@2055) -> bool_type, {23760000}, {1}
main:@2057 = nonzero(main:@2056) -> int64_type, {1, 23760000}, {23760000, 1}
main:@2058 = transpose[permutation={1, 0}](main:@2057) -> int64_type, {23760000, 1}, {1, 23760000}
main:@2059 = gathernd[batch_dims=0](main:@2053,main:@2058) -> float_type, {23760000}, {1}

Primarily due to reshape saying the shape is invalid when trying to mirror what I'm seeing in that section of code. Interestingly that's not a 1:1 to what netron sees

image

It as if the second block isn't read in correctly at all for the reshape

When this compiles in it looks like its doing the following for this branch (taken from my gdb dump and cherrypicked)


Run instruction: main:@692 = gpu::code_object[code_object=667328,symbol_name=gather_kernel,global=225280,local=1024,](main:@690,main:@2,main:@691) -> float_type, {90000, 264}, {264, 1}
Run instruction: main:@693 = load[offset=505871616,end=600911616](main:@1) -> float_type, {90000, 264}, {264, 1}
Run instruction: main:@694 = gpu::code_object[code_object=739560,symbol_name=sigmoid_kernel,global=5940000,local=1024,](main:@692,main:@693) -> float_type, {90000, 264}, {264, 1}
Run instruction: main:@722 = reshape[dims={-1}](main:@694) -> float_type, {23760000}, {1}
Run instruction: main:@723 = load[offset=6551328,end=30311328](main:@1) -> bool_type, {23760000}, {1}
Run instruction: main:@724 = gpu::code_object[code_object=878608,symbol_name=greater_convert_kernel,global=5940000,local=1024,](main:@722,main:@723) -> bool_type, {23760000}, {1}
Run instruction: main:@725 = load[offset=829404576,end=1019484576](main:@1) -> int64_type, {1, 23760000}, {23760000, 1}
Run instruction: main:@726 = gpu::nonzero(main:@724,main:@725) -> int64_type, {1, 23760000}, {23760000, 1}
Run instruction: main:@839 = load[offset=602576704,end=697616704](main:@1) -> float_type, {1, 23760000}, {23760000, 1}
Run instruction: main:@840 = gpu::code_object[code_object=815040,symbol_name=convert_kernel,global=5940000,local=1024,](main:@726,main:@839) -> float_type, {1, 23760000}, {23760000, 1}
Run instruction: main:@865 = transpose[permutation={1, 0}](main:@840) -> float_type, {23760000, 1}, {1, 23760000}
Run instruction: main:@866 = load[offset=359126528,end=454166528](main:@1) -> float_type, {23760000}, {1}
Run instruction: main:@867 = gpu::code_object[code_object=522336,symbol_name=gathernd_kernel,global=225280,local=1024,](main:@722,main:@865,main:@866) -> float_type, {23760000}, {1}

It appears we aren't parsing in that branch correctly then unless we're doing some other functionality with reshapes?

@TedThemistokleous
Copy link
Collaborator Author

So few updates on this.

It isn't an issue with reshapes after talking with @pfultz2 . That part is infact correct

Data driven ops seem to be doing something incorrect here: nonzero, gathernd, and topk are using larger vectors than needed. The GPU assert is related to this with nonzero ->tranpose->gathernd not giving the correct index and going out of bounds.

Tried an accuracy test since the output of size of 300 is valid and baked into the network when viewing this on netron. From the looks of it, onnxruntime is cutting down the vector size of the data and resulting in only 5 outputs. After relaxing the length condition and checking the first 5 values of the output, I'm seeing accuracy failures between the CPU EP and MIGraphX

From the previous initial run, topk was also a large amount of the overhead which Ive found is related to the issues with the other data driven ops. The value K being used is the largest vector size, which is resulting in topK just performing a sort on the entire vector of data. Similar to the transpose->gatherND issue we're seeing due to nonzero not cutting down the size.

I think the next step here is to combine these ops in a pass/matcher here to handle the correct sizes at run time.

@TedThemistokleous
Copy link
Collaborator Author

After talking this over with @CharlieL7 it appears the only way out of this is to use dynamic shapes.

This is due to the following block with respect to handling topk->gather

Image

In SSD resnet for example, we have a similar configuration of blocks which topk outputs to two seperate gathers. In the resnext50 case, one of these gathers, is fed by a concat from all subsequent topk outputs.

This is an issue as nonzero, sets the value of K for all topK, which results in K being the largest possible shape in static shapes, thus resulting in padding at the end of the topk. The resulting buffer before we do a gather is then

     TOPK1                           TOPK2            ...   TOPKN
[valid data | 0...0]  [valid data | 0.....0]  ... [valid data | 0..0]

                                                    

thus after concat, we attempt to gather on:

[valid data | 0....0 | valid data | 0....0] with no idea where the proper data boundaries are.

One "solution" discussed was to move back the concat and combine topk branches for inference but this then defeats the purpose of retinanet parallelizing these

If we want to move forward with this and dynamic shapes

  • Add dynamic shapes to nonzero
  • Add dynamic shapes to gather
  • Add dynamic shapes to concat

or

  • Move around operations with a matcher <- risky and may not work.

@TedThemistokleous
Copy link
Collaborator Author

set to backlog until I can get dynamic shape support.

@TedThemistokleous
Copy link
Collaborator Author

Related to #1886

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working roadmap Tasks to finish for a release
Projects
None yet
Development

No branches or pull requests

1 participant