Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
elronbandel committed Nov 29, 2023
1 parent d9ac62a commit a685c71
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 20 deletions.
1 change: 1 addition & 0 deletions src/unitxt/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .dataclass import AbstractField
from .random_utils import get_random


class Collection(Artifact):
items: typing.Collection = AbstractField()

Expand Down
2 changes: 1 addition & 1 deletion src/unitxt/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
StreamInstanceOperator,
StreamSource,
)
from .random_utils import nested_seed, get_random
from .random_utils import get_random, nested_seed
from .stream import MultiStream, Stream
from .text_utils import nested_tuple_to_string
from .utils import flatten_dict
Expand Down
14 changes: 10 additions & 4 deletions src/unitxt/random_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,37 @@
__default_seed__ = 42
_thread_local = threading.local()


def get_seed():
try:
return _thread_local.seed
except AttributeError:
_thread_local.seed = __default_seed__
return _thread_local.seed


def get_random():
try:
return _thread_local.random
try:
return _thread_local.random
except AttributeError:
_thread_local.random = python_random.Random(get_seed())
return _thread_local.random
return _thread_local.random


random = get_random()


def set_seed(seed):
_thread_local.seed = seed
get_random().seed(seed)



def get_random_string(length):
letters = string.ascii_letters
result_str = "".join(get_random().choice(letters) for _ in range(length))
return result_str


@contextlib.contextmanager
def nested_seed(sub_seed=None):
old_state = get_random().getstate()
Expand Down
2 changes: 1 addition & 1 deletion src/unitxt/renderers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str

inputs = instance["inputs"]
outputs = instance["outputs"]

source = self.template.process_inputs(inputs)
targets = self.template.process_outputs(outputs)

Expand Down
21 changes: 7 additions & 14 deletions tests/test_random_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import string
import unittest

from src.unitxt.random_utils import __default_seed__, nested_seed, get_random, set_seed
from src.unitxt.random_utils import __default_seed__, get_random, nested_seed, set_seed


def first_randomization():
Expand Down Expand Up @@ -52,7 +52,7 @@ def test_sepration_from_global_python_seed(self):
b = first_randomization()

self.assertEqual(a, b)

def test_thread_safety_sanity(self):
import threading
import time
Expand Down Expand Up @@ -85,14 +85,12 @@ def thread_function(name, sleep_time, results):
# x = threading.Thread(target=thread_function, args=(i, sleep_time, results))
# threads.append(x)
# x.start()



for index in range(3):
with self.subTest(f'Within Thread {index}'):
with self.subTest(f"Within Thread {index}"):
self.assertEqual(results[index][0], results[index][1])

with self.subTest(f'Across all threads'):
with self.subTest(f"Across all threads"):
flatten_results = [item for sublist in results for item in sublist]
self.assertEqual(len(set(flatten_results)), 1)

Expand Down Expand Up @@ -127,20 +125,15 @@ def thread_function(name, sleep_time, results):
x = threading.Thread(target=thread_function, args=(i, sleep_time, results))
threads.append(x)
x.start()



for index, thread in enumerate(threads):
thread.join()
with self.subTest(f'Within Thread {index}'):

with self.subTest(f"Within Thread {index}"):
self.assertIsNotNone(results[index][0])
self.assertIsNotNone(results[index][1])
self.assertEqual(results[index][0], results[index][1])


with self.subTest(f'Across all threads'):
with self.subTest(f"Across all threads"):
flatten_results = [item for sublist in results for item in sublist]
self.assertEqual(len(set(flatten_results)), 1)


0 comments on commit a685c71

Please sign in to comment.