From 284e73bf103ad9c5049394c236ace4432b244ad2 Mon Sep 17 00:00:00 2001 From: Andres Cruz Date: Fri, 20 Sep 2024 17:47:56 +0200 Subject: [PATCH 1/2] OPIK-58 Add trace usage --- .../main/java/com/comet/opik/api/Trace.java | 2 + .../opik/domain/FeedbackScoreMapper.java | 3 + .../java/com/comet/opik/domain/TraceDAO.java | 107 +++- .../resources/v1/priv/TracesResourceTest.java | 585 ++++++++++-------- 4 files changed, 418 insertions(+), 279 deletions(-) diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/Trace.java b/apps/opik-backend/src/main/java/com/comet/opik/api/Trace.java index 2e7edd06..8acbab69 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/Trace.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/Trace.java @@ -13,6 +13,7 @@ import java.time.Instant; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.UUID; @@ -35,6 +36,7 @@ public record Trace( @JsonView({Trace.View.Public.class, Trace.View.Write.class}) JsonNode output, @JsonView({Trace.View.Public.class, Trace.View.Write.class}) JsonNode metadata, @JsonView({Trace.View.Public.class, Trace.View.Write.class}) Set tags, + @JsonView({Trace.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Map usage, @JsonView({Trace.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant createdAt, @JsonView({Trace.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant lastUpdatedAt, @JsonView({Trace.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String createdBy, diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/FeedbackScoreMapper.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/FeedbackScoreMapper.java index 8add7adb..12fbcd54 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/FeedbackScoreMapper.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/FeedbackScoreMapper.java @@ -6,6 +6,7 @@ import org.mapstruct.Mapping; import org.mapstruct.factory.Mappers; +import java.util.List; import java.util.UUID; @Mapper @@ -15,6 +16,8 @@ public interface FeedbackScoreMapper { FeedbackScore toFeedbackScore(FeedbackScoreBatchItem feedbackScoreBatchItem); + List toFeedbackScores(List feedbackScoreBatchItems); + @Mapping(target = "id", source = "entityId") FeedbackScoreBatchItem toFeedbackScoreBatchItem(UUID entityId, String projectName, FeedbackScore feedbackScore); } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/TraceDAO.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/TraceDAO.java index bf967efb..9aa9182f 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/TraceDAO.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/TraceDAO.java @@ -29,6 +29,7 @@ import java.time.Instant; import java.util.Arrays; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.UUID; @@ -238,17 +239,77 @@ INSERT INTO traces ( private static final String SELECT_BY_ID = """ SELECT - * - FROM - traces - WHERE id = :id - AND workspace_id = :workspace_id - ORDER BY last_updated_at DESC - LIMIT 1 + t.id, + t.workspace_id, + t.project_id, + t.name, + t.start_time, + t.end_time, + t.input, + t.output, + t.metadata, + t.tags, + t.created_at, + t.last_updated_at, + t.created_by, + t.last_updated_by, + sumMap(s.usage) as usage + FROM ( + SELECT + * + FROM traces + WHERE workspace_id = :workspace_id + AND id = :id + ORDER BY id DESC, last_updated_at DESC + LIMIT 1 BY id + ) AS t + LEFT JOIN ( + SELECT + trace_id, + usage + FROM spans + WHERE workspace_id = :workspace_id + AND trace_id = :id + ORDER BY id DESC, last_updated_at DESC + LIMIT 1 BY id + ) AS s ON t.id = s.trace_id + GROUP BY + t.id, + t.workspace_id, + t.project_id, + t.name, + t.start_time, + t.end_time, + t.input, + t.output, + t.metadata, + t.tags, + t.created_at, + t.last_updated_at, + t.created_by, + t.last_updated_by + ORDER BY t.id DESC ; """; private static final String SELECT_BY_PROJECT_ID = """ + SELECT + t.id, + t.workspace_id, + t.project_id, + t.name, + t.start_time, + t.end_time, + t.input, + t.output, + t.metadata, + t.tags, + t.created_at, + t.last_updated_at, + t.created_by, + t.last_updated_by, + sumMap(s.usage) as usage + FROM ( SELECT * FROM traces @@ -263,6 +324,7 @@ AND id in ( SELECT * FROM feedback_scores WHERE entity_type = 'trace' + AND workspace_id = :workspace_id AND project_id = :project_id ORDER BY entity_id DESC, last_updated_at DESC LIMIT 1 BY entity_id, name @@ -274,6 +336,33 @@ AND id in ( ORDER BY id DESC, last_updated_at DESC LIMIT 1 BY id LIMIT :limit OFFSET :offset + ) AS t + LEFT JOIN ( + SELECT + trace_id, + usage + FROM spans + WHERE workspace_id = :workspace_id + AND project_id = :project_id + ORDER BY id DESC, last_updated_at DESC + LIMIT 1 BY id + ) AS s ON t.id = s.trace_id + GROUP BY + t.id, + t.workspace_id, + t.project_id, + t.name, + t.start_time, + t.end_time, + t.input, + t.output, + t.metadata, + t.tags, + t.created_at, + t.last_updated_at, + t.created_by, + t.last_updated_by + ORDER BY t.id DESC ; """; @@ -306,6 +395,7 @@ AND id in ( SELECT * FROM feedback_scores WHERE entity_type = 'trace' + AND workspace_id = :workspace_id AND project_id = :project_id ORDER BY entity_id DESC, last_updated_at DESC LIMIT 1 BY entity_id, name @@ -314,7 +404,7 @@ AND id in ( HAVING ) - ORDER BY last_updated_at DESC + ORDER BY id DESC, last_updated_at DESC LIMIT 1 BY id ) AS latest_rows ; @@ -627,6 +717,7 @@ private Publisher mapToDto(Result result) { .collect(Collectors.toSet())) .filter(it -> !it.isEmpty()) .orElse(null)) + .usage(row.get("usage", Map.class)) .createdAt(row.get("created_at", Instant.class)) .lastUpdatedAt(row.get("last_updated_at", Instant.class)) .createdBy(row.get("created_by", String.class)) diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/TracesResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/TracesResourceTest.java index 4f822c9c..931ce8c3 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/TracesResourceTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/TracesResourceTest.java @@ -43,7 +43,6 @@ import jakarta.ws.rs.core.MediaType; import jakarta.ws.rs.core.Response; import org.apache.commons.lang3.RandomStringUtils; -import org.assertj.core.api.recursive.comparison.RecursiveComparisonConfiguration; import org.jdbi.v3.core.Jdbi; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; @@ -70,8 +69,10 @@ import java.nio.charset.StandardCharsets; import java.sql.SQLException; import java.time.Instant; +import java.util.AbstractMap; import java.util.ArrayList; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.UUID; import java.util.regex.Pattern; @@ -105,9 +106,7 @@ class TracesResourceTest { "lastUpdatedAt", "feedbackScores", "createdBy", "lastUpdatedBy"}; private static final String[] IGNORED_FIELDS_SPANS = {"projectId", "projectName", "createdAt", "lastUpdatedAt", "feedbackScores", "createdBy", "lastUpdatedBy"}; - private static final String[] IGNORED_FIELDS_SCORES = {"projectId", "projectName", "id", "createdAt", - "lastUpdatedAt", - "createdBy", "lastUpdatedBy"}; + private static final String[] IGNORED_FIELDS_SCORES = {"createdAt", "lastUpdatedAt", "createdBy", "lastUpdatedBy"}; private static final String API_KEY = UUID.randomUUID().toString(); private static final String USER = UUID.randomUUID().toString(); @@ -782,6 +781,70 @@ void getByProjectName__whenProjectNameAndIdAreNull__thenReturnBadRequest() { .isEqualTo("Either 'project_name' or 'project_id' query params must be provided"); } + @Test + void findWithUsage() { + var projectName = RandomStringUtils.randomAlphanumeric(10); + var traces = PodamFactoryUtils.manufacturePojoList(factory, Trace.class).stream() + .map(trace -> trace.toBuilder() + .projectName(projectName) + .usage(null) + .feedbackScores(null) + .build()) + .toList(); + batchCreateTracesAndAssert(traces, API_KEY, TEST_WORKSPACE); + + var traceIdToSpansMap = traces.stream() + .flatMap(trace -> PodamFactoryUtils.manufacturePojoList(factory, Span.class).stream() + .map(span -> span.toBuilder() + .projectName(projectName) + .traceId(trace.id()) + .feedbackScores(null) + .build())) + .collect(Collectors.groupingBy(Span::traceId)); + batchCreateSpansAndAssert( + traceIdToSpansMap.values().stream().flatMap(List::stream).toList(), API_KEY, TEST_WORKSPACE); + + traces = traces.stream().map(trace -> trace.toBuilder() + .usage(traceIdToSpansMap.get(trace.id()).stream() + .map(Span::usage) + .flatMap(usage -> usage.entrySet().stream()) + .collect(Collectors.groupingBy( + Map.Entry::getKey, Collectors.summingLong(Map.Entry::getValue)))) + .build()).toList(); + getAndAssertPage(TEST_WORKSPACE, projectName, List.of(), traces, traces.reversed(), List.of(), API_KEY); + } + + @Test + void findWithoutUsage() { + var apiKey = UUID.randomUUID().toString(); + var workspaceName = RandomStringUtils.randomAlphanumeric(10); + var workspaceId = UUID.randomUUID().toString(); + mockTargetWorkspace(apiKey, workspaceName, workspaceId); + + var projectName = RandomStringUtils.randomAlphanumeric(10); + var traces = PodamFactoryUtils.manufacturePojoList(factory, Trace.class).stream() + .map(trace -> trace.toBuilder() + .projectName(projectName) + .usage(null) + .feedbackScores(null) + .build()) + .toList(); + batchCreateTracesAndAssert(traces, apiKey, workspaceName); + + var spans = traces.stream() + .flatMap(trace -> PodamFactoryUtils.manufacturePojoList(factory, Span.class).stream() + .map(span -> span.toBuilder() + .projectName(projectName) + .traceId(trace.id()) + .usage(null) + .feedbackScores(null) + .build())) + .toList(); + batchCreateSpansAndAssert(spans, apiKey, workspaceName); + + getAndAssertPage(workspaceName, projectName, List.of(), traces, traces.reversed(), List.of(), apiKey); + } + @Test @DisplayName("when project name is not empty, then return traces by project name") void getByProjectName__whenProjectNameIsNotEmpty__thenReturnTracesByProjectName() { @@ -891,6 +954,7 @@ void getByProjectName__whenFilterWorkspaceName__thenReturnTracesFiltered() { .map(trace -> trace.toBuilder() .projectId(null) .projectName(projectName1) + .usage(null) .feedbackScores(null) .build()) .toList(); @@ -900,6 +964,7 @@ void getByProjectName__whenFilterWorkspaceName__thenReturnTracesFiltered() { .map(trace -> trace.toBuilder() .projectId(null) .projectName(projectName1) + .usage(null) .feedbackScores(null) .build()) .toList(); @@ -928,6 +993,7 @@ void getByProjectName__whenFilterIdAndNameEqual__thenReturnTracesFiltered() { .map(trace -> trace.toBuilder() .projectId(null) .projectName(projectName) + .usage(null) .feedbackScores(null) .build()) .collect(Collectors.toCollection(ArrayList::new)); @@ -966,6 +1032,7 @@ void getByProjectName__whenFilterNameEqual__thenReturnTracesFiltered() { .map(trace -> trace.toBuilder() .projectId(null) .projectName(projectName) + .usage(null) .feedbackScores(null) .build()) .collect(Collectors.toCollection(ArrayList::new)); @@ -998,6 +1065,7 @@ void getByProjectName__whenFilterNameStartsWith__thenReturnTracesFiltered() { .map(trace -> trace.toBuilder() .projectId(null) .projectName(projectName) + .usage(null) .feedbackScores(null) .build()) .collect(Collectors.toCollection(ArrayList::new)); @@ -1030,6 +1098,7 @@ void getByProjectName__whenFilterNameEndsWith__thenReturnTracesFiltered() { .map(trace -> trace.toBuilder() .projectId(null) .projectName(projectName) + .usage(null) .feedbackScores(null) .build()) .collect(Collectors.toCollection(ArrayList::new)); @@ -1062,6 +1131,7 @@ void getByProjectName__whenFilterNameContains__thenReturnTracesFiltered() { .map(trace -> trace.toBuilder() .projectId(null) .projectName(projectName) + .usage(null) .feedbackScores(null) .build()) .collect(Collectors.toCollection(ArrayList::new)); @@ -1096,6 +1166,7 @@ void getByProjectName__whenFilterNameNotContains__thenReturnTracesFiltered() { .projectId(null) .projectName(projectName) .name(traceName) + .usage(null) .feedbackScores(null) .build()) .collect(Collectors.toCollection(ArrayList::new)); @@ -1131,6 +1202,7 @@ void getByProjectName__whenFilterStartTimeEqual__thenReturnTracesFiltered() { .map(trace -> trace.toBuilder() .projectId(null) .projectName(projectName) + .usage(null) .feedbackScores(null) .build()) .collect(Collectors.toCollection(ArrayList::new)); @@ -1164,6 +1236,7 @@ void getByProjectName__whenFilterStartTimeGreaterThan__thenReturnTracesFiltered( .projectId(null) .projectName(projectName) .startTime(Instant.now().minusSeconds(60 * 5)) + .usage(null) .feedbackScores(null) .build()) .collect(Collectors.toCollection(ArrayList::new)); @@ -1200,6 +1273,7 @@ void getByProjectName__whenFilterStartTimeGreaterThanEqual__thenReturnTracesFilt .projectId(null) .projectName(projectName) .startTime(Instant.now().minusSeconds(60 * 5)) + .usage(null) .feedbackScores(null) .build()) .collect(Collectors.toCollection(ArrayList::new)); @@ -1236,6 +1310,7 @@ void getByProjectName__whenFilterStartTimeLessThan__thenReturnTracesFiltered() { .projectId(null) .projectName(projectName) .startTime(Instant.now().plusSeconds(60 * 5)) + .usage(null) .feedbackScores(null) .build()) .collect(Collectors.toCollection(ArrayList::new)); @@ -1272,6 +1347,7 @@ void getByProjectName__whenFilterStartTimeLessThanEqual__thenReturnTracesFiltere .projectId(null) .projectName(projectName) .startTime(Instant.now().plusSeconds(60 * 5)) + .usage(null) .feedbackScores(null) .build()) .collect(Collectors.toCollection(ArrayList::new)); @@ -1307,6 +1383,7 @@ void getByProjectName__whenFilterEndTimeEqual__thenReturnTracesFiltered() { .map(trace -> trace.toBuilder() .projectId(null) .projectName(projectName) + .usage(null) .feedbackScores(null) .build()) .collect(Collectors.toCollection(ArrayList::new)); @@ -1339,6 +1416,7 @@ void getByProjectName__whenFilterInputEqual__thenReturnTracesFiltered() { .map(trace -> trace.toBuilder() .projectId(null) .projectName(projectName) + .usage(null) .feedbackScores(null) .build()) .collect(Collectors.toCollection(ArrayList::new)); @@ -1371,6 +1449,7 @@ void getByProjectName__whenFilterOutputEqual__thenReturnTracesFiltered() { .map(trace -> trace.toBuilder() .projectId(null) .projectName(projectName) + .usage(null) .feedbackScores(null) .build()) .collect(Collectors.toCollection(ArrayList::new)); @@ -1405,6 +1484,7 @@ void getByProjectName__whenFilterMetadataEqualString__thenReturnTracesFiltered() .projectName(projectName) .metadata(JsonUtils.getJsonNodeFromString("{\"model\":[{\"year\":2024,\"version\":\"Some " + "version\"}]}")) + .usage(null) .feedbackScores(null) .build()) .collect(Collectors.toCollection(ArrayList::new)); @@ -1443,6 +1523,7 @@ void getByProjectName__whenFilterMetadataEqualNumber__thenReturnTracesFiltered() .projectName(projectName) .metadata(JsonUtils.getJsonNodeFromString("{\"model\":[{\"year\":2024,\"version\":\"Some " + "version\"}]}")) + .usage(null) .feedbackScores(null) .build()) .collect(Collectors.toCollection(ArrayList::new)); @@ -1483,6 +1564,7 @@ void getByProjectName__whenFilterMetadataEqualBoolean__thenReturnTracesFiltered( .metadata( JsonUtils.getJsonNodeFromString("{\"model\":[{\"year\":false,\"version\":\"Some " + "version\"}]}")) + .usage(null) .feedbackScores(null) .build()) .collect(Collectors.toCollection(ArrayList::new)); @@ -1522,6 +1604,7 @@ void getByProjectName__whenFilterMetadataEqualNull__thenReturnTracesFiltered() { .projectName(projectName) .metadata(JsonUtils.getJsonNodeFromString("{\"model\":[{\"year\":2024,\"version\":\"Some " + "version\"}]}")) + .usage(null) .feedbackScores(null) .build()) .collect(Collectors.toCollection(ArrayList::new)); @@ -1561,6 +1644,7 @@ void getByProjectName__whenFilterMetadataContainsString__thenReturnTracesFiltere .projectName(projectName) .metadata(JsonUtils.getJsonNodeFromString("{\"model\":[{\"year\":2024,\"version\":\"Some " + "version\"}]}")) + .usage(null) .feedbackScores(null) .build()) .collect(Collectors.toCollection(ArrayList::new)); @@ -1600,6 +1684,7 @@ void getByProjectName__whenFilterMetadataContainsNumber__thenReturnTracesFiltere .projectName(projectName) .metadata(JsonUtils.getJsonNodeFromString("{\"model\":[{\"year\":\"two thousand twenty " + "four\",\"version\":\"OpenAI, Chat-GPT 4.0\"}]}")) + .usage(null) .feedbackScores(null) .build()) .collect(Collectors.toCollection(ArrayList::new)); @@ -1640,6 +1725,7 @@ void getByProjectName__whenFilterMetadataContainsBoolean__thenReturnTracesFilter .metadata( JsonUtils.getJsonNodeFromString("{\"model\":[{\"year\":false,\"version\":\"Some " + "version\"}]}")) + .usage(null) .feedbackScores(null) .build()) .collect(Collectors.toCollection(ArrayList::new)); @@ -1679,6 +1765,7 @@ void getByProjectName__whenFilterMetadataContainsNull__thenReturnTracesFiltered( .projectName(projectName) .metadata(JsonUtils.getJsonNodeFromString("{\"model\":[{\"year\":2024,\"version\":\"Some " + "version\"}]}")) + .usage(null) .feedbackScores(null) .build()) .collect(Collectors.toCollection(ArrayList::new)); @@ -1718,6 +1805,7 @@ void getByProjectName__whenFilterMetadataGreaterThanNumber__thenReturnTracesFilt .projectName(projectName) .metadata(JsonUtils.getJsonNodeFromString("{\"model\":[{\"year\":2020," + "\"version\":\"OpenAI, Chat-GPT 4.0\"}]}")) + .usage(null) .feedbackScores(null) .build()) .collect(Collectors.toCollection(ArrayList::new)); @@ -1865,6 +1953,7 @@ void getByProjectName__whenFilterMetadataLessThanNumber__thenReturnTracesFiltere .projectName(projectName) .metadata(JsonUtils.getJsonNodeFromString("{\"model\":[{\"year\":2026," + "\"version\":\"OpenAI, Chat-GPT 4.0\"}]}")) + .usage(null) .feedbackScores(null) .build()) .collect(Collectors.toCollection(ArrayList::new)); @@ -2010,6 +2099,7 @@ void getByProjectName__whenFilterTagsContains__thenReturnTracesFiltered() { .map(trace -> trace.toBuilder() .projectId(null) .projectName(projectName) + .usage(null) .feedbackScores(null) .build()) .collect(Collectors.toCollection(ArrayList::new)); @@ -2046,6 +2136,7 @@ void getByProjectName__whenFilterFeedbackScoresEqual__thenReturnTracesFiltered() .map(trace -> trace.toBuilder() .projectId(null) .projectName(projectName) + .usage(null) .feedbackScores(trace.feedbackScores().stream() .map(feedbackScore -> feedbackScore.toBuilder() .value(factory.manufacturePojo(BigDecimal.class)) @@ -2099,6 +2190,7 @@ void getByProjectName__whenFilterFeedbackScoresGreaterThan__thenReturnTracesFilt .map(trace -> trace.toBuilder() .projectId(null) .projectName(projectName) + .usage(null) .feedbackScores(updateFeedbackScore(trace.feedbackScores().stream() .map(feedbackScore -> feedbackScore.toBuilder() .value(factory.manufacturePojo(BigDecimal.class)) @@ -2150,6 +2242,7 @@ void getByProjectName__whenFilterFeedbackScoresGreaterThanEqual__thenReturnTrace .map(trace -> trace.toBuilder() .projectId(null) .projectName(projectName) + .usage(null) .feedbackScores(updateFeedbackScore(trace.feedbackScores().stream() .map(feedbackScore -> feedbackScore.toBuilder() .value(factory.manufacturePojo(BigDecimal.class)) @@ -2196,6 +2289,7 @@ void getByProjectName__whenFilterFeedbackScoresLessThan__thenReturnTracesFiltere .map(trace -> trace.toBuilder() .projectId(null) .projectName(projectName) + .usage(null) .feedbackScores(updateFeedbackScore(trace.feedbackScores().stream() .map(feedbackScore -> feedbackScore.toBuilder() .value(factory.manufacturePojo(BigDecimal.class)) @@ -2243,6 +2337,7 @@ void getByProjectName__whenFilterFeedbackScoresLessThanEqual__thenReturnTracesFi .map(trace -> trace.toBuilder() .projectId(null) .projectName(projectName) + .usage(null) .feedbackScores(updateFeedbackScore(trace.feedbackScores().stream() .map(feedbackScore -> feedbackScore.toBuilder() .value(factory.manufacturePojo(BigDecimal.class)) @@ -2763,30 +2858,35 @@ private String toURLEncodedQueryParam(List filters) { } private void assertIgnoredFields(List actualTraces, List expectedTraces) { + assertThat(actualTraces).size().isEqualTo(expectedTraces.size()); for (int i = 0; i < actualTraces.size(); i++) { var actualTrace = actualTraces.get(i); var expectedTrace = expectedTraces.get(i); - assertThat(actualTrace.projectId()).isNotNull(); - assertThat(actualTrace.projectName()).isNull(); - assertThat(actualTrace.createdAt()).isAfter(expectedTrace.createdAt()); - assertThat(actualTrace.lastUpdatedAt()).isAfter(expectedTrace.lastUpdatedAt()); - assertThat(actualTrace.lastUpdatedBy()).isEqualTo(USER); - assertThat(actualTrace.lastUpdatedBy()).isEqualTo(USER); - assertThat(actualTrace.feedbackScores()) - .usingRecursiveComparison() - .withComparatorForType(BigDecimal::compareTo, BigDecimal.class) - .ignoringFields(IGNORED_FIELDS_SCORES) - .ignoringCollectionOrder() - .isEqualTo(expectedTrace.feedbackScores()); + assertIgnoredFields(actualTrace, expectedTrace); + } + } - if (expectedTrace.feedbackScores() != null) { - actualTrace.feedbackScores().forEach(feedbackScore -> { - assertThat(feedbackScore.createdAt()).isAfter(expectedTrace.createdAt()); - assertThat(feedbackScore.lastUpdatedAt()).isAfter(expectedTrace.createdAt()); - assertThat(feedbackScore.lastUpdatedBy()).isEqualTo(USER); - assertThat(feedbackScore.lastUpdatedBy()).isEqualTo(USER); - }); - } + private static void assertIgnoredFields(Trace actualTrace, Trace expectedTrace) { + assertThat(actualTrace.projectId()).isNotNull(); + assertThat(actualTrace.projectName()).isNull(); + assertThat(actualTrace.createdAt()).isAfter(expectedTrace.createdAt()); + assertThat(actualTrace.lastUpdatedAt()).isAfter(expectedTrace.lastUpdatedAt()); + assertThat(actualTrace.createdBy()).isEqualTo(USER); + assertThat(actualTrace.lastUpdatedBy()).isEqualTo(USER); + assertThat(actualTrace.feedbackScores()) + .usingRecursiveComparison() + .withComparatorForType(BigDecimal::compareTo, BigDecimal.class) + .ignoringFields(IGNORED_FIELDS_SCORES) + .ignoringCollectionOrder() + .isEqualTo(expectedTrace.feedbackScores()); + + if (expectedTrace.feedbackScores() != null) { + actualTrace.feedbackScores().forEach(feedbackScore -> { + assertThat(feedbackScore.createdAt()).isAfter(expectedTrace.createdAt()); + assertThat(feedbackScore.lastUpdatedAt()).isAfter(expectedTrace.lastUpdatedAt()); + assertThat(feedbackScore.createdBy()).isEqualTo(USER); + assertThat(feedbackScore.lastUpdatedBy()).isEqualTo(USER); + }); } } @@ -2835,52 +2935,84 @@ private List updateFeedbackScore( class GetTrace { @Test - @DisplayName("Success") - void getTrace() { - - var projectName = generator.generate().toString(); + void getTraceWithUsage() { + var projectName = RandomStringUtils.randomAlphanumeric(10); + var span = factory.manufacturePojo(Span.class); + var usage = Stream.concat( + Map.of("completion_tokens", 2 * 5L, "prompt_tokens", 3 * 5L + 3, "total_tokens", 4 * 5L) + .entrySet().stream(), + span.usage().entrySet().stream()) + .map(entry -> new AbstractMap.SimpleEntry<>(entry.getKey(), entry.getValue().longValue())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); var trace = factory.manufacturePojo(Trace.class) .toBuilder() .id(null) - .name("OpenAPI Trace") .projectName(projectName) .endTime(null) + .input(null) .output(null) - .createdAt(null) - .lastUpdatedAt(null) .metadata(null) .tags(null) - .projectId(null) + .usage(usage) .feedbackScores(null) .build(); - var id = create(trace, API_KEY, TEST_WORKSPACE); - var actualResponse = getById(id, TEST_WORKSPACE, API_KEY); - - var actualEntity = actualResponse.readEntity(Trace.class); - - assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(200); - - assertThat(actualEntity.id()).isEqualTo(id); - assertThat(actualEntity.name()).isEqualTo("OpenAPI Trace"); - assertThat(actualEntity.projectId()).isNotNull(); + var spans = PodamFactoryUtils.manufacturePojoList(factory, Span.class).stream() + .map(spanInStream -> spanInStream.toBuilder() + .projectName(projectName) + .traceId(id) + .usage(Map.of("completion_tokens", 2, "prompt_tokens", 3, "total_tokens", 4)) + .build()) + .collect(Collectors.toList()); + spans.add(factory.manufacturePojo(Span.class).toBuilder() + .projectName(projectName) + .traceId(id) + .usage(null) + .build()); + spans.add(factory.manufacturePojo(Span.class).toBuilder() + .projectName(projectName) + .traceId(id) + .usage(Map.of("prompt_tokens", 3)) + .build()); + spans.add(span.toBuilder() + .projectName(projectName) + .traceId(id) + .build()); + batchCreateSpansAndAssert(spans, API_KEY, TEST_WORKSPACE); - assertThat(actualEntity.createdAt()).isNotNull(); - assertThat(actualEntity.createdAt()).isInstanceOf(Instant.class); - assertThat(actualEntity.lastUpdatedAt()).isNotNull(); - assertThat(actualEntity.lastUpdatedAt()).isInstanceOf(Instant.class); + var projectId = getProjectId(projectName, TEST_WORKSPACE, API_KEY); + trace = trace.toBuilder().id(id).build(); + getAndAssert(trace, projectId, API_KEY, TEST_WORKSPACE); + } - assertThat(actualEntity.input()).isNotNull(); - assertThat(actualEntity.output()).isNull(); + @Test + void getTraceWithoutUsage() { + var apiKey = UUID.randomUUID().toString(); + var workspaceName = RandomStringUtils.randomAlphanumeric(10); + var workspaceId = UUID.randomUUID().toString(); + mockTargetWorkspace(apiKey, workspaceName, workspaceId); - assertThat(actualEntity.metadata()).isNull(); - assertThat(actualEntity.tags()).isNull(); + var projectName = RandomStringUtils.randomAlphanumeric(10); + var trace = factory.manufacturePojo(Trace.class) + .toBuilder() + .projectName(projectName) + .usage(null) + .feedbackScores(null) + .build(); + create(trace, apiKey, workspaceName); - assertThat(actualEntity.endTime()).isNull(); + var spans = PodamFactoryUtils.manufacturePojoList(factory, Span.class).stream() + .map(spanInStream -> spanInStream.toBuilder() + .projectName(projectName) + .traceId(trace.id()) + .usage(null) + .build()) + .toList(); + batchCreateSpansAndAssert(spans, apiKey, workspaceName); - assertThat(actualEntity.startTime()).isNotNull(); - assertThat(actualEntity.startTime()).isInstanceOf(Instant.class); + var projectId = getProjectId(projectName, workspaceName, apiKey); + getAndAssert(trace, projectId, apiKey, workspaceName); } @Test @@ -2925,32 +3057,19 @@ private void create(UUID entityId, FeedbackScore score, String workspaceName, St } } - private Trace getAndAssert(Trace trace, UUID projectId, String apiKey, String workspaceName) { + private Trace getAndAssert(Trace expectedTrace, UUID projectId, String apiKey, String workspaceName) { + var actualResponse = getById(expectedTrace.id(), workspaceName, apiKey); + var actualTrace = actualResponse.readEntity(Trace.class); - var actualResponse = getById(trace.id(), workspaceName, apiKey); - var actualEntity = actualResponse.readEntity(Trace.class); - - assertThat(actualEntity) - .usingRecursiveComparison( - RecursiveComparisonConfiguration.builder() - .withIgnoredFields(IGNORED_FIELDS_TRACES) - .withComparatorForType(BigDecimal::compareTo, BigDecimal.class) - .build()) - .isEqualTo(trace); - - assertThat(actualEntity.name()).isEqualTo(trace.name()); - assertThat(actualEntity.projectId()).isEqualTo(projectId); - assertThat(actualEntity.input()).isEqualTo(trace.input()); - assertThat(actualEntity.output()).isEqualTo(trace.output()); - assertThat(actualEntity.metadata()).isEqualTo(trace.metadata()); - assertThat(actualEntity.tags()).isEqualTo(trace.tags()); - assertThat(actualEntity.endTime()).isEqualTo(trace.endTime()); - assertThat(actualEntity.startTime()).isEqualTo(trace.startTime()); + assertThat(actualTrace) + .usingRecursiveComparison() + .ignoringFields(IGNORED_FIELDS_TRACES) + .isEqualTo(expectedTrace); - assertThat(actualEntity.createdAt()).isAfter(trace.createdAt()); - assertThat(actualEntity.lastUpdatedAt()).isAfter(trace.lastUpdatedAt()); + assertThat(actualTrace.projectId()).isEqualTo(projectId); + assertIgnoredFields(actualTrace, expectedTrace); - return actualEntity; + return actualTrace; } private void getAndAssertTraceNotFound(UUID id, String apiKey, String testWorkspace) { @@ -2975,21 +3094,13 @@ class CreateTrace { @Test @DisplayName("Success") void createTrace() { - var id = generator.generate(); - var trace = factory.manufacturePojo(Trace.class).toBuilder() .id(id) - .name("OpenAPI traces") .projectName(DEFAULT_PROJECT) - .input(JsonUtils.getJsonNodeFromString("{ \"input\": \"data\"}")) - .output(JsonUtils.getJsonNodeFromString("{ \"output\": \"data\"}")) - .endTime(Instant.now()) - .startTime(Instant.now().minusSeconds(10)) - .metadata(JsonUtils.getJsonNodeFromString("{ \"metadata\": \"data\"}")) - .tags(Set.of("tag1", "tag2")) + .usage(null) + .feedbackScores(null) .build(); - try (var actualResponse = client.target(URL_TEMPLATE.formatted(baseURI)).request() .accept(MediaType.APPLICATION_JSON_TYPE) .header(HttpHeaders.AUTHORIZATION, API_KEY) @@ -3003,30 +3114,31 @@ void createTrace() { } var projectId = getProjectId(trace.projectName(), TEST_WORKSPACE, API_KEY); - getAndAssert(trace, projectId, API_KEY, TEST_WORKSPACE); } @Test @DisplayName("when creating traces with different workspaces names, then return created traces") void create__whenCreatingTracesWithDifferentWorkspacesNames__thenReturnCreatedTraces() { - var projectName = generator.generate().toString(); var trace1 = factory.manufacturePojo(Trace.class) .toBuilder() .projectName(DEFAULT_PROJECT) + .usage(null) + .feedbackScores(null) .build(); var trace2 = factory.manufacturePojo(Trace.class) .toBuilder() .projectName(projectName) + .usage(null) + .feedbackScores(null) .build(); - create(trace1, API_KEY, TEST_WORKSPACE); create(trace2, API_KEY, TEST_WORKSPACE); - UUID projectId1 = getProjectId(DEFAULT_PROJECT, TEST_WORKSPACE, API_KEY); - UUID projectId2 = getProjectId(projectName, TEST_WORKSPACE, API_KEY); + var projectId1 = getProjectId(DEFAULT_PROJECT, TEST_WORKSPACE, API_KEY); + var projectId2 = getProjectId(projectName, TEST_WORKSPACE, API_KEY); getAndAssert(trace1, projectId1, API_KEY, TEST_WORKSPACE); getAndAssert(trace2, projectId2, API_KEY, TEST_WORKSPACE); @@ -3036,6 +3148,8 @@ void create__whenCreatingTracesWithDifferentWorkspacesNames__thenReturnCreatedTr void createWithMissingId() { var trace = factory.manufacturePojo(Trace.class).toBuilder() .id(null) + .usage(null) + .feedbackScores(null) .build(); var id = create(trace, API_KEY, TEST_WORKSPACE); @@ -3122,6 +3236,7 @@ void batch__whenCreateTraces__thenReturnNoContent() { .projectName(projectName) .projectId(projectId) .endTime(null) + .usage(null) .feedbackScores(null) .build()) .toList(); @@ -3227,7 +3342,6 @@ private void batchCreateTracesAndAssert(List traces, String apiKey, Strin } private void batchCreateSpansAndAssert(List expectedSpans, String apiKey, String workspaceName) { - try (var actualResponse = client.target(URL_TEMPLATE_SPANS.formatted(baseURI)) .path("batch") .request() @@ -3257,6 +3371,7 @@ void delete() { var traces = List.of(factory.manufacturePojo(Trace.class).toBuilder() .projectName(projectName) + .usage(null) .build()); batchCreateTracesAndAssert(traces, apiKey, workspaceName); @@ -3265,6 +3380,7 @@ void delete() { .map(span -> span.toBuilder() .projectName(projectName) .traceId(trace.id()) + .usage(null) .build())) .toList(); batchCreateSpansAndAssert(spans, apiKey, workspaceName); @@ -3303,6 +3419,7 @@ void deleteWithoutSpansScores() { var traces = List.of(factory.manufacturePojo(Trace.class).toBuilder() .projectName(projectName) + .usage(null) .build()); batchCreateTracesAndAssert(traces, apiKey, workspaceName); @@ -3311,6 +3428,7 @@ void deleteWithoutSpansScores() { .map(span -> span.toBuilder() .projectName(projectName) .traceId(trace.id()) + .usage(null) .feedbackScores(null) .build())) .toList(); @@ -3343,6 +3461,7 @@ void deleteWithoutScores() { var traces = List.of(factory.manufacturePojo(Trace.class).toBuilder() .projectName(projectName) + .usage(null) .feedbackScores(null) .build()); batchCreateTracesAndAssert(traces, apiKey, workspaceName); @@ -3352,6 +3471,7 @@ void deleteWithoutScores() { .map(span -> span.toBuilder() .projectName(projectName) .traceId(trace.id()) + .usage(null) .feedbackScores(null) .build())) .toList(); @@ -3377,6 +3497,7 @@ void deleteWithoutSpans() { var traces = List.of(factory.manufacturePojo(Trace.class).toBuilder() .projectName(projectName) + .usage(null) .feedbackScores(null) .build()); batchCreateTracesAndAssert(traces, apiKey, workspaceName); @@ -3420,6 +3541,7 @@ void deleteTraces() { var traces = PodamFactoryUtils.manufacturePojoList(factory, Trace.class).stream() .map(trace -> trace.toBuilder() .projectName(projectName) + .usage(null) .build()) .toList(); batchCreateTracesAndAssert(traces, apiKey, workspaceName); @@ -3429,6 +3551,7 @@ void deleteTraces() { .map(span -> span.toBuilder() .projectName(projectName) .traceId(trace.id()) + .usage(null) .build())) .toList(); batchCreateSpansAndAssert(spans, apiKey, workspaceName); @@ -3471,6 +3594,7 @@ void deleteTracesWithoutSpansScores() { var traces = PodamFactoryUtils.manufacturePojoList(factory, Trace.class).stream() .map(trace -> trace.toBuilder() .projectName(projectName) + .usage(null) .build()) .toList(); batchCreateTracesAndAssert(traces, apiKey, workspaceName); @@ -3480,6 +3604,7 @@ void deleteTracesWithoutSpansScores() { .map(span -> span.toBuilder() .projectName(projectName) .traceId(trace.id()) + .usage(null) .feedbackScores(null) .build())) .toList(); @@ -3516,6 +3641,7 @@ void deleteTracesWithoutScores() { var traces = PodamFactoryUtils.manufacturePojoList(factory, Trace.class).stream() .map(trace -> trace.toBuilder() .projectName(projectName) + .usage(null) .feedbackScores(null) .build()) .toList(); @@ -3526,6 +3652,7 @@ void deleteTracesWithoutScores() { .map(span -> span.toBuilder() .projectName(projectName) .traceId(trace.id()) + .usage(null) .feedbackScores(null) .build())) .toList(); @@ -3555,6 +3682,7 @@ void deleteTracesWithoutSpans() { var traces = PodamFactoryUtils.manufacturePojoList(factory, Trace.class).stream() .map(trace -> trace.toBuilder() .projectName(projectName) + .usage(null) .feedbackScores(null) .build()) .toList(); @@ -3600,6 +3728,7 @@ void setUp() { .metadata(null) .tags(null) .projectId(null) + .usage(null) .feedbackScores(null) .build(); @@ -3868,7 +3997,7 @@ void update__whenTagsIsEmpty__thenAcceptUpdate() { UUID projectId = getProjectId(trace.projectName(), TEST_WORKSPACE, API_KEY); - Trace actualTrace = getAndAssert(trace, projectId, API_KEY, + var actualTrace = getAndAssert(trace, projectId, API_KEY, TEST_WORKSPACE); assertThat(actualTrace.tags()).isNull(); @@ -3889,7 +4018,7 @@ void update__whenMetadataIsEmpty__thenAcceptUpdate() { UUID projectId = getProjectId(trace.projectName(), TEST_WORKSPACE, API_KEY); - Trace actualTrace = getAndAssert(trace.toBuilder().metadata(metadata).build(), projectId, + var actualTrace = getAndAssert(trace.toBuilder().metadata(metadata).build(), projectId, API_KEY, TEST_WORKSPACE); assertThat(actualTrace.metadata()).isEqualTo(metadata); @@ -3910,7 +4039,7 @@ void update__whenInputIsEmpty__thenAcceptUpdate() { UUID projectId = getProjectId(trace.projectName(), TEST_WORKSPACE, API_KEY); - Trace actualTrace = getAndAssert(trace.toBuilder().input(input).build(), projectId, + var actualTrace = getAndAssert(trace.toBuilder().input(input).build(), projectId, API_KEY, TEST_WORKSPACE); assertThat(actualTrace.input()).isEqualTo(input); @@ -3931,7 +4060,7 @@ void update__whenOutputIsEmpty__thenAcceptUpdate() { UUID projectId = getProjectId(trace.projectName(), TEST_WORKSPACE, API_KEY); - Trace actualTrace = getAndAssert(trace.toBuilder().output(output).build(), projectId, + var actualTrace = getAndAssert(trace.toBuilder().output(output).build(), projectId, API_KEY, TEST_WORKSPACE); assertThat(actualTrace.output()).isEqualTo(output); @@ -3952,7 +4081,6 @@ void update__whenUpdatingUsingProjectId__thenAcceptUpdate() { var updatedTrace = trace.toBuilder() .projectId(projectId) .metadata(traceUpdate.metadata()) - .feedbackScores(null) .input(traceUpdate.input()) .output(traceUpdate.output()) .endTime(traceUpdate.endTime()) @@ -3965,7 +4093,7 @@ void update__whenUpdatingUsingProjectId__thenAcceptUpdate() { } private Response getById(UUID id, String workspaceName, String apiKey) { - Response response = client.target(URL_TEMPLATE.formatted(baseURI)) + var response = client.target(URL_TEMPLATE.formatted(baseURI)) .path(id.toString()) .request() .header(HttpHeaders.AUTHORIZATION, apiKey) @@ -4028,9 +4156,7 @@ Stream invalidRequestBodyParams() { @Test @DisplayName("when trace does not exist, then return not found") void feedback__whenTraceDoesNotExist__thenReturnNotFound() { - var id = generator.generate(); - try (var actualResponse = client.target(URL_TEMPLATE.formatted(baseURI)) .path(id.toString()) .path("feedback-scores") @@ -4049,11 +4175,9 @@ void feedback__whenTraceDoesNotExist__thenReturnNotFound() { @ParameterizedTest @MethodSource("invalidRequestBodyParams") @DisplayName("when feedback request body is invalid, then return bad request") - void feedback__whenFeedbackRequestBodyIsInvalid__thenReturnBadRequest(FeedbackScore feedbackScore, - String errorMessage) { - + void feedback__whenFeedbackRequestBodyIsInvalid__thenReturnBadRequest( + FeedbackScore feedbackScore, String errorMessage) { var id = generator.generate(); - try (var actualResponse = client.target(URL_TEMPLATE.formatted(baseURI)).path(id.toString()) .path("feedback-scores") .request() @@ -4070,7 +4194,6 @@ void feedback__whenFeedbackRequestBodyIsInvalid__thenReturnBadRequest(FeedbackSc @Test @DisplayName("when feedback without category name or reason, then return no content") void feedback__whenFeedbackWithoutCategoryNameOrReason__thenReturnNoContent() { - var trace = factory.manufacturePojo(Trace.class) .toBuilder() .projectName(DEFAULT_PROJECT) @@ -4078,34 +4201,26 @@ void feedback__whenFeedbackWithoutCategoryNameOrReason__thenReturnNoContent() { .output(null) .metadata(null) .tags(null) + .usage(null) .feedbackScores(null) .build(); - var id = create(trace, API_KEY, TEST_WORKSPACE); - FeedbackScore score = factory.manufacturePojo(FeedbackScore.class).toBuilder() + var score = factory.manufacturePojo(FeedbackScore.class).toBuilder() .categoryName(null) .reason(null) .value(factory.manufacturePojo(BigDecimal.class)) .build(); - create(id, score, TEST_WORKSPACE, API_KEY); - UUID projectId = getProjectId(trace.projectName(), TEST_WORKSPACE, API_KEY); - - var actualEntity = getAndAssert(trace, projectId, API_KEY, TEST_WORKSPACE); - - assertThat(actualEntity.feedbackScores()).hasSize(1); - - FeedbackScore actualScore = actualEntity.feedbackScores().getFirst(); - - assertEqualsForScores(actualScore, score); + var projectId = getProjectId(trace.projectName(), TEST_WORKSPACE, API_KEY); + trace = trace.toBuilder().feedbackScores(List.of(score)).build(); + getAndAssert(trace, projectId, API_KEY, TEST_WORKSPACE); } @Test @DisplayName("when feedback with category name or reason, then return no content") void feedback__whenFeedbackWithCategoryNameOrReason__thenReturnNoContent() { - var trace = factory.manufacturePojo(Trace.class) .toBuilder() .projectName(DEFAULT_PROJECT) @@ -4113,25 +4228,20 @@ void feedback__whenFeedbackWithCategoryNameOrReason__thenReturnNoContent() { .output(null) .metadata(null) .tags(null) + .usage(null) .feedbackScores(null) .build(); - var id = create(trace, API_KEY, TEST_WORKSPACE); var score = factory.manufacturePojo(FeedbackScore.class).toBuilder() .value(factory.manufacturePojo(BigDecimal.class)) .build(); - create(id, score, TEST_WORKSPACE, API_KEY); - UUID projectId = getProjectId(trace.projectName(), TEST_WORKSPACE, API_KEY); - - Trace actualEntity = getAndAssert(trace, projectId, API_KEY, TEST_WORKSPACE); - - assertThat(actualEntity.feedbackScores()).hasSize(1); - FeedbackScore actualScore = actualEntity.feedbackScores().getFirst(); + var projectId = getProjectId(trace.projectName(), TEST_WORKSPACE, API_KEY); - assertEqualsForScores(actualScore, score); + trace = trace.toBuilder().feedbackScores(List.of(score)).build(); + getAndAssert(trace, projectId, API_KEY, TEST_WORKSPACE); } @Test @@ -4144,25 +4254,20 @@ void feedback__whenOverridingFeedbackValue__thenReturnNoContent() { .output(null) .metadata(null) .tags(null) + .usage(null) .feedbackScores(null) .build(); - var id = create(trace, API_KEY, TEST_WORKSPACE); var score = factory.manufacturePojo(FeedbackScore.class); - create(id, score, TEST_WORKSPACE, API_KEY); - FeedbackScore newScore = score.toBuilder().value(BigDecimal.valueOf(2)).build(); + var newScore = score.toBuilder().value(BigDecimal.valueOf(2)).build(); create(id, newScore, TEST_WORKSPACE, API_KEY); - UUID projectId = getProjectId(trace.projectName(), TEST_WORKSPACE, API_KEY); - var actualEntity = getAndAssert(trace, projectId, API_KEY, TEST_WORKSPACE); - - assertThat(actualEntity.feedbackScores()).hasSize(1); - FeedbackScore actualScore = actualEntity.feedbackScores().getFirst(); - - assertEqualsForScores(actualScore, newScore); + var projectId = getProjectId(trace.projectName(), TEST_WORKSPACE, API_KEY); + trace = trace.toBuilder().feedbackScores(List.of(newScore)).build(); + getAndAssert(trace, projectId, API_KEY, TEST_WORKSPACE); } } @@ -4284,18 +4389,17 @@ Stream invalidRequestBodyParams() { @Test @DisplayName("Success") void feedback() { - var trace = factory.manufacturePojo(Trace.class) + var trace1 = factory.manufacturePojo(Trace.class) .toBuilder() .projectName(DEFAULT_PROJECT) .endTime(null) .output(null) .metadata(null) .tags(null) + .usage(null) .feedbackScores(null) .build(); - - var id = create(trace, API_KEY, TEST_WORKSPACE); - + var id1 = create(trace1, API_KEY, TEST_WORKSPACE); var trace2 = factory.manufacturePojo(Trace.class) .toBuilder() .projectName(UUID.randomUUID().toString()) @@ -4303,45 +4407,41 @@ void feedback() { .output(null) .metadata(null) .tags(null) + .usage(null) .feedbackScores(null) .build(); - var id2 = create(trace2, API_KEY, TEST_WORKSPACE); - var score = factory.manufacturePojo(FeedbackScoreBatchItem.class).toBuilder() - .id(id) - .projectName(trace.projectName()) + var score1 = factory.manufacturePojo(FeedbackScoreBatchItem.class).toBuilder() + .id(id1) + .projectName(trace1.projectName()) .value(factory.manufacturePojo(BigDecimal.class)) .build(); - var score2 = factory.manufacturePojo(FeedbackScoreBatchItem.class).toBuilder() .id(id2) .name("hallucination") .projectName(trace2.projectName()) .value(factory.manufacturePojo(BigDecimal.class)) .build(); - var score3 = factory.manufacturePojo(FeedbackScoreBatchItem.class).toBuilder() - .id(id) + .id(id1) .name("hallucination") - .projectName(trace.projectName()) + .projectName(trace1.projectName()) .value(factory.manufacturePojo(BigDecimal.class)) .build(); - - var feedbackScoreBatch = FeedbackScoreBatch.builder().scores(List.of(score, score2, score3)).build(); + var feedbackScoreBatch = FeedbackScoreBatch.builder().scores(List.of(score1, score2, score3)).build(); createAndAssertForTrace(feedbackScoreBatch, TEST_WORKSPACE, API_KEY); - UUID projectId = getProjectId(trace.projectName(), TEST_WORKSPACE, API_KEY); - UUID projectId2 = getProjectId(trace2.projectName(), TEST_WORKSPACE, API_KEY); - - var actualTrace1 = getAndAssert(trace, projectId, API_KEY, TEST_WORKSPACE); - var actualTrace2 = getAndAssert(trace2, projectId2, API_KEY, TEST_WORKSPACE); - - assertThat(actualTrace2.feedbackScores()).hasSize(1); - assertThat(actualTrace1.feedbackScores()).hasSize(2); - - assertEqualsForScores(List.of(score, score3), actualTrace1.feedbackScores()); - assertEqualsForScores(List.of(score2), actualTrace2.feedbackScores()); + var projectId1 = getProjectId(trace1.projectName(), TEST_WORKSPACE, API_KEY); + var projectId2 = getProjectId(trace2.projectName(), TEST_WORKSPACE, API_KEY); + trace1 = trace1.toBuilder() + .feedbackScores(FeedbackScoreMapper.INSTANCE.toFeedbackScores(List.of(score1, score3))) + .build(); + trace2 = trace2.toBuilder() + .feedbackScores(FeedbackScoreMapper.INSTANCE.toFeedbackScores(List.of(score2))) + .build(); + getAndAssert(trace1, projectId1, API_KEY, TEST_WORKSPACE); + getAndAssert(trace2, projectId2, API_KEY, TEST_WORKSPACE); } @Test @@ -4350,66 +4450,61 @@ void feedback__whenWorkspaceIsSpecified__thenReturnNoContent() { var projectName = UUID.randomUUID().toString(); var workspaceName = UUID.randomUUID().toString(); var workspaceId = UUID.randomUUID().toString(); - String apiKey = UUID.randomUUID().toString(); + var apiKey = UUID.randomUUID().toString(); mockTargetWorkspace(apiKey, workspaceName, workspaceId); var expectedTrace1 = factory.manufacturePojo(Trace.class).toBuilder() .projectName(DEFAULT_PROJECT) .projectId(null) + .usage(null) .build(); - - var id = create(expectedTrace1, apiKey, workspaceName); + var id1 = create(expectedTrace1, apiKey, workspaceName); var expectedTrace2 = factory.manufacturePojo(Trace.class).toBuilder() .projectName(projectName) .projectId(null) + .usage(null) .build(); - var id2 = create(expectedTrace2, apiKey, workspaceName); - var score = factory.manufacturePojo(FeedbackScoreBatchItem.class) + var score1 = factory.manufacturePojo(FeedbackScoreBatchItem.class) .toBuilder() - .id(id) + .id(id1) .projectName(expectedTrace1.projectName()) .value(factory.manufacturePojo(BigDecimal.class)) .build(); - var score2 = factory.manufacturePojo(FeedbackScoreBatchItem.class).toBuilder() .id(id2) .name("hallucination") .projectName(expectedTrace2.projectName()) .value(factory.manufacturePojo(BigDecimal.class)) .build(); - var score3 = factory.manufacturePojo(FeedbackScoreBatchItem.class).toBuilder() - .id(id) + .id(id1) .name("hallucination") .projectName(expectedTrace1.projectName()) .value(factory.manufacturePojo(BigDecimal.class)) .build(); - - var feedbackScoreBatch = FeedbackScoreBatch.builder().scores(List.of(score, score2, score3)).build(); + var feedbackScoreBatch = FeedbackScoreBatch.builder().scores(List.of(score1, score2, score3)).build(); createAndAssertForTrace(feedbackScoreBatch, workspaceName, apiKey); - UUID projectId = getProjectId(DEFAULT_PROJECT, workspaceName, apiKey); - UUID projectId2 = getProjectId(projectName, workspaceName, apiKey); - - var actualTrace1 = getAndAssert(expectedTrace1, projectId, apiKey, workspaceName); - var actualTrace2 = getAndAssert(expectedTrace2, projectId2, apiKey, workspaceName); - - assertThat(actualTrace2.feedbackScores()).hasSize(1); - assertThat(actualTrace1.feedbackScores()).hasSize(2); - - assertEqualsForScores(actualTrace1.feedbackScores(), List.of(score, score3)); - assertEqualsForScores(actualTrace2.feedbackScores(), List.of(score2)); + var projectId1 = getProjectId(DEFAULT_PROJECT, workspaceName, apiKey); + var projectId2 = getProjectId(projectName, workspaceName, apiKey); + expectedTrace1 = expectedTrace1.toBuilder() + .feedbackScores(FeedbackScoreMapper.INSTANCE.toFeedbackScores(List.of(score1, score3))) + .build(); + expectedTrace2 = expectedTrace2.toBuilder() + .feedbackScores(FeedbackScoreMapper.INSTANCE.toFeedbackScores(List.of(score2))) + .build(); + getAndAssert(expectedTrace1, projectId1, apiKey, workspaceName); + getAndAssert(expectedTrace2, projectId2, apiKey, workspaceName); } @ParameterizedTest @MethodSource("invalidRequestBodyParams") @DisplayName("when batch request is invalid, then return bad request") void feedback__whenBatchRequestIsInvalid__thenReturnBadRequest(FeedbackScoreBatch batch, String errorMessage) { - try (var actualResponse = client.target(URL_TEMPLATE.formatted(baseURI)) .path("feedback-scores") .request() @@ -4426,7 +4521,6 @@ void feedback__whenBatchRequestIsInvalid__thenReturnBadRequest(FeedbackScoreBatc @Test @DisplayName("when feedback without category name or reason, then return no content") void feedback__whenFeedbackWithoutCategoryNameOrReason__thenReturnNoContent() { - var trace = factory.manufacturePojo(Trace.class) .toBuilder() .projectName(DEFAULT_PROJECT) @@ -4434,9 +4528,9 @@ void feedback__whenFeedbackWithoutCategoryNameOrReason__thenReturnNoContent() { .output(null) .metadata(null) .tags(null) + .usage(null) .feedbackScores(null) .build(); - var id = create(trace, API_KEY, TEST_WORKSPACE); var score = factory.manufacturePojo(FeedbackScoreBatchItem.class).toBuilder() @@ -4446,37 +4540,30 @@ void feedback__whenFeedbackWithoutCategoryNameOrReason__thenReturnNoContent() { .value(factory.manufacturePojo(BigDecimal.class)) .reason(null) .build(); + createAndAssertForTrace( + FeedbackScoreBatch.builder().scores(List.of(score)).build(), TEST_WORKSPACE, API_KEY); - createAndAssertForTrace(FeedbackScoreBatch.builder().scores(List.of(score)).build(), TEST_WORKSPACE, - API_KEY); - - UUID projectId = getProjectId(trace.projectName(), TEST_WORKSPACE, API_KEY); - - var actualEntity = getAndAssert(trace, projectId, API_KEY, TEST_WORKSPACE); - - assertThat(actualEntity.feedbackScores()).hasSize(1); - - FeedbackScore actualScore = actualEntity.feedbackScores().getFirst(); - - assertEqualsForScores(actualScore, score); + var projectId = getProjectId(trace.projectName(), TEST_WORKSPACE, API_KEY); + trace = trace.toBuilder() + .feedbackScores(FeedbackScoreMapper.INSTANCE.toFeedbackScores(List.of(score))) + .build(); + getAndAssert(trace, projectId, API_KEY, TEST_WORKSPACE); } @Test @DisplayName("when feedback with category name or reason, then return no content") void feedback__whenFeedbackWithCategoryNameOrReason__thenReturnNoContent() { + var projectName = RandomStringUtils.randomAlphanumeric(10); - var projectName = UUID.randomUUID().toString(); - - Trace expectedTrace = factory.manufacturePojo(Trace.class) + var expectedTrace = factory.manufacturePojo(Trace.class) .toBuilder() .projectName(projectName) .endTime(null) .output(null) .metadata(null) .tags(null) - .feedbackScores(null) + .usage(null) .build(); - var id = create(expectedTrace, API_KEY, TEST_WORKSPACE); var score = factory.manufacturePojo(FeedbackScoreBatchItem.class).toBuilder() @@ -4484,25 +4571,23 @@ void feedback__whenFeedbackWithCategoryNameOrReason__thenReturnNoContent() { .projectName(expectedTrace.projectName()) .value(factory.manufacturePojo(BigDecimal.class)) .build(); + createAndAssertForTrace( + FeedbackScoreBatch.builder().scores(List.of(score)).build(), TEST_WORKSPACE, API_KEY); - createAndAssertForTrace(FeedbackScoreBatch.builder().scores(List.of(score)).build(), TEST_WORKSPACE, - API_KEY); - - var actualEntity = getAndAssert(expectedTrace, - getProjectId(expectedTrace.projectName(), TEST_WORKSPACE, API_KEY), API_KEY, + expectedTrace = expectedTrace.toBuilder() + .feedbackScores(FeedbackScoreMapper.INSTANCE.toFeedbackScores(List.of(score))) + .build(); + getAndAssert( + expectedTrace, + getProjectId(expectedTrace.projectName(), TEST_WORKSPACE, API_KEY), + API_KEY, TEST_WORKSPACE); - - assertThat(actualEntity.feedbackScores()).hasSize(1); - FeedbackScore actualScore = actualEntity.feedbackScores().getFirst(); - - assertEqualsForScores(actualScore, score); } @Test @DisplayName("when overriding feedback value, then return no content") void feedback__whenOverridingFeedbackValue__thenReturnNoContent() { - - var projectName = UUID.randomUUID().toString(); + var projectName = RandomStringUtils.randomAlphanumeric(10); var trace = factory.manufacturePojo(Trace.class) .toBuilder() .projectName(projectName) @@ -4510,54 +4595,44 @@ void feedback__whenOverridingFeedbackValue__thenReturnNoContent() { .output(null) .metadata(null) .tags(null) - .feedbackScores(null) - .feedbackScores(null) + .usage(null) .build(); - var id = create(trace, API_KEY, TEST_WORKSPACE); var score = factory.manufacturePojo(FeedbackScoreBatchItem.class).toBuilder() .id(id) .projectName(trace.projectName()) .build(); + createAndAssertForTrace( + FeedbackScoreBatch.builder().scores(List.of(score)).build(), TEST_WORKSPACE, API_KEY); - createAndAssertForTrace(FeedbackScoreBatch.builder().scores(List.of(score)).build(), TEST_WORKSPACE, - API_KEY); - - FeedbackScoreBatchItem newItem = score.toBuilder().value(factory.manufacturePojo(BigDecimal.class)).build(); + var newScore = score.toBuilder().value(factory.manufacturePojo(BigDecimal.class)).build(); + createAndAssertForTrace( + FeedbackScoreBatch.builder().scores(List.of(newScore)).build(), TEST_WORKSPACE, API_KEY); - createAndAssertForTrace(FeedbackScoreBatch.builder().scores(List.of(newItem)).build(), TEST_WORKSPACE, - API_KEY); - - UUID projectId = getProjectId(trace.projectName(), TEST_WORKSPACE, API_KEY); - - var actualEntity = getAndAssert(trace, projectId, API_KEY, TEST_WORKSPACE); - - assertThat(actualEntity.feedbackScores()).hasSize(1); - FeedbackScore actualScore = actualEntity.feedbackScores().getFirst(); - - assertEqualsForScores(actualScore, newItem); + var projectId = getProjectId(trace.projectName(), TEST_WORKSPACE, API_KEY); + trace = trace.toBuilder() + .feedbackScores(FeedbackScoreMapper.INSTANCE.toFeedbackScores(List.of(newScore))) + .build(); + getAndAssert(trace, projectId, API_KEY, TEST_WORKSPACE); } @Test @DisplayName("when trace does not exist, then return no content and create score") void feedback__whenTraceDoesNotExist__thenReturnNoContentAndCreateScore() { - var id = generator.generate(); - var score = factory.manufacturePojo(FeedbackScoreBatchItem.class).toBuilder() .id(id) .projectName(DEFAULT_PROJECT) .build(); - createAndAssertForTrace(FeedbackScoreBatch.builder().scores(List.of(score)).build(), TEST_WORKSPACE, - API_KEY); + createAndAssertForTrace( + FeedbackScoreBatch.builder().scores(List.of(score)).build(), TEST_WORKSPACE, API_KEY); } @Test @DisplayName("when feedback trace project and score project do not match, then return conflict") void feedback__whenFeedbackTraceProjectAndScoreProjectDoNotMatch__thenReturnConflict() { - var trace = factory.manufacturePojo(Trace.class) .toBuilder() .id(null) @@ -4570,14 +4645,12 @@ void feedback__whenFeedbackTraceProjectAndScoreProjectDoNotMatch__thenReturnConf .tags(null) .feedbackScores(null) .build(); - var id = create(trace, API_KEY, TEST_WORKSPACE); var score = factory.manufacturePojo(FeedbackScoreBatchItem.class).toBuilder() .id(id) .projectName(UUID.randomUUID().toString()) .build(); - try (var actualResponse = client.target(URL_TEMPLATE.formatted(baseURI)) .path("feedback-scores") .request() @@ -4599,7 +4672,6 @@ void feedback__whenFeedbackSpanBatchHasMaxSize__thenReturnNoContentAndCreateScor var expectedTrace = factory.manufacturePojo(Trace.class).toBuilder() .projectName(DEFAULT_PROJECT) .build(); - var id = create(expectedTrace, API_KEY, TEST_WORKSPACE); var scores = IntStream.range(0, 1000) @@ -4608,14 +4680,12 @@ void feedback__whenFeedbackSpanBatchHasMaxSize__thenReturnNoContentAndCreateScor .id(id) .build()) .toList(); - createAndAssertForTrace(FeedbackScoreBatch.builder().scores(scores).build(), TEST_WORKSPACE, API_KEY); } @Test @DisplayName("when feedback trace id is not valid, then return 400") void feedback__whenFeedbackTraceIdIsNotValid__thenReturn400() { - var score = factory.manufacturePojo(FeedbackScoreBatchItem.class).toBuilder() .id(UUID.randomUUID()) .projectName(DEFAULT_PROJECT) @@ -4657,33 +4727,6 @@ private void createAndAssert(String path, FeedbackScoreBatch request, String wor } } - private void assertEqualsForScores(FeedbackScore actualScore, FeedbackScore expectedScore) { - assertThat(actualScore) - .usingRecursiveComparison() - .withComparatorForType(BigDecimal::compareTo, BigDecimal.class) - .ignoringFields(IGNORED_FIELDS_SCORES) - .isEqualTo(expectedScore); - } - - private void assertEqualsForScores(FeedbackScore actualScore, FeedbackScoreBatchItem expectedScore) { - assertThat(actualScore) - .usingRecursiveComparison() - .withComparatorForType(BigDecimal::compareTo, BigDecimal.class) - .ignoringFields(IGNORED_FIELDS_SCORES) - .isEqualTo(expectedScore); - } - - private void assertEqualsForScores(List expected, List actual) { - assertThat(actual) - .usingRecursiveComparison( - RecursiveComparisonConfiguration.builder() - .withIgnoredFields(IGNORED_FIELDS_SCORES) - .withComparatorForType(BigDecimal::compareTo, BigDecimal.class) - .build()) - .ignoringCollectionOrder() - .isEqualTo(expected); - } - private int setupTracesForWorkspace(String workspaceName, String workspaceId, String okApikey) { mockTargetWorkspace(okApikey, workspaceName, workspaceId); From 310e5515683ed16f9467ae299bb1fdbd4cad9f5f Mon Sep 17 00:00:00 2001 From: Andres Cruz Date: Wed, 2 Oct 2024 12:56:38 +0200 Subject: [PATCH 2/2] Rev2: using trace star for queries --- .../java/com/comet/opik/domain/TraceDAO.java | 60 ++----------------- 1 file changed, 4 insertions(+), 56 deletions(-) diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/TraceDAO.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/TraceDAO.java index 9aa9182f..265529ef 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/TraceDAO.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/TraceDAO.java @@ -239,20 +239,7 @@ INSERT INTO traces ( private static final String SELECT_BY_ID = """ SELECT - t.id, - t.workspace_id, - t.project_id, - t.name, - t.start_time, - t.end_time, - t.input, - t.output, - t.metadata, - t.tags, - t.created_at, - t.last_updated_at, - t.created_by, - t.last_updated_by, + t.*, sumMap(s.usage) as usage FROM ( SELECT @@ -274,40 +261,14 @@ LEFT JOIN ( LIMIT 1 BY id ) AS s ON t.id = s.trace_id GROUP BY - t.id, - t.workspace_id, - t.project_id, - t.name, - t.start_time, - t.end_time, - t.input, - t.output, - t.metadata, - t.tags, - t.created_at, - t.last_updated_at, - t.created_by, - t.last_updated_by + t.* ORDER BY t.id DESC ; """; private static final String SELECT_BY_PROJECT_ID = """ SELECT - t.id, - t.workspace_id, - t.project_id, - t.name, - t.start_time, - t.end_time, - t.input, - t.output, - t.metadata, - t.tags, - t.created_at, - t.last_updated_at, - t.created_by, - t.last_updated_by, + t.*, sumMap(s.usage) as usage FROM ( SELECT @@ -348,20 +309,7 @@ LEFT JOIN ( LIMIT 1 BY id ) AS s ON t.id = s.trace_id GROUP BY - t.id, - t.workspace_id, - t.project_id, - t.name, - t.start_time, - t.end_time, - t.input, - t.output, - t.metadata, - t.tags, - t.created_at, - t.last_updated_at, - t.created_by, - t.last_updated_by + t.* ORDER BY t.id DESC ; """;