Skip to content

Commit

Permalink
Use single-precision floats for UMAP, for speed.
Browse files Browse the repository at this point in the history
This is also consistent with the configuration in our R packages.
  • Loading branch information
LTLA committed Oct 14, 2024
1 parent 6ce5a2a commit 26f8e97
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 6 deletions.
4 changes: 2 additions & 2 deletions js/runUmap.js
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,9 @@ export function initializeUmap(x, options = {}) {
nnres = x;
}

raw_coords = utils.createFloat64WasmArray(2 * nnres.numberOfCells());
raw_coords = utils.createFloat32WasmArray(2 * nnres.numberOfCells());
output = gc.call(
module => module.initialize_umap(nnres.results, epochs, minDist, raw_coords.offset, nthreads),
module => module.initialize_umap(nnres.results, neighbors, epochs, minDist, raw_coords.offset, nthreads),
UmapStatus,
raw_coords
);
Expand Down
11 changes: 11 additions & 0 deletions js/utils.js
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,17 @@ export function createBigUint64WasmArray (length) {
return wa.createBigUint64WasmArray(wasmArraySpace(), length);
}

/**
* Helper function to create a Float32WasmArray from the **wasmarrays.js** package.
*
* @param {number} length - Length of the array.
*
* @return {Float32WasmArray} Float32WasmArray on the **scran.js** Wasm heap.
*/
export function createFloat32WasmArray(length) {
return wa.createFloat32WasmArray(wasmArraySpace(), length);
}

/**
* Helper function to create a Float64WasmArray from the **wasmarrays.js** package.
*
Expand Down
20 changes: 16 additions & 4 deletions src/run_umap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include <chrono>

struct UmapStatus {
typedef umappp::Status<int32_t, double> Status;
typedef umappp::Status<int32_t, float> Status;

Status status;

Expand All @@ -27,7 +27,7 @@ struct UmapStatus {

UmapStatus deepcopy(uintptr_t Y) const {
auto copy = status;
copy.set_embedding(reinterpret_cast<double*>(Y), false);
copy.set_embedding(reinterpret_cast<float*>(Y), false);
return UmapStatus(std::move(copy));
}

Expand All @@ -41,8 +41,20 @@ UmapStatus initialize_umap(const NeighborResults& neighbors, int32_t num_epochs,
opt.min_dist = min_dist;
opt.num_epochs = num_epochs;
opt.num_threads = nthreads;
double* embedding = reinterpret_cast<double*>(Y);
auto stat = umappp::initialize(neighbors.neighbors, 2, embedding, opt);

std::vector<std::vector<std::pair<int32_t, float> > > copy(neighbors.neighbors.size());
for (size_t i = 0, end = copy.size(); i < end; ++i) {
auto& output = copy[i];
const auto& src = neighbors.neighbors[i];
size_t n = src.size();
output.reserve(n);
for (size_t j = 0; j < n; ++j) {
output.emplace_back(src[j].first, src[j].second);
}
}

float* embedding = reinterpret_cast<float*>(Y);
auto stat = umappp::initialize(std::move(copy), 2, embedding, opt);
return UmapStatus(std::move(stat));
}

Expand Down

0 comments on commit 26f8e97

Please sign in to comment.