Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compile error in where(x, a, b) with single precision a or b #2403

Open
jacobhinkle opened this issue Feb 2, 2023 · 1 comment
Open

Compile error in where(x, a, b) with single precision a or b #2403

jacobhinkle opened this issue Feb 2, 2023 · 1 comment

Comments

@jacobhinkle
Copy link
Collaborator

🐛 Describe the bug

Current tests use double-precision constants passed to where(), which works. There is currently no using Float = Scalar<float> scalar defined but we'd like to extend where to support single precision arguments. However, adding that type, currently this fails to compile:

TEST_F(NVFuserTest, FusionWhereFloat32_CUDA) {       
  Fusion fusion;       
  FusionGuard fg(&fusion);       
       
  auto tv0 = makeSymbolicTensor(1, DataType::Bool);       
  fusion.addInput(tv0);       
       
  using Float = Scalar<float>; // no built-in Float scalar, so we define it      
       
  auto tv1 = where(tv0,       
      IrBuilder::create<Float>(3.0),       
      IrBuilder::create<Float>(5.0)       
    );       
  fusion.addOutput(tv1);       
       
  auto options = at::TensorOptions().dtype(at::kBool).device(at::kCUDA, 0);       
  auto t0 = at::randint(0, 1, {5}, options);          
  auto ref = at::where(t0, (float)3.0, (float)5.0);                                    
                                                     
  std::vector<IValue> inputs = {t0};       
  auto lparams = schedulePointwise(&fusion, inputs);       
       
  FusionExecutor fe;       
  fe.compileFusion(&fusion, inputs, lparams);       
  /*
C++ exception with description "false INTERNAL ASSERT FAILED at "/opt/pytorch/pytorch/third_party/nvfuser/csrc/executor_utils.cpp":1237, please report a bug to PyTorch.
...
__global__ void kernel1(Tensor<bool, 1> T0, Tensor<float, 1> T1) {
  int i59;
  i59 = (((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x);
  if ((i59 < T0.size[0])) {
    bool T2[1];
    T2[0] = 0;
    T2[0]
       = T0[(((T0.stride[0] * ((nvfuser_index_t)blockIdx.x)) * 128) + (T0.stride[0] * ((nvfuser_index_t)threadIdx.x)))]; 
    float T3[1];
    T3[0]
       = where(T2[0], f3, f2);
    T1[i59]
       = T3[0];
  }
}
}

CUDA NVRTC compile error: __tmp_kernel1.cu(8920): error: identifier "f3" is undefined

__tmp_kernel1.cu(8920): error: identifier "f2" is undefined

2 errors detected in the compilation of "__tmp_kernel1.cu".
*/                                                                  
  auto cg_outputs = fe.runFusion(inputs);                                     
                                                                                   
  testValidate(&fusion, cg_outputs, inputs, {ref}, __LINE__, __FILE__);            
}          

If this is the only known use case for single-precision scalars, then it may be simpler to add a dtype argument to the C++ where. Otherwise, it seems we may need to add some more to the codegen to be aware of scalars other than Double so that they appear in the kernel signature.

Versions

Collecting environment information...
PyTorch version: 2.0.0a0+git4121ffc
Is debug build: False
CUDA used to build PyTorch: 12.0
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: 10.0.0-4ubuntu1
CMake version: version 3.25.0
Libc version: glibc-2.31

Python version: 3.10.9 (main, Feb 1 2023, 00:41:45) [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.13.0-40-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 12.0.140
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3080
Nvidia driver version: 525.85.05
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.7.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.23.5
[pip3] torch==2.0.0a0+git4121ffc
[pip3] torchvision==0.15.0a0+0dceac0
[conda] Could not collect

@jacobhinkle
Copy link
Collaborator Author

Note this is not specific to where; a similar error is made for add(tv0, IrBuilder::create<Float>(5.0)).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant