Skip to content

Commit

Permalink
Forcing empty list/empty tuple behavior. (#514)
Browse files Browse the repository at this point in the history
* Forcing empty list/empty tuple behavior.

* Tailscale.

* No tailscale anymore ?

* Do not push the package it's not necessary.

* Remove `Debug`.
  • Loading branch information
Narsil committed Aug 5, 2024
1 parent 6f791c5 commit aa4ad82
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 15 deletions.
18 changes: 7 additions & 11 deletions .github/workflows/python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,6 @@ jobs:
uses: docker/setup-qemu-action@v2
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
- uses: tailscale/github-action@v1
with:
authkey: ${{ secrets.TAILSCALE_AUTHKEY }}
- name: Set short sha
id: vars
run: echo "GITHUB_SHA_SHORT=$(git rev-parse --short HEAD)" >> $GITHUB_ENV
Expand All @@ -107,7 +104,7 @@ jobs:
with:
# list of Docker images to use as base name for tags
images: |
registry.internal.huggingface.tech/safetensors/s390x
ghcr.io/safetensors/s390x
# generate Docker tags based on the following events/attributes
tags: |
type=schedule
Expand All @@ -118,18 +115,17 @@ jobs:
type=semver,pattern={{major}}
type=sha
- name: Login to Registry
uses: docker/login-action@v2
uses: docker/login-action@v3
with:
registry: registry.internal.huggingface.tech
username: ${{ secrets.REGISTRY_USERNAME }}
password: ${{ secrets.REGISTRY_PASSWORD }}
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Test big endian
uses: docker/build-push-action@v4
with:
push: true
platforms: linux/s390x
file: Dockerfile.s390x.test
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
cache-from: type=registry,ref=registry.internal.huggingface.tech/safetensors/s390x:cache,mode=min
cache-to: type=registry,ref=registry.internal.huggingface.tech/safetensors/s390x:cache,mode=min
cache-from: type=registry,ref=ghcr.io/safetensors/s390x:cache,mode=max
cache-to: type=registry,ref=ghcr.io/safetensors/s390x:cache,mode=max
24 changes: 20 additions & 4 deletions bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ use pyo3::exceptions::{PyException, PyFileNotFoundError};
use pyo3::prelude::*;
use pyo3::sync::GILOnceCell;
use pyo3::types::IntoPyDict;
use pyo3::types::PySlice;
use pyo3::types::{PyByteArray, PyBytes, PyDict, PyList};
use pyo3::types::{PyByteArray, PyBytes, PyDict, PyList, PySlice};
use pyo3::Bound as PyBound;
use pyo3::{intern, PyErr};
use safetensors::slice::TensorIndexer;
Expand Down Expand Up @@ -842,10 +841,27 @@ impl PySafeSlice {
pub fn __getitem__(&self, slices: &PyBound<'_, PyAny>) -> PyResult<PyObject> {
match &self.storage.as_ref() {
Storage::Mmap(mmap) => {
let slices: Slice = slices.extract()?;
let pyslices = slices;
let slices: Slice = pyslices.extract()?;
let is_list = pyslices.is_instance_of::<PyList>();
let slices: Vec<SliceIndex> = match slices {
Slice::Slice(slice) => vec![slice],
Slice::Slices(slices) => slices,
Slice::Slices(slices) => {
if slices.is_empty() && is_list {
vec![SliceIndex::Slice(PySlice::new_bound(
pyslices.py(),
0,
0,
0,
))]
} else if is_list {
return Err(SafetensorError::new_err(
"Non empty lists are not implemented",
));
} else {
slices
}
}
};
let data = &mmap[self.info.data_offsets.0 + self.offset
..self.info.data_offsets.1 + self.offset];
Expand Down
20 changes: 20 additions & 0 deletions bindings/python/tests/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,10 @@ def test_torch_slice(self):
self.assertEqual(list(tensor.shape), [10, 5])
torch.testing.assert_close(tensor, A)

tensor = slice_[tuple()]
self.assertEqual(list(tensor.shape), [10, 5])
torch.testing.assert_close(tensor, A)

tensor = slice_[:2]
self.assertEqual(list(tensor.shape), [2, 5])
torch.testing.assert_close(tensor, A[:2])
Expand All @@ -270,6 +274,10 @@ def test_torch_slice(self):
self.assertEqual(list(tensor.shape), [8])
torch.testing.assert_close(tensor, A[2:, -1])

tensor = slice_[list()]
self.assertEqual(list(tensor.shape), [0, 5])
torch.testing.assert_close(tensor, A[list()])

def test_numpy_slice(self):
A = np.random.rand(10, 5)
tensors = {
Expand All @@ -284,6 +292,10 @@ def test_numpy_slice(self):
self.assertEqual(list(tensor.shape), [10, 5])
self.assertTrue(np.allclose(tensor, A))

tensor = slice_[tuple()]
self.assertEqual(list(tensor.shape), [10, 5])
self.assertTrue(np.allclose(tensor, A))

tensor = slice_[:2]
self.assertEqual(list(tensor.shape), [2, 5])
self.assertTrue(np.allclose(tensor, A[:2]))
Expand Down Expand Up @@ -312,10 +324,18 @@ def test_numpy_slice(self):
self.assertEqual(list(tensor.shape), [8])
self.assertTrue(np.allclose(tensor, A[2:, -5]))

tensor = slice_[list()]
self.assertEqual(list(tensor.shape), [0, 5])
self.assertTrue(np.allclose(tensor, A[list()]))

with self.assertRaises(SafetensorError) as cm:
tensor = slice_[2:, -6]
self.assertEqual(str(cm.exception), "Invalid index -6 for dimension 1 of size 5")

with self.assertRaises(SafetensorError) as cm:
tensor = slice_[[0, 1]]
self.assertEqual(str(cm.exception), "Non empty lists are not implemented")

with self.assertRaises(SafetensorError) as cm:
tensor = slice_[2:, 20]
self.assertEqual(
Expand Down

0 comments on commit aa4ad82

Please sign in to comment.