Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make baggage implementation w3c spec complaint #2167

Merged
merged 25 commits into from
Oct 13, 2021
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 50 additions & 11 deletions opentelemetry-api/src/opentelemetry/baggage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import typing
from logging import getLogger
from re import compile
from types import MappingProxyType
from typing import Mapping, Optional

from opentelemetry.context import create_key, get_value, set_value
from opentelemetry.context.context import Context
from opentelemetry.util.re import (
_KEY_FORMAT,
_VALUE_FORMAT,
_BAGGAGE_PROPERTY_FORMAT,
)


_BAGGAGE_KEY = create_key("baggage")
_logger = getLogger(__name__)

_KEY_PATTERN = compile(_KEY_FORMAT)
_VALUE_PATTERN = compile(_VALUE_FORMAT)
_PROPERT_PATTERN = compile(_BAGGAGE_PROPERTY_FORMAT)


def get_all(
context: typing.Optional[Context] = None,
) -> typing.Mapping[str, object]:
context: Optional[Context] = None,
) -> Mapping[str, object]:
"""Returns the name/value pairs in the Baggage

Args:
Expand All @@ -39,8 +52,8 @@ def get_all(


def get_baggage(
name: str, context: typing.Optional[Context] = None
) -> typing.Optional[object]:
name: str, context: Optional[Context] = None
) -> Optional[object]:
"""Provides access to the value for a name/value pair in the
Baggage

