diff --git a/opentelemetry-sdk/CHANGELOG.md b/opentelemetry-sdk/CHANGELOG.md index d6a7828d64..2d12e54d98 100644 --- a/opentelemetry-sdk/CHANGELOG.md +++ b/opentelemetry-sdk/CHANGELOG.md @@ -20,6 +20,8 @@ Released 2020-11-02 ([#1289](https://github.com/open-telemetry/opentelemetry-python/pull/1289)) - Set initial checkpoint timestamp in aggregators ([#1237](https://github.com/open-telemetry/opentelemetry-python/pull/1237)) +- Allow samplers to modify tracestate + ([#1319](https://github.com/open-telemetry/opentelemetry-python/pull/1319)) - Remove TracerProvider coupling from Tracer init ([#1295](https://github.com/open-telemetry/opentelemetry-python/pull/1295)) diff --git a/opentelemetry-sdk/src/opentelemetry/sdk/trace/__init__.py b/opentelemetry-sdk/src/opentelemetry/sdk/trace/__init__.py index a350e3687a..c0189e807e 100644 --- a/opentelemetry-sdk/src/opentelemetry/sdk/trace/__init__.py +++ b/opentelemetry-sdk/src/opentelemetry/sdk/trace/__init__.py @@ -780,7 +780,7 @@ def start_span( # pylint: disable=too-many-locals # The sampler may also add attributes to the newly-created span, e.g. # to include information about the sampling result. sampling_result = self.sampler.should_sample( - context, trace_id, name, attributes, links, + context, trace_id, name, attributes, links, trace_state ) trace_flags = ( @@ -793,7 +793,7 @@ def start_span( # pylint: disable=too-many-locals self.ids_generator.generate_span_id(), is_remote=False, trace_flags=trace_flags, - trace_state=trace_state, + trace_state=sampling_result.trace_state, ) # Only record if is_recording() is true diff --git a/opentelemetry-sdk/src/opentelemetry/sdk/trace/sampling.py b/opentelemetry-sdk/src/opentelemetry/sdk/trace/sampling.py index dbb12d5d63..ffa51506ff 100644 --- a/opentelemetry-sdk/src/opentelemetry/sdk/trace/sampling.py +++ b/opentelemetry-sdk/src/opentelemetry/sdk/trace/sampling.py @@ -68,6 +68,7 @@ # pylint: disable=unused-import from opentelemetry.context import Context from opentelemetry.trace import Link, get_current_span +from opentelemetry.trace.span import TraceState from opentelemetry.util.types import Attributes @@ -93,6 +94,8 @@ class SamplingResult: decision: A sampling decision based off of whether the span is recorded and the sampled flag in trace flags in the span context. attributes: Attributes to add to the `opentelemetry.trace.Span`. + trace_state: The tracestate used for the `opentelemetry.trace.Span`. + Could possibly have been modified by the sampler. """ def __repr__(self) -> str: @@ -101,13 +104,17 @@ def __repr__(self) -> str: ) def __init__( - self, decision: Decision, attributes: Attributes = None, + self, + decision: Decision, + attributes: "Attributes" = None, + trace_state: "TraceState" = None, ) -> None: self.decision = decision if attributes is None: self.attributes = MappingProxyType({}) else: self.attributes = MappingProxyType(attributes) + self.trace_state = trace_state class Sampler(abc.ABC): @@ -119,6 +126,7 @@ def should_sample( name: str, attributes: Attributes = None, links: Sequence["Link"] = None, + trace_state: "TraceState" = None, ) -> "SamplingResult": pass @@ -140,10 +148,11 @@ def should_sample( name: str, attributes: Attributes = None, links: Sequence["Link"] = None, + trace_state: "TraceState" = None, ) -> "SamplingResult": if self._decision is Decision.DROP: return SamplingResult(self._decision) - return SamplingResult(self._decision, attributes) + return SamplingResult(self._decision, attributes, trace_state) def get_description(self) -> str: if self._decision is Decision.DROP: @@ -194,6 +203,7 @@ def should_sample( name: str, attributes: Attributes = None, links: Sequence["Link"] = None, + trace_state: "TraceState" = None, ) -> "SamplingResult": decision = Decision.DROP if trace_id & self.TRACE_ID_LIMIT < self.bound: @@ -226,6 +236,7 @@ def should_sample( name: str, attributes: Attributes = None, links: Sequence["Link"] = None, + trace_state: "TraceState" = None, ) -> "SamplingResult": if parent_context is not None: parent_span_context = get_current_span( @@ -246,6 +257,7 @@ def should_sample( name=name, attributes=attributes, links=links, + trace_state=trace_state, ) def get_description(self): diff --git a/opentelemetry-sdk/tests/trace/test_sampling.py b/opentelemetry-sdk/tests/trace/test_sampling.py index fad5816ca8..d51a59c106 100644 --- a/opentelemetry-sdk/tests/trace/test_sampling.py +++ b/opentelemetry-sdk/tests/trace/test_sampling.py @@ -47,14 +47,18 @@ def test_is_sampled(self): class TestSamplingResult(unittest.TestCase): def test_ctr(self): attributes = {"asd": "test"} + trace_state = dict() + # pylint: disable=E1137 + trace_state["test"] = "123" result = sampling.SamplingResult( - sampling.Decision.RECORD_ONLY, attributes + sampling.Decision.RECORD_ONLY, attributes, trace_state ) self.assertIs(result.decision, sampling.Decision.RECORD_ONLY) with self.assertRaises(TypeError): result.attributes["test"] = "mess-this-up" self.assertTrue(len(result.attributes), 1) self.assertEqual(result.attributes["asd"], "test") + self.assertEqual(result.trace_state["test"], "123") class TestSampler(unittest.TestCase):