Skip to content

Commit

Permalink
Add tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
vahidk committed Jul 7, 2024
1 parent 5034c32 commit 4b02596
Show file tree
Hide file tree
Showing 12 changed files with 260 additions and 531 deletions.
29 changes: 29 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
name: CI

on:
push:
branches: [main]
pull_request:
branches: [main]

jobs:
test:
runs-on: ubuntu-latest

steps:
- name: Checkout code
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.12'

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e .
- name: Run tests
run: |
python run_tests.py
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
*.pyc
dist
tfrecord.egg-info
/test_*
/*.proto
/*.sh
/.pypirc
11 changes: 0 additions & 11 deletions MANIFEST

This file was deleted.

4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ This library allows reading and writing tfrecord files efficiently in python. Th

## Installation

```pip3 install tfrecord```
```
pip3 install 'tfrecord[torch]'
```

## Usage

Expand Down
11 changes: 11 additions & 0 deletions run_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import unittest
import sys

if __name__ == '__main__':
loader = unittest.TestLoader()
tests = loader.discover('tests')
testRunner = unittest.TextTestRunner()
result = testRunner.run(tests)
# Exit with a non-zero status code if tests failed
if not result.wasSuccessful():
sys.exit(1)
8 changes: 6 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

setup(
name='tfrecord',
version='1.14.4',
version='1.14.5',
description='TFRecord reader',
long_description=long_description,
long_description_content_type='text/markdown',
Expand All @@ -27,5 +27,9 @@
url='https://github.com/vahidk/tfrecord',
packages=find_packages(),
license='MIT',
install_requires=install_requires
install_requires=install_requires,
extras_require={
'torch': ['torch'],
},
test_suite='tests',
)
72 changes: 72 additions & 0 deletions tests/test_read_and_write.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import os
import tempfile
import unittest

import numpy as np

from tfrecord.reader import example_loader, tfrecord_iterator
from tfrecord.writer import TFRecordWriter


class TestReadWrite(unittest.TestCase):

def write_tfrecord(self, filename, records):
writer = TFRecordWriter(filename)
for datum in records:
writer.write(datum)
writer.close()

def read_tfrecord(self, filename):
iterator = tfrecord_iterator(filename)
records = list(iterator)
return records

def test_write_and_read_integers(self):
datum = {"int_key": (123, "int")}
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
filename = temp_file.name
self.write_tfrecord(filename, [datum])

records = self.read_tfrecord(filename)

self.assertEqual(len(records), 1)
example = list(example_loader(filename, None))
np.testing.assert_array_equal(
example[0]["int_key"], np.array([123], dtype=np.int64)
)

os.remove(filename)

def test_write_and_read_floats(self):
datum = {"float_key": (1.23, "float")}
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
filename = temp_file.name
self.write_tfrecord(filename, [datum])

records = self.read_tfrecord(filename)

self.assertEqual(len(records), 1)
example = list(example_loader(filename, None))
np.testing.assert_array_equal(
example[0]["float_key"], np.array([1.23], dtype=np.float32)
)

os.remove(filename)

def test_write_and_read_string_arrays(self):
datum = {"string_key": ([b"test1", b"test2"], "byte")}
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
filename = temp_file.name
self.write_tfrecord(filename, [datum])

records = self.read_tfrecord(filename)

self.assertEqual(len(records), 1)
example = list(example_loader(filename, None))
self.assertEqual(example[0]["string_key"], b"test1")

os.remove(filename)


if __name__ == "__main__":
unittest.main()
46 changes: 46 additions & 0 deletions tests/test_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import unittest
from unittest.mock import mock_open, patch

import numpy as np
from tfrecord.reader import (
example_loader,
sequence_loader,
tfrecord_iterator,
process_feature,
)

from tfrecord import example_pb2


class TestFeatureProcessing(unittest.TestCase):

def setUp(self):
self.feature_bytes = example_pb2.Feature(
bytes_list=example_pb2.BytesList(value=[b"test"])
)
self.feature_float = example_pb2.Feature(
float_list=example_pb2.FloatList(value=[1.0])
)
self.feature_int = example_pb2.Feature(
int64_list=example_pb2.Int64List(value=[1])
)

def test_process_feature_bytes(self):
result = process_feature(
self.feature_bytes, "byte", {"byte": "bytes_list"}, "key"
)
self.assertEqual(result, b"test")

def test_process_feature_float(self):
result = process_feature(
self.feature_float, "float", {"float": "float_list"}, "key"
)
np.testing.assert_array_equal(result, np.array([1.0], dtype=np.float32))

def test_process_feature_int(self):
result = process_feature(self.feature_int, "int", {"int": "int64_list"}, "key")
np.testing.assert_array_equal(result, np.array([1], dtype=np.int64))


if __name__ == "__main__":
unittest.main()
41 changes: 41 additions & 0 deletions tests/test_writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import unittest
import tempfile
import os
import numpy as np

from tfrecord.reader import tfrecord_iterator
from tfrecord.writer import TFRecordWriter


class TestTFRecordWriter(unittest.TestCase):

def test_tfrecord_writer_write_example(self):
datum = {"key": (b"value", "byte")}
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
filename = temp_file.name
writer = TFRecordWriter(filename)
writer.write(datum)
writer.close()

iterator = tfrecord_iterator(filename)
records = list(iterator)
self.assertEqual(records[0], b"\n\x12\n\x10\n\x03key\x12\t\n\x07\n\x05value")
os.remove(filename)

def test_tfrecord_writer_write_sequence_example(self):
datum = {"key": (b"value", "byte")}
sequence_datum = {"seq_key": ([b"seq_value"], "byte")}
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
filename = temp_file.name
writer = TFRecordWriter(filename)
writer.write(datum, sequence_datum)
writer.close()

iterator = tfrecord_iterator(filename)
records = list(iterator)
self.assertTrue(records[0].tobytes().startswith(b"\n\x12\n\x10\n\x03key\x12\t\n\x07\n\x05value"))
os.remove(filename)


if __name__ == "__main__":
unittest.main()
5 changes: 4 additions & 1 deletion tfrecord/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from tfrecord import tools
from tfrecord import torch
try:
from tfrecord import torch
except ImportError:
pass

from tfrecord import example_pb2
from tfrecord import iterator_utils
Expand Down
Loading

0 comments on commit 4b02596

Please sign in to comment.