Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Add tfidf bm25 #2353

Open
wants to merge 94 commits into
base: branch-24.12
Choose a base branch
from

Conversation

jperez999
Copy link

This PR will add support for tfidf and BM25 preprocessing of sparse matrix. It does not require the user to work within the confines of the COO or CSR matrix. It only requires the triplets of data ( row, column, value). With this information, we are able to preprocess the values accordingly. Putting this up to get eyes on this, to make sure this is going in the correct direction or if not, to adjust.

Unit tests are still required for these features.

ajschmidt8 and others added 30 commits July 14, 2020 17:05
[skip ci] Update master references for main branch
[HOTFIX] Remove `-g` from cython compile commands
Our `devel` Docker containers need to be switched to using `conda` compilers to resolve a linking error. `raft` is in those containers, but hasn't yet been built with `conda` compilers. This PR addresses that.

These changes won't cleanly merge into `branch-22.08` unfortunately due to the changes in rapidsai#641, but we can address that another time.

Authors:
   - AJ Schmidt (https://github.com/ajschmidt8)
   - Corey J. Nolet (https://github.com/cjnolet)
   - Vyas Ramasubramani (https://github.com/vyasr)

Approvers:
   - Corey J. Nolet (https://github.com/cjnolet)
@shwina I'm going to apologize ahead of time for this, but i was trying to forward merge your branch 22.10 locally to create a new PR from it and I accidentally pushed to your remote branch. I cherry-picked the commits over to a new branch for the hotfix.

Authors:
   - Bradley Dice (https://github.com/bdice)
   - Ashwin Srinath (https://github.com/shwina)

Approvers:
   - Ray Douglass (https://github.com/raydouglass)
[RELEASE] raft v22.12.01 [skip-gpuci]
Copy link
Member

@cjnolet cjnolet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for these changes Julio! They look great for the most part. Mostly minor things- 1) we need to use RAFT primitives where and whenever possible instead of thrust. 2) We should test at larger scales and write more reproducible tests by providing naive kernels to evalute the results.

* limitations under the License.
*/

#include <raft/core/device_mdarray.hpp>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally you just import what you need, so if you need all of these then go ahead and import them. Otherwise, try to remove things that are unneeded.

cpp/include/raft/sparse/matrix/detail/preprocessing.cuh Outdated Show resolved Hide resolved
cpp/include/raft/sparse/matrix/detail/preprocessing.cuh Outdated Show resolved Hide resolved
cpp/include/raft/sparse/matrix/detail/preprocessing.cuh Outdated Show resolved Hide resolved
cpp/include/raft/sparse/matrix/detail/preprocessing.cuh Outdated Show resolved Hide resolved
cpp/include/raft/sparse/matrix/detail/preprocessing.cuh Outdated Show resolved Hide resolved
cpp/include/raft/sparse/matrix/detail/preprocessing.cuh Outdated Show resolved Hide resolved
cpp/test/sparse/preprocess_coo.cu Outdated Show resolved Hide resolved
@cjnolet
Copy link
Member

cjnolet commented Aug 14, 2024

/ok to test

data.data_handle(),
stream);

thrust::reduce_by_key(raft::resource::get_thrust_policy(handle),
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ended up using the thrust version because it could handle vectors, which allows me to use the same code for both the csr and coo matrix versions of the encoding logic. Also the raft version does not support sparse matrix versions.

Copy link
Member

@cjnolet cjnolet Sep 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this to compute the degree of each row in the sparse format? We have routines for this already. We have a coo_degree function here. Degree computation for CSR is actually really trivial- since you already have an array of offsets, you don't even need to count the columns because you can literally just diff the array (e.g. compute the difference between each value in the indptr array and the value that occurred before it). If you can't guarantee uniqueness, you can also use a simple mask as an efficient way to compute uniqueness. For COO, you can then just add the 1s in the mask for each row segment. For a sorted COO, the degree computation is actually trivial- you only need the row and columns arrays and do a segmented reduce.

Copy link
Author

@jperez999 jperez999 Sep 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we were using this function for rows, coo_degree was absolutely the right play. I was just trying to follow code reuse, but that ended up causing problems with larger datasets (in the form of illegal memory access errors). I have made it so this function is only used when we are trying to get a column-wise sum of the values (not just checking if there is a value like with rows). And we cant just use l1 normalization because I need the avg column size across all columns and the individual column avg. The reduce by key functions available in raft are for dense matrices only. This is why I have opted to use the thrust reduce_by_key when we are doing the column based processing.

auto keys_out = raft::make_device_vector<int, int64_t>(handle, num_rows);
auto counts_out = raft::make_device_vector<int, int64_t>(handle, num_rows);

thrust::reduce_by_key(raft::resource::get_thrust_policy(handle),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a great function already for removing duplicates from sparse formats- it uses a simple mask to figure out where the duplicates are. It's really efficient. Also, if the goal is to get the degree for each row of the matrix, we have functions for this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm really less concerned about thrust in tests... however it does make it easier if we can reuse raft routines

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the two available functions for removing duplicates that I saw are compute_duplicates_mask and max_duplicates now I did not find one that takes a mask and removes based on it. And what I have here is a function that uses the mask to remove the dupes. Max duplicates works a little different than the mask. It will opt to leave the max value row, however that is not the behavior I want. I would like to take the last vertice value that we see in the COO vectors. This aligns more with the compute_duplicates_mask function which is what I used here. But all this other stuff is required. If you look further up in the code you will see that before we remove the dupes we use that function to calculate the mask. The mask is actually used in this function. If there is a function that exists already that takes this mask directly and will remove the 0 value indices, I would love to use it. I could not find it though.

@cjnolet cjnolet changed the base branch from branch-24.10 to branch-24.12 September 26, 2024 14:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
Status: In Progress
Development

Successfully merging this pull request may close these issues.

8 participants