Expand All @@ -56,7 +69,7 @@ def get_baggage(


def set_baggage(
name: str, value: object, context: typing.Optional[Context] = None
name: str, value: object, context: Optional[Context] = None
) -> Context:
"""Sets a value in the Baggage

Expand All @@ -69,13 +82,20 @@ def set_baggage(
A Context with the value updated
"""
baggage = dict(get_all(context=context))
baggage[name] = value
if not _is_valid_key(name):
_logger.warning(
"Baggage key `%s` does not match format, ignoring", name
)
elif not _is_valid_value(str(value)):
_logger.warning(
"Baggage value `%s` does not match format, ignorig", value
srikanthccv marked this conversation as resolved.
Show resolved Hide resolved
)
else:
baggage[name] = value
return set_value(_BAGGAGE_KEY, baggage, context=context)


def remove_baggage(
name: str, context: typing.Optional[Context] = None
) -> Context:
def remove_baggage(name: str, context: Optional[Context] = None) -> Context:
"""Removes a value from the Baggage

Args:
Expand All @@ -91,7 +111,7 @@ def remove_baggage(
return set_value(_BAGGAGE_KEY, baggage, context=context)


def clear(context: typing.Optional[Context] = None) -> Context:
def clear(context: Optional[Context] = None) -> Context:
"""Removes all values from the Baggage

Args:
Expand All @@ -101,3 +121,22 @@ def clear(context: typing.Optional[Context] = None) -> Context:
A Context with all baggage entries removed
"""
return set_value(_BAGGAGE_KEY, {}, context=context)


def _is_valid_key(name: str) -> bool:
return _KEY_PATTERN.fullmatch(name) is not None


def _is_valid_value(value: str) -> bool:
parts = value.split(";")
is_valid_value = _VALUE_PATTERN.fullmatch(parts[0]) is not None
if len(parts) > 1: # one or more properties metadata
for property in parts[1:]:
if _PROPERT_PATTERN.fullmatch(property) is None:
is_valid_value = False
break
return is_valid_value


def _is_valid_pair(key: str, value: str) -> bool:
return _is_valid_key(key) and _is_valid_value(value)
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import typing
from urllib.parse import quote_plus, unquote_plus

from opentelemetry.baggage import get_all, set_baggage
from logging import getLogger
from re import compile, split
from typing import Iterable, Mapping, Optional, Set

from opentelemetry.baggage import get_all, set_baggage, _is_valid_pair
from opentelemetry.context import get_current
from opentelemetry.context.context import Context
from opentelemetry.util.re import _DELIMITER_PATTERN
from opentelemetry.propagators import textmap

_logger = getLogger(__name__)


class W3CBaggagePropagator(textmap.TextMapPropagator):
"""Extracts and injects Baggage which is used to annotate telemetry."""
Expand All @@ -32,7 +38,7 @@ class W3CBaggagePropagator(textmap.TextMapPropagator):
def extract(
self,
carrier: textmap.CarrierT,
context: typing.Optional[Context] = None,
context: Optional[Context] = None,
getter: textmap.Getter = textmap.default_getter,
) -> Context:
"""Extract Baggage from the carrier.
Expand All @@ -49,32 +55,59 @@ def extract(
)

if not header or len(header) > self._MAX_HEADER_LENGTH:
_logger.warning(
"Baggage header `%s` exceeded the maximum number of bytes per baggage-string.",
header,
)
return context

baggage_entries = split(_DELIMITER_PATTERN, header)

if len(baggage_entries) > self._MAX_PAIRS:
_logger.warning(
"Baggage header `%s` exceeded the maximum number of list-members",
header,
)
return context

baggage_entries = header.split(",")
total_baggage_entries = self._MAX_PAIRS
entries = []
for entry in baggage_entries:
if len(entry) > self._MAX_PAIR_LENGTH:
_logger.warning(
"Baggage entry `%s` exceeded the maximum number of bytes per list-member",
entry,
)
return context
srikanthccv marked this conversation as resolved.
Show resolved Hide resolved
if not entry: # empty string
continue
try:
name, value = entry.split("=", 1)
srikanthccv marked this conversation as resolved.
Show resolved Hide resolved
except Exception: # pylint: disable=broad-except
continue
_logger.warning(
"Baggage list-member doesn't match the format: `%s`", entry
)
return context
name = unquote_plus(name).strip().lower()
value = unquote_plus(value).strip()
if not _is_valid_pair(name, value):
_logger.warning("Invalid baggage entry: `%s`", entry)
return context

entries.append((name, value))

for name, value in entries:
context = set_baggage(
unquote_plus(name).strip(),
unquote_plus(value).strip(),
name,
value,
context=context,
)
total_baggage_entries -= 1
if total_baggage_entries == 0:
break

return context
return context # type: ignore

def inject(
self,
carrier: textmap.CarrierT,
context: typing.Optional[Context] = None,
context: Optional[Context] = None,
setter: textmap.Setter = textmap.default_setter,
) -> None:
"""Injects Baggage into the carrier.
Expand All @@ -90,21 +123,21 @@ def inject(
setter.set(carrier, self._BAGGAGE_HEADER_NAME, baggage_string)

@property
def fields(self) -> typing.Set[str]:
def fields(self) -> Set[str]:
"""Returns a set with the fields set in `inject`."""
return {self._BAGGAGE_HEADER_NAME}


def _format_baggage(baggage_entries: typing.Mapping[str, object]) -> str:
def _format_baggage(baggage_entries: Mapping[str, object]) -> str:
return ",".join(
quote_plus(str(key)) + "=" + quote_plus(str(value))
for key, value in baggage_entries.items()
)


def _extract_first_element(
items: typing.Optional[typing.Iterable[textmap.CarrierT]],
) -> typing.Optional[textmap.CarrierT]:
items: Optional[Iterable[textmap.CarrierT]],
) -> Optional[textmap.CarrierT]:
if items is None:
return None
return next(iter(items), None)
5 changes: 3 additions & 2 deletions opentelemetry-api/src/opentelemetry/util/re.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@
)
# A value contains a URL encoded UTF-8 string.
_VALUE_FORMAT = r"[\x21\x23-\x2b\x2d-\x3a\x3c-\x5b\x5d-\x7e]*"
_HEADER_FORMAT = _KEY_FORMAT + _OWS + r"=" + _OWS + _VALUE_FORMAT
_HEADER_PATTERN = compile(_HEADER_FORMAT)
_KEY_VALUE_FORMAT = rf"{_OWS}{_KEY_FORMAT}{_OWS}={_OWS}{_VALUE_FORMAT}{_OWS}"
_HEADER_PATTERN = compile(_KEY_VALUE_FORMAT)
_DELIMITER_PATTERN = compile(r"[ \t]*,[ \t]*")

_BAGGAGE_PROPERTY_FORMAT = rf"{_KEY_VALUE_FORMAT}|{_OWS}{_KEY_FORMAT}{_OWS}"

# pylint: disable=invalid-name
def parse_headers(s: str) -> Mapping[str, str]:
Expand Down
36 changes: 23 additions & 13 deletions opentelemetry-api/tests/baggage/test_baggage_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import unittest
from unittest.mock import Mock, patch
from logging import WARNING

from opentelemetry import baggage
from opentelemetry.baggage.propagation import (
Expand Down Expand Up @@ -66,25 +67,32 @@ def test_valid_header_with_properties(self):
expected = {"key1": "val1", "key2": "val2;prop=1"}
self.assertEqual(self._extract(header), expected)

def test_valid_header_with_url_escaped_comma(self):
header = "key%2C1=val1,key2=val2%2Cval3"
expected = {"key,1": "val1", "key2": "val2,val3"}
def test_valid_header_with_url_escaped_values(self):
header = "key%2C1=val1,key2=val2%3Aval3,key3=val4%40%23%24val5"
expected = {
"key,1": "val1",
"key2": "val2:val3",
"key3": "val4@#$val5",
}
self.assertEqual(self._extract(header), expected)

def test_valid_header_with_invalid_value(self):
def test_header_with_invalid_value(self):
header = "key1=val1,key2=val2,a,val3"
expected = {"key1": "val1", "key2": "val2"}
self.assertEqual(self._extract(header), expected)
with self.assertLogs(level=WARNING) as warning:
self._extract(header)
self.assertIn(
"Baggage list-member doesn't match the format",
warning.output[0],
)

def test_valid_header_with_empty_value(self):
header = "key1=,key2=val2"
expected = {"key1": "", "key2": "val2"}
self.assertEqual(self._extract(header), expected)

def test_invalid_header(self):
header = "header1"
expected = {}
self.assertEqual(self._extract(header), expected)
self.assertEqual(self._extract("header1"), {})
self.assertEqual(self._extract(" = "), {})

def test_header_too_long(self):
long_value = "s" * (W3CBaggagePropagator._MAX_HEADER_LENGTH + 1)
Expand All @@ -111,11 +119,11 @@ def test_header_contains_pair_too_long(self):

def test_extract_unquote_plus(self):
self.assertEqual(
self._extract("key+key=value+value"), {"key key": "value value"}
self._extract("keykey=value%5Evalue"), {"keykey": "value^value"}
)
self.assertEqual(
self._extract("key%2Fkey=value%2Fvalue"),
{"key/key": "value/value"},
self._extract("key%23key=value%23value"),
{"key#key": "value#value"},
)

def test_header_max_entries_skip_invalid_entry(self):
Expand Down Expand Up @@ -163,6 +171,9 @@ def test_inject_no_baggage_entries(self):
output = self._inject(values)
self.assertEqual(None, output)

def test_inject_invalid_entries(self):
self.assertEqual(None, self._inject({"key": "val ue"}))

def test_inject(self):
values = {
"key1": "val1",
Expand All @@ -178,7 +189,6 @@ def test_inject_escaped_values(self):
"key2": "val3=4",
}
output = self._inject(values)
self.assertIn("key1=val1%2Cval2", output)
self.assertIn("key2=val3%3D4", output)

def test_inject_non_string_values(self):
Expand Down