From 0490c3647bfbd43728393d2ff8ff03afa8fa0a28 Mon Sep 17 00:00:00 2001 From: pca006132 Date: Mon, 10 Oct 2022 02:41:26 +0800 Subject: [PATCH] collider: improve performance Reduces the number of atomic instructions needed and only perform allocation once. This improves the performance for complicated models by ~10%. --- src/collider/src/collider.cpp | 70 ++++++++++++++++------------------- 1 file changed, 32 insertions(+), 38 deletions(-) diff --git a/src/collider/src/collider.cpp b/src/collider/src/collider.cpp index 9fd62f71a..775c883b8 100644 --- a/src/collider/src/collider.cpp +++ b/src/collider/src/collider.cpp @@ -157,26 +157,28 @@ struct CreateRadixTree { } }; -template +template struct FindCollisions { thrust::pair querryTri_; - int* numOverlaps_; - const int maxOverlaps_; + int* counts; const Box* nodeBBox_; const thrust::pair* internalChildren_; __host__ __device__ int RecordCollision(int node, - const thrust::tuple& query) { + thrust::tuple& query) { const T& queryObj = thrust::get<0>(query); const int queryIdx = thrust::get<1>(query); + int& count = counts[queryIdx]; bool overlaps = nodeBBox_[node].DoesOverlap(queryObj); if (overlaps && IsLeaf(node)) { - int pos = AtomicAdd(*numOverlaps_, 1); - if (pos >= maxOverlaps_) - return -1; // Didn't allocate enough memory; bail out - querryTri_.first[pos] = queryIdx; - querryTri_.second[pos] = Node2Leaf(node); + if (allocateOnly) { + count++; + } else { + int pos = count++; + querryTri_.first[pos] = queryIdx; + querryTri_.second[pos] = Node2Leaf(node); + } } return overlaps && IsInternal(node); // Should traverse into node } @@ -188,15 +190,16 @@ struct FindCollisions { int top = -1; // Depth-first search int node = kRoot; + const int queryIdx = thrust::get<1>(query); + // same implies that this query do not have any collision + if (!allocateOnly && counts[queryIdx] == counts[queryIdx + 1]) return; while (1) { int internal = Node2Internal(node); int child1 = internalChildren_[internal].first; int child2 = internalChildren_[internal].second; int traverse1 = RecordCollision(child1, query); - if (traverse1 < 0) return; int traverse2 = RecordCollision(child2, query); - if (traverse2 < 0) return; if (!traverse1 && !traverse2) { if (top < 0) break; // done @@ -268,33 +271,24 @@ Collider::Collider(const VecDH& leafBB, */ template SparseIndices Collider::Collisions(const VecDH& querriesIn) const { - int maxOverlaps = querriesIn.size() * 4; - SparseIndices querryTri(maxOverlaps); - int nOverlaps = 0; - while (1) { - // scalar number of overlaps found - VecDH nOverlapsD(1, 0); - // calculate Bounding Box overlaps - for_each_n( - autoPolicy(querriesIn.size()), zip(querriesIn.cbegin(), countAt(0)), - querriesIn.size(), - FindCollisions({querryTri.ptrDpq(), nOverlapsD.ptrD(), maxOverlaps, - nodeBBox_.ptrD(), internalChildren_.ptrD()})); - nOverlaps = nOverlapsD[0]; - if (nOverlaps <= maxOverlaps) - break; - else { // if not enough memory was allocated, guess how much will be needed - int lastQuery = querryTri.Get(0).back(); - float ratio = static_cast(querriesIn.size()) / lastQuery; - if (ratio > 1000) // do not trust the ratio if it is too large - maxOverlaps *= 2; - else - maxOverlaps *= 2 * ratio; - querryTri.Resize(maxOverlaps); - } - } - // remove unused part of array - querryTri.Resize(nOverlaps); + // note that the length is 1 larger than the number of queries so the last + // element can store the sum when using exclusive scan + VecDH counts(querriesIn.size() + 1, 0); + auto policy = autoPolicy(querriesIn.size()); + // compute the number of collisions to determine the size for allocation and + // offset, this avoids the need for atomic + for_each_n(policy, zip(querriesIn.cbegin(), countAt(0)), querriesIn.size(), + FindCollisions( + {thrust::pair(nullptr, nullptr), counts.ptrD(), + nodeBBox_.ptrD(), internalChildren_.ptrD()})); + // compute start index for each query and total count + exclusive_scan(policy, counts.begin(), counts.end(), counts.begin()); + SparseIndices querryTri(counts.back()); + // actually recording collisions + for_each_n( + policy, zip(querriesIn.cbegin(), countAt(0)), querriesIn.size(), + FindCollisions({querryTri.ptrDpq(), counts.ptrD(), + nodeBBox_.ptrD(), internalChildren_.ptrD()})); return querryTri; }