From 5bc4d86deed23b5e8eb939df4014338226f40ac6 Mon Sep 17 00:00:00 2001 From: Leon Derczynski Date: Thu, 11 Jul 2024 18:23:43 +0200 Subject: [PATCH] defensive coding: allow python generators more places (#782) * defensive coding: allow generators more places * update workflow to treat generators more defensively, casting to list if there's a risk of multiple consumption --- garak/attempt.py | 7 +++- garak/evaluators/base.py | 10 +++-- garak/harnesses/base.py | 4 +- tests/test_internal_structures.py | 63 +++++++++++++++++++++++++++++++ 4 files changed, 78 insertions(+), 6 deletions(-) create mode 100644 tests/test_internal_structures.py diff --git a/garak/attempt.py b/garak/attempt.py index 84fd684d..24b14d01 100644 --- a/garak/attempt.py +++ b/garak/attempt.py @@ -1,5 +1,7 @@ """Defines the Attempt class, which encapsulates a prompt with metadata and results""" +from collections.abc import Iterable +from types import GeneratorType from typing import Any, List import uuid @@ -179,8 +181,9 @@ def __setattr__(self, name: str, value: Any) -> None: self._add_first_turn("user", value) elif name == "outputs": - if not isinstance(value, list): - raise TypeError("Value for attempt.outputs must be a list") + if not (isinstance(value, list) or isinstance(value, GeneratorType)): + raise TypeError("Value for attempt.outputs must be a list or generator") + value = list(value) if len(self.messages) == 0: raise TypeError("A prompt must be set before outputs are given") # do we have only the initial prompt? in which case, let's flesh out messages a bit diff --git a/garak/evaluators/base.py b/garak/evaluators/base.py index 879b2a19..6152e695 100644 --- a/garak/evaluators/base.py +++ b/garak/evaluators/base.py @@ -5,7 +5,7 @@ import json import logging -from typing import List +from typing import Iterable from colorama import Fore, Style @@ -33,19 +33,23 @@ def test(self, test_value: float) -> bool: """ return False # fail everything by default - def evaluate(self, attempts: List[garak.attempt.Attempt]) -> None: + def evaluate(self, attempts: Iterable[garak.attempt.Attempt]) -> None: """ evaluate feedback from detectors expects a list of attempts that correspond to one probe outputs results once per detector """ - if len(attempts) == 0: + if isinstance(attempts, list) and len(attempts) == 0: logging.debug( "evaluators.base.Evaluator.evaluate called with 0 attempts, expected 1+" ) return + attempts = list( + attempts + ) # disprefer this but getting detector_names from first one for the loop below is a pain + self.probename = attempts[0].probe_classname detector_names = attempts[0].detector_results.keys() diff --git a/garak/harnesses/base.py b/garak/harnesses/base.py index 75f3a3db..00be58ba 100644 --- a/garak/harnesses/base.py +++ b/garak/harnesses/base.py @@ -116,7 +116,9 @@ def run(self, model, probes, detectors, evaluator, announce_probe=True) -> None: detector_probe_name = d.detectorname.replace("garak.detectors.", "") attempt_iterator.set_description("detectors." + detector_probe_name) for attempt in attempt_iterator: - attempt.detector_results[detector_probe_name] = d.detect(attempt) + attempt.detector_results[detector_probe_name] = list( + d.detect(attempt) + ) if first_detector: eval_outputs += attempt.outputs diff --git a/tests/test_internal_structures.py b/tests/test_internal_structures.py new file mode 100644 index 00000000..87b4b999 --- /dev/null +++ b/tests/test_internal_structures.py @@ -0,0 +1,63 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from collections.abc import Iterable +import importlib +import tempfile + +import pytest + +import garak._config +import garak._plugins +import garak.attempt +import garak.evaluators.base +import garak.generators.test + +# probes should be able to return a generator of attempts +# -> probes.base.Probe._execute_all (1) should be able to consume a generator of attempts +# generators should be able to return a generator of outputs +# -> attempts (2) should be able to consume a generator of outputs +# detectors should be able to return generators of results +# -> evaluators (3) should be able to consume generators of results --> enforced in harness; cast to list, multiple consumption + + + +@pytest.fixture(autouse=True) +def _config_loaded(): + importlib.reload(garak._config) + garak._config.load_base_config() + temp_report_file = tempfile.NamedTemporaryFile(mode="w+") + garak._config.transient.reportfile = temp_report_file + garak._config.transient.report_filename = temp_report_file.name + yield + temp_report_file.close() + + +def test_generator_consume_attempt_generator(): + count = 5 + attempts = (garak.attempt.Attempt(prompt=str(i)) for i in range(count)) + p = garak._plugins.load_plugin("probes.test.Blank") + g = garak._plugins.load_plugin("generators.test.Blank") + p.generator = g + results = p._execute_all(attempts) + + assert isinstance(results, Iterable), "_execute_all should return an Iterable" + result_len = 0 + for _attempt in results: + assert isinstance( + _attempt, garak.attempt.Attempt + ), "_execute_all should return attempts" + result_len += 1 + assert ( + result_len == count + ), "there should be the same number of attempts in the passed generator as results returned in _execute_all" + +def test_attempt_outputs_can_consume_generator(): + a = garak.attempt.Attempt(prompt="fish") + count = 5 + str_iter = ("abc" for _ in range(count)) + a.outputs = str_iter + outputs_list = list(a.outputs) + assert len(outputs_list) == count, "attempt.outputs should have same cardinality as generator used to populate it" + assert len(list(a.outputs)) == len(outputs_list), "attempt.outputs should have the same cardinality every time" +