Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Indexing failure due to self-mapped broadcast domains #2762

Open
naoyam opened this issue Aug 6, 2024 · 0 comments
Open

Indexing failure due to self-mapped broadcast domains #2762

naoyam opened this issue Aug 6, 2024 · 0 comments

Comments

@naoyam
Copy link
Collaborator

naoyam commented Aug 6, 2024

Originally seen in #2685.

tl;dr:

  • Self-mapping due to non-concretized broadcast domains results in the self-mapping validation error
  • Non-concretized broadcast domains should not matter for indexing, so self mapping should not be a validation error
  • Just removing the validation doesn't work, either, likely because of some code relying on the property even for broadcast domains
  • The new IdModel-based indexer doesn't seem to have the issue.

This fusion fails at indexing:

TEST_F(NVFuserTest, TMP) {
  Fusion fusion;
  FusionGuard fg(&fusion);

  auto tv0 = makeSymbolicTensor(1);
  fusion.addInput(tv0);

  auto tv1 = set(tv0);
  auto tv2 = broadcast(tv1, {true, false});

  auto tv3 = broadcast(tv1, {true, false});

  auto tv4 = add(tv2, tv3);
  fusion.addOutput(tv4);

  tv3->reorder({{0, 1}});

  tv1->inlineAt(1);

  fusion.printMath();
  fusion.printKernel();
}
Inputs:
  T0_g[ iS0{i0} ], float
Outputs:
  T4_g[ bS6{1}, iS7{i0} ], float

%kernel_math {
T1_l[ iS1{i0} ] ca_pos( 1 )
   = Set( T0_g[ iS0{i0} ], cache_op=Streaming )
T2_l[ bS2{1}, iS3{i0} ] produce_pos( 2 )
   = broadcast( T1_l[ iS1{i0} ] ca_pos( 1 ) )
T3_l[ iS5{i0}, bS4{1} ] produce_pos( 1 )
   = broadcast( T1_l[ iS1{i0} ] ca_pos( 1 ) )
T4_g[ bS6{1}, iS7{i0} ]
   = T2_l[ bS2{1}, iS3{i0} ] produce_pos( 2 )
   + T3_l[ iS5{i0}, bS4{1} ] produce_pos( 1 );
} // %kernel_math

What matters most is:

T1_l[ iS1{i0} ] ca_pos( 1 )
T2_l[ bS2{1}, iS3{i0} ] produce_pos( 2 )
T3_l[ iS5{i0}, bS4{1} ] produce_pos( 1 )

And iS1, iS3 and iS5 are all inlined together. Since T3 is reordered, bS2 and bS4 loops need to be placed outer and inner of the iS1/iS3/iS5 loop, respectively. In other words, we would need to create loops as:

for i in bS2:
  for j in iS1
     for k in bS4

In fact, this is the Kernel IR just after the LoopNestGenerator pass:

After LoopNestGenerator:
FOR 0 in bS2{1}:
  FOR i56 in iS11{( (( (( getMetaData(T0) )).logical_size ))[0] )}:
    T1_l[ iS9{( (( (( getMetaData(T0) )).logical_size ))[0] )} ] ca_pos( 1 )
       = Set( T0_g[ iS8{( (( (( getMetaData(T0) )).logical_size ))[0] )} ], cache_op=Streaming )
    T2_l[ bS2{1}, iS10{( (( (( getMetaData(T0) )).logical_size ))[0] )} ] produce_pos( 2 )
       = broadcast( T1_l[ iS9{( (( (( getMetaData(T0) )).logical_size ))[0] )} ] ca_pos( 1 ) )
    FOR 0 in bS4{1}:
      T3_l[ iS11{( (( (( getMetaData(T0) )).logical_size ))[0] )}, bS4{1} ] produce_pos( 1 )
         = broadcast( T1_l[ iS9{( (( (( getMetaData(T0) )).logical_size ))[0] )} ] ca_pos( 1 ) )
