Skip to content

Commit

Permalink
fix the cumsum bug for large size
Browse files Browse the repository at this point in the history
  • Loading branch information
wawltor committed Jun 21, 2022
1 parent cf0d602 commit d10c1b7
Showing 1 changed file with 7 additions and 12 deletions.
19 changes: 7 additions & 12 deletions paddle/phi/kernels/gpu/cum_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,8 @@ __global__ void BlockScanKernel(T* d_out,
} temp_storage;

int bx = blockIdx.x;
int by = blockIdx.y;

BlockPrefixCallbackOp<T, Op> prefix_op(Identity<T, Op>::value, op);
T block_aggregate = static_cast<T>(0);

// Obtain this block's segment of consecutive keys (blocked across threads)
int item_per_block = BLOCK_THREADS * ITEMS_PER_THREAD;
Expand All @@ -192,7 +190,7 @@ __global__ void BlockScanKernel(T* d_out,
valid_item = scan_size;
}

int offset = bx * scan_size + block_offset + by * (inner_size * scan_size);
int offset = block_offset + bx * scan_size;

T thread_keys[ITEMS_PER_THREAD];
BlockLoadT(temp_storage.load)
Expand Down Expand Up @@ -271,7 +269,6 @@ void ScanKernel(const Context& dev_ctx,
return;
}


size_t height = 1;
size_t width = 1;
for (size_t i = 0; i <= axis; i++) {
Expand Down Expand Up @@ -308,6 +305,7 @@ void ScanKernel(const Context& dev_ctx,
int outer_size = height / scan_size;
int inner_size = width;
// Consider the size of shared memory, here block size is 128

dim3 scan_grid(outer_size, inner_size);
dim3 reverse_grid = scan_grid;
if (reverse) {
Expand All @@ -323,13 +321,14 @@ void ScanKernel(const Context& dev_ctx,
in_data, out_data, scan_size, outer_size, inner_size);
}
}
int64_t grid_size = outer_size * inner_size;
if (!transpose && !reverse) {
BlockScanKernel<T, 128, 4, Op><<<scan_grid, 128, 0, dev_ctx.stream()>>>(
BlockScanKernel<T, 128, 4, Op><<<grid_size, 128, 0, dev_ctx.stream()>>>(
out_data, in_data, outer_size, inner_size, scan_size, exclusive, op);

} else {
BlockScanKernel<T, 128, 4, Op>
<<<scan_grid, 128, 0, dev_ctx.stream()>>>(next_out_data,
<<<grid_size, 128, 0, dev_ctx.stream()>>>(next_out_data,
next_in_data,
outer_size,
inner_size,
Expand Down Expand Up @@ -391,9 +390,5 @@ PD_REGISTER_KERNEL(cumsum,
int,
int64_t) {}

PD_REGISTER_KERNEL(logcumsumexp,
GPU,
ALL_LAYOUT,
phi::LogcumsumexpKernel,
float,
double) {}
PD_REGISTER_KERNEL(
logcumsumexp, GPU, ALL_LAYOUT, phi::LogcumsumexpKernel, float, double) {}

1 comment on commit d10c1b7

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.