diff --git a/CHANGELOG.md b/CHANGELOG.md index 3afaf991022..ff8be54ede5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased](https://github.com/open-telemetry/opentelemetry-python/compare/v1.13.0...HEAD) +- Enabled custom samplers via entry points + ([#2972](https://github.com/open-telemetry/opentelemetry-python/pull/2972)) - Update log symbol names ([#2943](https://github.com/open-telemetry/opentelemetry-python/pull/2943)) - Update explicit histogram bucket boundaries diff --git a/opentelemetry-sdk/src/opentelemetry/sdk/_configuration/__init__.py b/opentelemetry-sdk/src/opentelemetry/sdk/_configuration/__init__.py index fa057f785ca..c2280e1b27e 100644 --- a/opentelemetry-sdk/src/opentelemetry/sdk/_configuration/__init__.py +++ b/opentelemetry-sdk/src/opentelemetry/sdk/_configuration/__init__.py @@ -23,7 +23,6 @@ from os import environ from typing import Dict, Optional, Sequence, Tuple, Type -from pkg_resources import iter_entry_points from typing_extensions import Literal from opentelemetry.environment_variables import ( @@ -55,6 +54,7 @@ from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor, SpanExporter from opentelemetry.sdk.trace.id_generator import IdGenerator +from opentelemetry.sdk.util import _import_config_components from opentelemetry.semconv.resource import ResourceAttributes from opentelemetry.trace import set_tracer_provider @@ -226,26 +226,6 @@ def _init_logging( logging.getLogger().addHandler(handler) -def _import_config_components( - selected_components, entry_point_name -) -> Sequence[Tuple[str, object]]: - component_entry_points = { - ep.name: ep for ep in iter_entry_points(entry_point_name) - } - component_impls = [] - for selected_component in selected_components: - entry_point = component_entry_points.get(selected_component, None) - if not entry_point: - raise RuntimeError( - f"Requested component '{selected_component}' not found in entry points for '{entry_point_name}'" - ) - - component_impl = entry_point.load() - component_impls.append((selected_component, component_impl)) - - return component_impls - - def _import_exporters( trace_exporter_names: Sequence[str], metric_exporter_names: Sequence[str], @@ -287,10 +267,9 @@ def _import_exporters( def _import_id_generator(id_generator_name: str) -> IdGenerator: - # pylint: disable=unbalanced-tuple-unpacking - [(id_generator_name, id_generator_impl)] = _import_config_components( + id_generator_name, id_generator_impl = _import_config_components( [id_generator_name.strip()], "opentelemetry_id_generator" - ) + )[0] if issubclass(id_generator_impl, IdGenerator): return id_generator_impl diff --git a/opentelemetry-sdk/src/opentelemetry/sdk/trace/sampling.py b/opentelemetry-sdk/src/opentelemetry/sdk/trace/sampling.py index 3cdd34cfe8c..38a3338b02f 100644 --- a/opentelemetry-sdk/src/opentelemetry/sdk/trace/sampling.py +++ b/opentelemetry-sdk/src/opentelemetry/sdk/trace/sampling.py @@ -64,7 +64,7 @@ ... The tracer sampler can also be configured via environment variables ``OTEL_TRACES_SAMPLER`` and ``OTEL_TRACES_SAMPLER_ARG`` (only if applicable). -The list of known values for ``OTEL_TRACES_SAMPLER`` are: +The list of built-in values for ``OTEL_TRACES_SAMPLER`` are: * always_on - Sampler that always samples spans, regardless of the parent span's sampling decision. * always_off - Sampler that never samples spans, regardless of the parent span's sampling decision. @@ -73,8 +73,7 @@ * parentbased_always_off - Sampler that respects its parent span's sampling decision, but otherwise never samples. * parentbased_traceidratio - Sampler that respects its parent span's sampling decision, but otherwise samples probabalistically based on rate. -Sampling probability can be set with ``OTEL_TRACES_SAMPLER_ARG`` if the sampler is traceidratio or parentbased_traceidratio, when not provided rate will be set to 1.0 (maximum rate possible). - +Sampling probability can be set with ``OTEL_TRACES_SAMPLER_ARG`` if the sampler is traceidratio or parentbased_traceidratio. Rate must be in the range [0.0,1.0]. When not provided rate will be set to 1.0 (maximum rate possible). Prev example but with environment variables. Please make sure to set the env ``OTEL_TRACES_SAMPLER=traceidratio`` and ``OTEL_TRACES_SAMPLER_ARG=0.001``. @@ -97,13 +96,45 @@ # created spans will now be sampled by the TraceIdRatioBased sampler with rate 1/1000. with trace.get_tracer(__name__).start_as_current_span("Test Span"): ... + +In order to create a configurable custom sampler, create an entry point for the custom sampler factory method under the entry point group, ``opentelemetry_traces_sampler``. The custom sampler factory +method must be of type ``Callable[[str], Sampler]``, taking a single string argument and returning a Sampler object. The single input will come from the string value of the +``OTEL_TRACES_SAMPLER_ARG`` environment variable. If ``OTEL_TRACES_SAMPLER_ARG`` is not configured, the input will be an empty string. For example: + +.. code:: python + + setup( + ... + entry_points={ + ... + "opentelemetry_traces_sampler": [ + "custom_sampler_name = path.to.sampler.factory.method:CustomSamplerFactory.get_sampler" + ] + } + ) + # ... + class CustomRatioSampler(Sampler): + def __init__(rate): + # ... + # ... + class CustomSamplerFactory: + @staticmethod + get_sampler(sampler_argument): + try: + rate = float(sampler_argument) + return CustomSampler(rate) + except ValueError: # In case argument is empty string. + return CustomSampler(0.5) + +In order to configure you application with a custom sampler's entry point, set the ``OTEL_TRACES_SAMPLER`` environment variable to the key name of the entry point. For example, to configured the +above sampler, set ``OTEL_TRACES_SAMPLER=custom_sampler_name`` and ``OTEL_TRACES_SAMPLER_ARG=0.5``. """ import abc import enum import os from logging import getLogger from types import MappingProxyType -from typing import Optional, Sequence +from typing import Callable, Optional, Sequence # pylint: disable=unused-import from opentelemetry.context import Context @@ -111,6 +142,7 @@ OTEL_TRACES_SAMPLER, OTEL_TRACES_SAMPLER_ARG, ) +from opentelemetry.sdk.util import _import_config_components from opentelemetry.trace import Link, SpanKind, get_current_span from opentelemetry.trace.span import TraceState from opentelemetry.util.types import Attributes @@ -161,6 +193,9 @@ def __init__( self.trace_state = trace_state +_OTEL_SAMPLER_ENTRY_POINT_GROUP = "opentelemetry_traces_sampler" + + class Sampler(abc.ABC): @abc.abstractmethod def should_sample( @@ -372,22 +407,37 @@ def __init__(self, rate: float): def _get_from_env_or_default() -> Sampler: - trace_sampler = os.getenv( + traces_sampler_name = os.getenv( OTEL_TRACES_SAMPLER, "parentbased_always_on" ).lower() - if trace_sampler not in _KNOWN_SAMPLERS: - _logger.warning("Couldn't recognize sampler %s.", trace_sampler) - trace_sampler = "parentbased_always_on" - - if trace_sampler in ("traceidratio", "parentbased_traceidratio"): - try: - rate = float(os.getenv(OTEL_TRACES_SAMPLER_ARG)) - except ValueError: - _logger.warning("Could not convert TRACES_SAMPLER_ARG to float.") - rate = 1.0 - return _KNOWN_SAMPLERS[trace_sampler](rate) - return _KNOWN_SAMPLERS[trace_sampler] + if traces_sampler_name in _KNOWN_SAMPLERS: + if traces_sampler_name in ("traceidratio", "parentbased_traceidratio"): + try: + rate = float(os.getenv(OTEL_TRACES_SAMPLER_ARG)) + except ValueError: + _logger.warning( + "Could not convert TRACES_SAMPLER_ARG to float." + ) + rate = 1.0 + return _KNOWN_SAMPLERS[traces_sampler_name](rate) + return _KNOWN_SAMPLERS[traces_sampler_name] + try: + traces_sampler_factory = _import_sampler_factory(traces_sampler_name) + sampler_arg = os.getenv(OTEL_TRACES_SAMPLER_ARG, "") + traces_sampler = traces_sampler_factory(sampler_arg) + if not isinstance(traces_sampler, Sampler): + message = f"Traces sampler factory, {traces_sampler_factory}, produced output, {traces_sampler}, which is not a Sampler object." + _logger.warning(message) + raise ValueError(message) + return traces_sampler + except Exception as exc: # pylint: disable=broad-except + _logger.warning( + "Using default sampler. Failed to initialize custom sampler, %s: %s", + traces_sampler_name, + exc, + ) + return _KNOWN_SAMPLERS["parentbased_always_on"] def _get_parent_trace_state(parent_context) -> Optional["TraceState"]: @@ -395,3 +445,10 @@ def _get_parent_trace_state(parent_context) -> Optional["TraceState"]: if parent_span_context is None or not parent_span_context.is_valid: return None return parent_span_context.trace_state + + +def _import_sampler_factory(sampler_name: str) -> Callable[[str], Sampler]: + _, sampler_impl = _import_config_components( + [sampler_name.strip()], _OTEL_SAMPLER_ENTRY_POINT_GROUP + )[0] + return sampler_impl diff --git a/opentelemetry-sdk/src/opentelemetry/sdk/util/__init__.py b/opentelemetry-sdk/src/opentelemetry/sdk/util/__init__.py index e1857d8e62d..52104243532 100644 --- a/opentelemetry-sdk/src/opentelemetry/sdk/util/__init__.py +++ b/opentelemetry-sdk/src/opentelemetry/sdk/util/__init__.py @@ -14,11 +14,11 @@ import datetime import threading -from collections import OrderedDict, deque -from collections.abc import MutableMapping, Sequence -from typing import Optional +from collections import OrderedDict, abc, deque +from typing import List, Optional, Sequence, Tuple from deprecated import deprecated +from pkg_resources import iter_entry_points def ns_to_iso_str(nanoseconds): @@ -41,7 +41,27 @@ def get_dict_as_key(labels): ) -class BoundedList(Sequence): +def _import_config_components( + selected_components: List[str], entry_point_name: str +) -> Sequence[Tuple[str, object]]: + component_entry_points = { + ep.name: ep for ep in iter_entry_points(entry_point_name) + } + component_impls = [] + for selected_component in selected_components: + entry_point = component_entry_points.get(selected_component, None) + if not entry_point: + raise RuntimeError( + f"Requested component '{selected_component}' not found in entry points for '{entry_point_name}'" + ) + + component_impl = entry_point.load() + component_impls.append((selected_component, component_impl)) + + return component_impls + + +class BoundedList(abc.Sequence): """An append only list with a fixed max size. Calls to `append` and `extend` will drop the oldest elements if there is @@ -92,7 +112,7 @@ def from_seq(cls, maxlen, seq): @deprecated(version="1.4.0") # type: ignore -class BoundedDict(MutableMapping): +class BoundedDict(abc.MutableMapping): """An ordered dict with a fixed max capacity. Oldest elements are dropped when the dict is full and a new element is diff --git a/opentelemetry-sdk/tests/test_configurator.py b/opentelemetry-sdk/tests/test_configurator.py index cf1b8253ddb..947ae623bc8 100644 --- a/opentelemetry-sdk/tests/test_configurator.py +++ b/opentelemetry-sdk/tests/test_configurator.py @@ -257,7 +257,7 @@ def test_trace_init_otlp(self): @patch.dict(environ, {OTEL_PYTHON_ID_GENERATOR: "custom_id_generator"}) @patch("opentelemetry.sdk._configuration.IdGenerator", new=IdGenerator) - @patch("opentelemetry.sdk._configuration.iter_entry_points") + @patch("opentelemetry.sdk.util.iter_entry_points") def test_trace_init_custom_id_generator(self, mock_iter_entry_points): mock_iter_entry_points.configure_mock( return_value=[ diff --git a/opentelemetry-sdk/tests/trace/test_trace.py b/opentelemetry-sdk/tests/trace/test_trace.py index 8b8d33faa45..3f4d0d0da1c 100644 --- a/opentelemetry-sdk/tests/trace/test_trace.py +++ b/opentelemetry-sdk/tests/trace/test_trace.py @@ -21,7 +21,7 @@ from logging import ERROR, WARNING from random import randint from time import time_ns -from typing import Optional +from typing import Optional, Sequence from unittest import mock from opentelemetry import trace as trace_api @@ -39,15 +39,27 @@ OTEL_TRACES_SAMPLER, OTEL_TRACES_SAMPLER_ARG, ) -from opentelemetry.sdk.trace import Resource, sampling +from opentelemetry.sdk.trace import Resource from opentelemetry.sdk.trace.id_generator import RandomIdGenerator +from opentelemetry.sdk.trace.sampling import ( + ALWAYS_OFF, + ALWAYS_ON, + Decision, + ParentBased, + Sampler, + SamplingResult, + StaticSampler, + TraceIdRatioBased, +) from opentelemetry.sdk.util import ns_to_iso_str from opentelemetry.sdk.util.instrumentation import InstrumentationInfo from opentelemetry.test.spantestutil import ( get_span_with_dropped_attributes_events_links, new_tracer, ) -from opentelemetry.trace import Status, StatusCode +from opentelemetry.trace import Link, SpanKind, Status, StatusCode +from opentelemetry.trace.span import TraceState +from opentelemetry.util.types import Attributes class TestTracer(unittest.TestCase): @@ -139,6 +151,78 @@ def test_tracer_provider_accepts_concurrent_multi_span_processor(self): ) +class CustomSampler(Sampler): + def __init__(self) -> None: + pass + + def get_description(self) -> str: + return "CustomSampler" + + def should_sample( + self, + parent_context: Optional["Context"], + trace_id: int, + name: str, + kind: SpanKind = None, + attributes: Attributes = None, + links: Sequence[Link] = None, + trace_state: TraceState = None, + ) -> "SamplingResult": + return SamplingResult( + Decision.RECORD_AND_SAMPLE, + None, + None, + ) + + +class CustomRatioSampler(TraceIdRatioBased): + def __init__(self, ratio): + self.ratio = ratio + super().__init__(ratio) + + def get_description(self) -> str: + return "CustomSampler" + + def should_sample( + self, + parent_context: Optional["Context"], + trace_id: int, + name: str, + kind: SpanKind = None, + attributes: Attributes = None, + links: Sequence[Link] = None, + trace_state: TraceState = None, + ) -> "SamplingResult": + return SamplingResult( + Decision.RECORD_AND_SAMPLE, + None, + None, + ) + + +class CustomSamplerFactory: + @staticmethod + def get_custom_sampler(unused_sampler_arg): + return CustomSampler() + + @staticmethod + def get_custom_ratio_sampler(sampler_arg): + return CustomRatioSampler(float(sampler_arg)) + + @staticmethod + def empty_get_custom_sampler(sampler_arg): + return + + +class IterEntryPoint: + def __init__(self, name, class_type): + self.name = name + self.class_type = class_type + + def load(self): + return self.class_type + + class TestTracerSampling(unittest.TestCase): def tearDown(self): reload(trace) @@ -165,12 +249,10 @@ def test_default_sampler(self): def test_default_sampler_type(self): tracer_provider = trace.TracerProvider() - self.assertIsInstance(tracer_provider.sampler, sampling.ParentBased) - # pylint: disable=protected-access - self.assertEqual(tracer_provider.sampler._root, sampling.ALWAYS_ON) + self.verify_default_sampler(tracer_provider) def test_sampler_no_sampling(self): - tracer_provider = trace.TracerProvider(sampling.ALWAYS_OFF) + tracer_provider = trace.TracerProvider(ALWAYS_OFF) tracer = tracer_provider.get_tracer(__name__) # Check that the default tracer creates no-op spans if the sampler @@ -194,10 +276,8 @@ def test_sampler_with_env(self): # pylint: disable=protected-access reload(trace) tracer_provider = trace.TracerProvider() - self.assertIsInstance(tracer_provider.sampler, sampling.StaticSampler) - self.assertEqual( - tracer_provider.sampler._decision, sampling.Decision.DROP - ) + self.assertIsInstance(tracer_provider.sampler, StaticSampler) + self.assertEqual(tracer_provider.sampler._decision, Decision.DROP) tracer = tracer_provider.get_tracer(__name__) @@ -216,9 +296,169 @@ def test_ratio_sampler_with_env(self): # pylint: disable=protected-access reload(trace) tracer_provider = trace.TracerProvider() - self.assertIsInstance(tracer_provider.sampler, sampling.ParentBased) + self.assertIsInstance(tracer_provider.sampler, ParentBased) self.assertEqual(tracer_provider.sampler._root.rate, 0.25) + @mock.patch.dict( + "os.environ", {OTEL_TRACES_SAMPLER: "non_existent_entry_point"} + ) + def test_sampler_with_env_non_existent_entry_point(self): + # pylint: disable=protected-access + reload(trace) + tracer_provider = trace.TracerProvider() + self.verify_default_sampler(tracer_provider) + + @mock.patch("opentelemetry.sdk.trace.util.iter_entry_points") + @mock.patch.dict( + "os.environ", {OTEL_TRACES_SAMPLER: "custom_sampler_factory"} + ) + def test_custom_sampler_with_env(self, mock_iter_entry_points): + mock_iter_entry_points.return_value = [ + IterEntryPoint( + "custom_sampler_factory", + CustomSamplerFactory.get_custom_sampler, + ) + ] + # pylint: disable=protected-access + reload(trace) + tracer_provider = trace.TracerProvider() + self.assertIsInstance(tracer_provider.sampler, CustomSampler) + + @mock.patch("opentelemetry.sdk.trace.util.iter_entry_points") + @mock.patch.dict( + "os.environ", {OTEL_TRACES_SAMPLER: "custom_sampler_factory"} + ) + def test_custom_sampler_with_env_bad_factory(self, mock_iter_entry_points): + mock_iter_entry_points.return_value = [ + IterEntryPoint( + "custom_sampler_factory", + CustomSamplerFactory.empty_get_custom_sampler, + ) + ] + # pylint: disable=protected-access + reload(trace) + tracer_provider = trace.TracerProvider() + self.verify_default_sampler(tracer_provider) + + @mock.patch("opentelemetry.sdk.trace.util.iter_entry_points") + @mock.patch.dict( + "os.environ", + { + OTEL_TRACES_SAMPLER: "custom_sampler_factory", + OTEL_TRACES_SAMPLER_ARG: "0.5", + }, + ) + def test_custom_sampler_with_env_unused_arg(self, mock_iter_entry_points): + mock_iter_entry_points.return_value = [ + IterEntryPoint( + "custom_sampler_factory", + CustomSamplerFactory.get_custom_sampler, + ) + ] + # pylint: disable=protected-access + reload(trace) + tracer_provider = trace.TracerProvider() + self.assertIsInstance(tracer_provider.sampler, CustomSampler) + + @mock.patch("opentelemetry.sdk.trace.util.iter_entry_points") + @mock.patch.dict( + "os.environ", + { + OTEL_TRACES_SAMPLER: "custom_ratio_sampler_factory", + OTEL_TRACES_SAMPLER_ARG: "0.5", + }, + ) + def test_custom_ratio_sampler_with_env(self, mock_iter_entry_points): + mock_iter_entry_points.return_value = [ + IterEntryPoint( + "custom_ratio_sampler_factory", + CustomSamplerFactory.get_custom_ratio_sampler, + ) + ] + # pylint: disable=protected-access + reload(trace) + tracer_provider = trace.TracerProvider() + self.assertIsInstance(tracer_provider.sampler, CustomRatioSampler) + self.assertEqual(tracer_provider.sampler.ratio, 0.5) + + @mock.patch("opentelemetry.sdk.trace.util.iter_entry_points") + @mock.patch.dict( + "os.environ", + { + OTEL_TRACES_SAMPLER: "custom_ratio_sampler_factory", + OTEL_TRACES_SAMPLER_ARG: "foobar", + }, + ) + def test_custom_ratio_sampler_with_env_bad_arg( + self, mock_iter_entry_points + ): + mock_iter_entry_points.return_value = [ + IterEntryPoint( + "custom_ratio_sampler_factory", + CustomSamplerFactory.get_custom_ratio_sampler, + ) + ] + # pylint: disable=protected-access + reload(trace) + tracer_provider = trace.TracerProvider() + self.verify_default_sampler(tracer_provider) + + @mock.patch("opentelemetry.sdk.trace.util.iter_entry_points") + @mock.patch.dict( + "os.environ", + { + OTEL_TRACES_SAMPLER: "custom_ratio_sampler_factory", + }, + ) + def test_custom_ratio_sampler_with_env_no_arg( + self, mock_iter_entry_points + ): + mock_iter_entry_points.return_value = [ + IterEntryPoint( + "custom_ratio_sampler_factory", + CustomSamplerFactory.get_custom_ratio_sampler, + ) + ] + # pylint: disable=protected-access + reload(trace) + tracer_provider = trace.TracerProvider() + self.verify_default_sampler(tracer_provider) + + @mock.patch("opentelemetry.sdk.trace.util.iter_entry_points") + @mock.patch.dict( + "os.environ", + { + OTEL_TRACES_SAMPLER: "custom_sampler_factory", + OTEL_TRACES_SAMPLER_ARG: "0.5", + }, + ) + def test_custom_ratio_sampler_with_env_multiple_entry_points( + self, mock_iter_entry_points + ): + mock_iter_entry_points.return_value = [ + IterEntryPoint( + "custom_ratio_sampler_factory", + CustomSamplerFactory.get_custom_ratio_sampler, + ), + IterEntryPoint( + "custom_sampler_factory", + CustomSamplerFactory.get_custom_sampler, + ), + IterEntryPoint( + "custom_z_sampler_factory", + CustomSamplerFactory.empty_get_custom_sampler, + ), + ] + # pylint: disable=protected-access + reload(trace) + tracer_provider = trace.TracerProvider() + self.assertIsInstance(tracer_provider.sampler, CustomSampler) + + def verify_default_sampler(self, tracer_provider): + self.assertIsInstance(tracer_provider.sampler, ParentBased) + # pylint: disable=protected-access + self.assertEqual(tracer_provider.sampler._root, ALWAYS_ON) + class TestSpanCreation(unittest.TestCase): def test_start_span_invalid_spancontext(self): @@ -712,7 +952,7 @@ def test_sampling_attributes(self): "attr-in-both": "decision-attr", } tracer_provider = trace.TracerProvider( - sampling.StaticSampler(sampling.Decision.RECORD_AND_SAMPLE) + StaticSampler(Decision.RECORD_AND_SAMPLE) ) self.tracer = tracer_provider.get_tracer(__name__)