From f8f581f3b1fa9a8d750428a83a9a1f966ad26410 Mon Sep 17 00:00:00 2001 From: Daiyi Peng Date: Tue, 24 Sep 2024 12:54:30 -0700 Subject: [PATCH] Updates on `pg.format`. 1) Delegates `pg.Formattable.__str__()/__repr__()` to `pg.format`, which allows more top-level features to be supported without involving subclass formatting logics. E.g. (`custom_format` and `markdown` arguments). 2) Upgrades the `custom_format` argument of `pg.format` with `Callable` type: This allows users to plugin custom formatting logics more flexibly without intrusively adding method to existing types. 3) Add `pg.Formatting.__str_kwargs__` and `pg.Formatting.__repr_kwargs__` to allow subclasses to override the kwargs arbitration process. PiperOrigin-RevId: 678361797 --- langfun/core/eval/matching.py | 11 +++++++++-- langfun/core/modality.py | 7 +++++++ langfun/core/repr_utils.py | 11 +++++++++-- 3 files changed, 25 insertions(+), 4 deletions(-) diff --git a/langfun/core/eval/matching.py b/langfun/core/eval/matching.py index 9f67a44..9b10097 100644 --- a/langfun/core/eval/matching.py +++ b/langfun/core/eval/matching.py @@ -266,20 +266,27 @@ def _render_matches(self, s: io.StringIO) -> None: 'Prompt/Response Chain' '' ) + def _maybe_html(v, root_indent: int): + del root_indent + if hasattr(v, '_repr_html_'): + return v._repr_html_() # pylint: disable=protected-access + # Fall back to the default format. + return None + for i, (example, output, message) in enumerate(self.matches): bgcolor = 'white' if i % 2 == 0 else '#DDDDDD' s.write(f'{i + 1}') input_str = lf.repr_utils.escape_quoted( pg.format( example, verbose=False, max_bytes_len=32, - custom_format='_repr_html_' + custom_format=_maybe_html ) ) s.write(f'{input_str}') output_str = lf.repr_utils.escape_quoted( pg.format( output, verbose=False, max_bytes_len=32, - custom_format='_repr_html_' + custom_format=_maybe_html ) ) s.write(f'{output_str}') diff --git a/langfun/core/modality.py b/langfun/core/modality.py index 5fe89c7..0addf11 100644 --- a/langfun/core/modality.py +++ b/langfun/core/modality.py @@ -49,6 +49,13 @@ def format(self, *args, **kwargs) -> str: return super().format(*args, **kwargs) return Modality.text_marker(self.referred_name) + def __str_kwargs__(self) -> dict[str, Any]: + # For modality objects, we don't want to use markdown format when they + # are rendered as parts of the prompt. + kwargs = super().__str_kwargs__() + kwargs.pop('markdown', None) + return kwargs + @abc.abstractmethod def to_bytes(self) -> bytes: """Returns content in bytes.""" diff --git a/langfun/core/repr_utils.py b/langfun/core/repr_utils.py index 7390b25..96e6bf5 100644 --- a/langfun/core/repr_utils.py +++ b/langfun/core/repr_utils.py @@ -121,8 +121,15 @@ def html_repr( s.write('') item_color = item_color or (lambda k, v: (None, '#F1C40F', None, None)) - with (pg.str_format(custom_format='_repr_html_'), - pg.repr_format(custom_format='_repr_html_')): + def maybe_html_format(v: Any, root_indent: int) -> str | None: + del root_indent + if hasattr(v, '_repr_html_'): + return v._repr_html_() # pylint: disable=protected-access + # Fall back to the default format. + return None + + with (pg.str_format(custom_format=maybe_html_format), + pg.repr_format(custom_format=maybe_html_format)): for k, v in pg.object_utils.flatten(value).items(): if isinstance(v, pg.Ref): v = v.value