FOR 0 in bS6{1}:
  FOR i55 in iS12{( (( (( getMetaData(T0) )).logical_size ))[0] )}:
    T4_g[ bS6{1}, iS12{( (( (( getMetaData(T0) )).logical_size ))[0] )} ]
       = T2_l[ bS2{1}, iS10{( (( (( getMetaData(T0) )).logical_size ))[0] )} ] produce_pos( 2 )
       + T3_l[ iS11{( (( (( getMetaData(T0) )).logical_size ))[0] )}, bS4{1} ] produce_pos( 1 );

The problematic expression is the one producing T3. Since all broadcast domains of this fusion are exactly mapped, bS2 and bS4 are indeed self-mapped, resulting in the original error:

C++ exception with description "!concrete_to_loop.count(concrete_loop_id) INTERNAL ASSERT FAILED at "csrc/device_lower/analysis/index_compute.cpp":5
97, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Unsupported loop structure. Two loops are mapped together.bS4{1} and bS2{1}
Exception raised from validateLoopStructure at csrc/device_lower/analysis/index_compute.cpp:597 (most recent call first):

I think this validation error is a false alarm since they are just broadcast domains. There's no broadcast forwarding in this fusion, so they should be just no-op for indexing. However, just disabling the validation resulted in another validation error:

C++ exception with description "loops.size() <= loop_domains.size() INTERNAL ASSERT FAILED at "csrc/device_lower/analysis/index_compute.cpp":215, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Loop domain didn't replay all loops
Exception raised from getNonGlobalInitialIndexParameters at csrc/device_lower/analysis/index_compute.cpp:215 (most recent call first):
frame #0: <unknown function> + 0xcf9893 (0x558eac519893 in ./bin/nvfuser_tests)

Not unexpectedly, there seems to be some code relying on the self-mapping-free property even for broadcast domains.

Note that this fusion seems to work with the new indexer:

$ NVFUSER_ENABLE=id_model(all)  ./bin/nvfuser_tests --gtest_filter="NVFuserTest.TMP" 
Note: Google Test filter = *TMP
[==========] Running 1 test from 1 test suite.
[----------] Global test environment set-up.
[----------] 1 test from NVFuserTest
[ RUN      ] NVFuserTest.TMP
Inputs:
  T0_g[ iS0{i0} ], float
Outputs:
  T4_g[ bS6{1}, iS7{i0} ], float

%kernel_math {
T1_l[ iS1{i0} ] ca_pos( 1 )
   = Set( T0_g[ iS0{i0} ], cache_op=Streaming )
T2_l[ bS2{1}, iS3{i0} ] produce_pos( 2 )
   = broadcast( T1_l[ iS1{i0} ] ca_pos( 1 ) )
T3_l[ iS5{i0}, bS4{1} ] produce_pos( 1 )
   = broadcast( T1_l[ iS1{i0} ] ca_pos( 1 ) )
T4_g[ bS6{1}, iS7{i0} ]
   = T2_l[ bS2{1}, iS3{i0} ] produce_pos( 2 )
   + T3_l[ iS5{i0}, bS4{1} ] produce_pos( 1 );
} // %kernel_math

[W alias_memory.cpp:821] Warning: Lower_alias_memory : dynamic sized register allocation (function operator())
__global__ void CUDAGeneratedKernel(Tensor<float, 1, 1> T0, Tensor<float, 2, 2> T4) {
  float T2[T0.logical_size[0LL]];
  float T3[T0.logical_size[0LL]];
  #pragma unroll 1
  for(nvfuser_index_t i0 = 0LL; i0 < T0.logical_size[0LL]; ++i0) {
    float T1[1LL];
    T1[0LL] = 0LL;
    T1[0LL]
       = T0[(T0.alloc_stride[0LL] * i0)];
    T2[i0]
       = T1[0LL];
    T3[i0]
       = T1[0LL];
  }
  #pragma unroll 1
  for(nvfuser_index_t i1 = 0LL; i1 < T0.logical_size[0LL]; ++i1) {
    T4[i1]
      = T2[i1]
      + T3[i1];
  }
}

This is expected since in the new indexer, the broadcast domains are not going to participate in indexing.

Since it should work with the new indexer, I don't think it's worthwhile to fix the issue with the legacy indexer.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant