diff --git a/apps/opik-backend/config.yml b/apps/opik-backend/config.yml index ec31499b..63476638 100644 --- a/apps/opik-backend/config.yml +++ b/apps/opik-backend/config.yml @@ -65,3 +65,9 @@ server: enableVirtualThreads: ${ENABLE_VIRTUAL_THREADS:-false} gzip: enabled: true + +rateLimit: + enabled: ${RATE_LIMIT_ENABLED:-false} + generalEvents: + limit: ${RATE_LIMIT_GENERAL_EVENTS_LIMIT:-10000} + durationInSeconds: ${RATE_LIMIT_GENERAL_EVENTS_DURATION_IN_SEC:-60} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/OpikApplication.java b/apps/opik-backend/src/main/java/com/comet/opik/OpikApplication.java index 813a43fe..e45aac13 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/OpikApplication.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/OpikApplication.java @@ -6,6 +6,7 @@ import com.comet.opik.infrastructure.db.DatabaseAnalyticsModule; import com.comet.opik.infrastructure.db.IdGeneratorModule; import com.comet.opik.infrastructure.db.NameGeneratorModule; +import com.comet.opik.infrastructure.ratelimit.RateLimitModule; import com.comet.opik.infrastructure.redis.RedisModule; import com.comet.opik.utils.JsonBigDecimalDeserializer; import com.fasterxml.jackson.annotation.JsonInclude; @@ -60,7 +61,7 @@ public void initialize(Bootstrap bootstrap) { .bundles(JdbiBundle.forDatabase((conf, env) -> conf.getDatabase()) .withPlugins(new SqlObjectPlugin(), new Jackson2Plugin())) .modules(new DatabaseAnalyticsModule(), new IdGeneratorModule(), new AuthModule(), new RedisModule(), - new NameGeneratorModule()) + new RateLimitModule(), new NameGeneratorModule()) .enableAutoConfig() .build()); } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/DatasetItemBatch.java b/apps/opik-backend/src/main/java/com/comet/opik/api/DatasetItemBatch.java index 6c3204d2..f03ee900 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/DatasetItemBatch.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/DatasetItemBatch.java @@ -1,6 +1,7 @@ package com.comet.opik.api; import com.comet.opik.api.validate.DatasetItemBatchValidation; +import com.comet.opik.infrastructure.ratelimit.RateEventContainer; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonView; import com.fasterxml.jackson.databind.PropertyNamingStrategies; @@ -26,6 +27,12 @@ public record DatasetItemBatch( DatasetItem.View.Write.class}) @Pattern(regexp = NULL_OR_NOT_BLANK, message = "must not be blank") @Schema(description = "If null, dataset_id must be provided") String datasetName, @JsonView({ DatasetItem.View.Write.class}) @Schema(description = "If null, dataset_name must be provided") UUID datasetId, - @JsonView({DatasetItem.View.Write.class}) @NotNull @Size(min = 1, max = 1000) @Valid List items){ + @JsonView({DatasetItem.View.Write.class}) @NotNull @Size(min = 1, max = 1000) @Valid List items) + implements + RateEventContainer{ + @Override + public long eventCount() { + return items.size(); + } } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/ExperimentItemsBatch.java b/apps/opik-backend/src/main/java/com/comet/opik/api/ExperimentItemsBatch.java index 54a93145..771c5172 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/ExperimentItemsBatch.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/ExperimentItemsBatch.java @@ -1,5 +1,6 @@ package com.comet.opik.api; +import com.comet.opik.infrastructure.ratelimit.RateEventContainer; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonView; import com.fasterxml.jackson.databind.PropertyNamingStrategies; @@ -16,5 +17,12 @@ @JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) public record ExperimentItemsBatch( @JsonView( { - ExperimentItem.View.Write.class}) @NotNull @Size(min = 1, max = 1000) @Valid Set experimentItems){ + ExperimentItem.View.Write.class}) @NotNull @Size(min = 1, max = 1000) @Valid Set experimentItems) + implements + RateEventContainer{ + + @Override + public long eventCount() { + return experimentItems.size(); + } } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/FeedbackScoreBatch.java b/apps/opik-backend/src/main/java/com/comet/opik/api/FeedbackScoreBatch.java index ab798e74..1fc5910e 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/FeedbackScoreBatch.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/FeedbackScoreBatch.java @@ -1,5 +1,6 @@ package com.comet.opik.api; +import com.comet.opik.infrastructure.ratelimit.RateEventContainer; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.databind.PropertyNamingStrategies; import com.fasterxml.jackson.databind.annotation.JsonNaming; @@ -13,6 +14,11 @@ @Builder(toBuilder = true) @JsonIgnoreProperties(ignoreUnknown = true) @JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) -public record FeedbackScoreBatch(@NotNull @Size(min = 1, max = 1000) @Valid List scores) { +public record FeedbackScoreBatch( + @NotNull @Size(min = 1, max = 1000) @Valid List scores) implements RateEventContainer { + @Override + public long eventCount() { + return scores.size(); + } } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/SpanBatch.java b/apps/opik-backend/src/main/java/com/comet/opik/api/SpanBatch.java index 02727bec..74fa4325 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/SpanBatch.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/SpanBatch.java @@ -1,5 +1,6 @@ package com.comet.opik.api; +import com.comet.opik.infrastructure.ratelimit.RateEventContainer; import com.fasterxml.jackson.annotation.JsonView; import jakarta.validation.Valid; import jakarta.validation.constraints.NotNull; @@ -10,5 +11,10 @@ @Builder(toBuilder = true) public record SpanBatch(@NotNull @Size(min = 1, max = 1000) @JsonView( { - Span.View.Write.class}) @Valid List spans){ + Span.View.Write.class}) @Valid List spans) implements RateEventContainer{ + + @Override + public long eventCount() { + return spans.size(); + } } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/TraceBatch.java b/apps/opik-backend/src/main/java/com/comet/opik/api/TraceBatch.java index 0765a897..fafa6446 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/TraceBatch.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/TraceBatch.java @@ -1,5 +1,6 @@ package com.comet.opik.api; +import com.comet.opik.infrastructure.ratelimit.RateEventContainer; import com.fasterxml.jackson.annotation.JsonView; import jakarta.validation.Valid; import jakarta.validation.constraints.NotNull; @@ -10,5 +11,10 @@ @Builder(toBuilder = true) public record TraceBatch(@NotNull @Size(min = 1, max = 1000) @JsonView( { - Trace.View.Write.class}) @Valid List traces){ + Trace.View.Write.class}) @Valid List traces) implements RateEventContainer{ + + @Override + public long eventCount() { + return traces.size(); + } } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/DatasetsResource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/DatasetsResource.java index cb54d09b..5f80960d 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/DatasetsResource.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/DatasetsResource.java @@ -16,6 +16,7 @@ import com.comet.opik.domain.FeedbackScoreDAO; import com.comet.opik.domain.IdGenerator; import com.comet.opik.infrastructure.auth.RequestContext; +import com.comet.opik.infrastructure.ratelimit.RateLimited; import com.comet.opik.utils.AsyncUtils; import com.comet.opik.utils.JsonUtils; import com.fasterxml.jackson.annotation.JsonView; @@ -136,6 +137,7 @@ public Response findDatasets( @Header(name = "Location", required = true, example = "${basePath}/api/v1/private/datasets/{id}", schema = @Schema(implementation = String.class)) }) }) + @RateLimited public Response createDataset( @RequestBody(content = @Content(schema = @Schema(implementation = Dataset.class))) @JsonView(Dataset.View.Write.class) @NotNull @Valid Dataset dataset, @Context UriInfo uriInfo) { @@ -156,6 +158,7 @@ public Response createDataset( @Operation(operationId = "updateDataset", summary = "Update dataset by id", description = "Update dataset by id", responses = { @ApiResponse(responseCode = "204", description = "No content"), }) + @RateLimited public Response updateDataset(@PathParam("id") UUID id, @RequestBody(content = @Content(schema = @Schema(implementation = DatasetUpdate.class))) @NotNull @Valid DatasetUpdate datasetUpdate) { @@ -346,6 +349,7 @@ private void sendDatasetItems(DatasetItem item, ChunkedOutput writer) @Operation(operationId = "createOrUpdateDatasetItems", summary = "Create/update dataset items", description = "Create/update dataset items based on dataset item id", responses = { @ApiResponse(responseCode = "204", description = "No content"), }) + @RateLimited public Response createDatasetItems( @RequestBody(content = @Content(schema = @Schema(implementation = DatasetItemBatch.class))) @JsonView({ DatasetItem.View.Write.class}) @NotNull @Valid DatasetItemBatch batch) { diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ExperimentsResource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ExperimentsResource.java index 816535f3..31caf2a3 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ExperimentsResource.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ExperimentsResource.java @@ -10,6 +10,7 @@ import com.comet.opik.domain.ExperimentService; import com.comet.opik.domain.IdGenerator; import com.comet.opik.infrastructure.auth.RequestContext; +import com.comet.opik.infrastructure.ratelimit.RateLimited; import com.comet.opik.utils.AsyncUtils; import com.fasterxml.jackson.annotation.JsonView; import io.dropwizard.jersey.errors.ErrorMessage; @@ -106,6 +107,7 @@ public Response get(@PathParam("id") UUID id) { @Operation(operationId = "createExperiment", summary = "Create experiment", description = "Create experiment", responses = { @ApiResponse(responseCode = "201", description = "Created", headers = { @Header(name = "Location", required = true, example = "${basePath}/v1/private/experiments/{id}", schema = @Schema(implementation = String.class))})}) + @RateLimited public Response create( @RequestBody(content = @Content(schema = @Schema(implementation = Experiment.class))) @JsonView(Experiment.View.Write.class) @NotNull @Valid Experiment experiment, @Context UriInfo uriInfo) { @@ -151,6 +153,7 @@ public Response getExperimentItem(@PathParam("id") UUID id) { @Path("/items") @Operation(operationId = "createExperimentItems", summary = "Create experiment items", description = "Create experiment items", responses = { @ApiResponse(responseCode = "204", description = "No content")}) + @RateLimited public Response createExperimentItems( @RequestBody(content = @Content(schema = @Schema(implementation = ExperimentItemsBatch.class))) @NotNull @Valid ExperimentItemsBatch request) { diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/FeedbackDefinitionResource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/FeedbackDefinitionResource.java index c3a2d9cc..345503da 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/FeedbackDefinitionResource.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/FeedbackDefinitionResource.java @@ -6,6 +6,7 @@ import com.comet.opik.api.Page; import com.comet.opik.domain.FeedbackDefinitionService; import com.comet.opik.infrastructure.auth.RequestContext; +import com.comet.opik.infrastructure.ratelimit.RateLimited; import com.fasterxml.jackson.annotation.JsonView; import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.headers.Header; @@ -100,6 +101,7 @@ public Response getById(@PathParam("id") @NotNull UUID id) { @ApiResponse(responseCode = "201", description = "Created", headers = { @Header(name = "Location", required = true, example = "${basePath}/v1/private/feedback-definitions/{feedbackId}", schema = @Schema(implementation = String.class))}) }) + @RateLimited public Response create( @RequestBody(content = @Content(schema = @Schema(implementation = FeedbackDefinition.class))) @JsonView({ FeedbackDefinition.View.Create.class}) @NotNull @Valid FeedbackDefinition feedbackDefinition, @@ -123,6 +125,7 @@ public Response create( @Operation(operationId = "updateFeedbackDefinition", summary = "Update feedback definition by id", description = "Update feedback definition by id", responses = { @ApiResponse(responseCode = "204", description = "No Content") }) + @RateLimited public Response update(final @PathParam("id") UUID id, @RequestBody(content = @Content(schema = @Schema(implementation = FeedbackDefinition.class))) @JsonView({ FeedbackDefinition.View.Update.class}) @NotNull @Valid FeedbackDefinition feedbackDefinition) { diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ProjectsResource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ProjectsResource.java index f6e05d4b..186d5202 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ProjectsResource.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ProjectsResource.java @@ -8,6 +8,7 @@ import com.comet.opik.api.error.ErrorMessage; import com.comet.opik.domain.ProjectService; import com.comet.opik.infrastructure.auth.RequestContext; +import com.comet.opik.infrastructure.ratelimit.RateLimited; import com.fasterxml.jackson.annotation.JsonView; import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.headers.Header; @@ -100,6 +101,7 @@ public Response getById(@PathParam("id") UUID id) { @ApiResponse(responseCode = "422", description = "Unprocessable Content", content = @Content(schema = @Schema(implementation = ErrorMessage.class))), @ApiResponse(responseCode = "400", description = "Bad Request", content = @Content(schema = @Schema(implementation = ErrorMessage.class))) }) + @RateLimited public Response create( @RequestBody(content = @Content(schema = @Schema(implementation = Project.class))) @JsonView(Project.View.Write.class) @Valid Project project, @Context UriInfo uriInfo) { @@ -125,6 +127,7 @@ public Response create( @ApiResponse(responseCode = "422", description = "Unprocessable Content", content = @Content(schema = @Schema(implementation = ErrorMessage.class))), @ApiResponse(responseCode = "400", description = "Bad Request", content = @Content(schema = @Schema(implementation = ErrorMessage.class))) }) + @RateLimited public Response update(@PathParam("id") UUID id, @RequestBody(content = @Content(schema = @Schema(implementation = ProjectUpdate.class))) @Valid ProjectUpdate project) { diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/SpansResource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/SpansResource.java index 53279ead..166e293d 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/SpansResource.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/SpansResource.java @@ -14,6 +14,7 @@ import com.comet.opik.domain.SpanService; import com.comet.opik.domain.SpanType; import com.comet.opik.infrastructure.auth.RequestContext; +import com.comet.opik.infrastructure.ratelimit.RateLimited; import com.comet.opik.utils.AsyncUtils; import com.fasterxml.jackson.annotation.JsonView; import io.swagger.v3.oas.annotations.Operation; @@ -126,6 +127,7 @@ public Response getById(@PathParam("id") @NotNull UUID id) { @Operation(operationId = "createSpan", summary = "Create span", description = "Create span", responses = { @ApiResponse(responseCode = "201", description = "Created", headers = { @Header(name = "Location", required = true, example = "${basePath}/v1/private/spans/{spanId}", schema = @Schema(implementation = String.class))})}) + @RateLimited public Response create( @RequestBody(content = @Content(schema = @Schema(implementation = Span.class))) @JsonView(Span.View.Write.class) @NotNull @Valid Span span, @Context UriInfo uriInfo) { @@ -148,6 +150,7 @@ public Response create( @Path("/batch") @Operation(operationId = "createSpans", summary = "Create spans", description = "Create spans", responses = { @ApiResponse(responseCode = "204", description = "No Content")}) + @RateLimited public Response createSpans( @RequestBody(content = @Content(schema = @Schema(implementation = SpanBatch.class))) @JsonView(Span.View.Write.class) @NotNull @Valid SpanBatch spans) { @@ -173,6 +176,7 @@ public Response createSpans( @Operation(operationId = "updateSpan", summary = "Update span by id", description = "Update span by id", responses = { @ApiResponse(responseCode = "204", description = "No Content"), @ApiResponse(responseCode = "404", description = "Not found")}) + @RateLimited public Response update(@PathParam("id") UUID id, @RequestBody(content = @Content(schema = @Schema(implementation = SpanUpdate.class))) @NotNull @Valid SpanUpdate spanUpdate) { @@ -201,6 +205,7 @@ public Response deleteById(@PathParam("id") @NotNull String id) { @Path("/{id}/feedback-scores") @Operation(operationId = "addSpanFeedbackScore", summary = "Add span feedback score", description = "Add span feedback score", responses = { @ApiResponse(responseCode = "204", description = "No Content")}) + @RateLimited public Response addSpanFeedbackScore(@PathParam("id") UUID id, @RequestBody(content = @Content(schema = @Schema(implementation = FeedbackScore.class))) @NotNull @Valid FeedbackScore score) { @@ -236,6 +241,7 @@ public Response deleteSpanFeedbackScore(@PathParam("id") UUID id, @Path("/feedback-scores") @Operation(operationId = "scoreBatchOfSpans", summary = "Batch feedback scoring for spans", description = "Batch feedback scoring for spans", responses = { @ApiResponse(responseCode = "204", description = "No Content")}) + @RateLimited public Response scoreBatchOfSpans( @RequestBody(content = @Content(schema = @Schema(implementation = FeedbackScoreBatch.class))) @NotNull @Valid FeedbackScoreBatch batch) { diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/TracesResource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/TracesResource.java index dc918040..9da1ab66 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/TracesResource.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/TracesResource.java @@ -15,6 +15,7 @@ import com.comet.opik.domain.FeedbackScoreService; import com.comet.opik.domain.TraceService; import com.comet.opik.infrastructure.auth.RequestContext; +import com.comet.opik.infrastructure.ratelimit.RateLimited; import com.comet.opik.utils.AsyncUtils; import com.fasterxml.jackson.annotation.JsonView; import io.swagger.v3.oas.annotations.Operation; @@ -125,6 +126,7 @@ public Response getById(@PathParam("id") UUID id) { @Operation(operationId = "createTrace", summary = "Create trace", description = "Get trace", responses = { @ApiResponse(responseCode = "201", description = "Created", headers = { @Header(name = "Location", required = true, example = "${basePath}/v1/private/traces/{traceId}", schema = @Schema(implementation = String.class))})}) + @RateLimited public Response create( @RequestBody(content = @Content(schema = @Schema(implementation = Trace.class))) @JsonView(Trace.View.Write.class) @NotNull @Valid Trace trace, @Context UriInfo uriInfo) { @@ -150,7 +152,8 @@ public Response create( @Path("/batch") @Operation(operationId = "createTraces", summary = "Create traces", description = "Create traces", responses = { @ApiResponse(responseCode = "204", description = "No Content")}) - public Response createSpans( + @RateLimited + public Response createTraces( @RequestBody(content = @Content(schema = @Schema(implementation = TraceBatch.class))) @JsonView(Trace.View.Write.class) @NotNull @Valid TraceBatch traces) { traces.traces() @@ -174,6 +177,7 @@ public Response createSpans( @Path("{id}") @Operation(operationId = "updateTrace", summary = "Update trace by id", description = "Update trace by id", responses = { @ApiResponse(responseCode = "204", description = "No Content")}) + @RateLimited public Response update(@PathParam("id") UUID id, @RequestBody(content = @Content(schema = @Schema(implementation = TraceUpdate.class))) @Valid @NonNull TraceUpdate trace) { @@ -225,6 +229,7 @@ public Response deleteTraces( @Path("/{id}/feedback-scores") @Operation(operationId = "addTraceFeedbackScore", summary = "Add trace feedback score", description = "Add trace feedback score", responses = { @ApiResponse(responseCode = "204", description = "No Content")}) + @RateLimited public Response addTraceFeedbackScore(@PathParam("id") UUID id, @RequestBody(content = @Content(schema = @Schema(implementation = FeedbackScore.class))) @NotNull @Valid FeedbackScore score) { @@ -265,6 +270,7 @@ public Response deleteTraceFeedbackScore(@PathParam("id") UUID id, @Path("/feedback-scores") @Operation(operationId = "scoreBatchOfTraces", summary = "Batch feedback scoring for traces", description = "Batch feedback scoring for traces", responses = { @ApiResponse(responseCode = "204", description = "No Content")}) + @RateLimited public Response scoreBatchOfTraces( @RequestBody(content = @Content(schema = @Schema(implementation = FeedbackScoreBatch.class))) @NotNull @Valid FeedbackScoreBatch batch) { diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/FeedbackScoreService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/FeedbackScoreService.java index 3c2569ba..a15de2c8 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/FeedbackScoreService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/FeedbackScoreService.java @@ -8,7 +8,7 @@ import com.comet.opik.api.error.IdentifierMismatchException; import com.comet.opik.infrastructure.auth.RequestContext; import com.comet.opik.infrastructure.db.TransactionTemplate; -import com.comet.opik.infrastructure.redis.LockService; +import com.comet.opik.infrastructure.lock.LockService; import com.comet.opik.utils.WorkspaceUtils; import com.google.inject.ImplementedBy; import com.google.inject.Singleton; diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/SpanService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/SpanService.java index 27caa005..315f20ee 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/SpanService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/SpanService.java @@ -10,7 +10,7 @@ import com.comet.opik.api.error.ErrorMessage; import com.comet.opik.api.error.IdentifierMismatchException; import com.comet.opik.infrastructure.auth.RequestContext; -import com.comet.opik.infrastructure.redis.LockService; +import com.comet.opik.infrastructure.lock.LockService; import com.comet.opik.utils.WorkspaceUtils; import com.google.common.base.Preconditions; import com.newrelic.api.agent.Trace; diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/TraceService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/TraceService.java index b1be4a97..1b2ab35d 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/TraceService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/TraceService.java @@ -12,7 +12,7 @@ import com.comet.opik.api.error.IdentifierMismatchException; import com.comet.opik.infrastructure.auth.RequestContext; import com.comet.opik.infrastructure.db.TransactionTemplate; -import com.comet.opik.infrastructure.redis.LockService; +import com.comet.opik.infrastructure.lock.LockService; import com.comet.opik.utils.AsyncUtils; import com.comet.opik.utils.WorkspaceUtils; import com.google.common.base.Preconditions; diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/OpikConfiguration.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/OpikConfiguration.java index 754a685a..8bdef4b1 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/OpikConfiguration.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/OpikConfiguration.java @@ -33,4 +33,8 @@ public class OpikConfiguration extends Configuration { @Valid @NotNull @JsonProperty private DistributedLockConfig distributedLock = new DistributedLockConfig(); + + @Valid + @NotNull @JsonProperty + private RateLimitConfig rateLimit = new RateLimitConfig(); } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/RateLimitConfig.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/RateLimitConfig.java new file mode 100644 index 00000000..11509ac4 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/RateLimitConfig.java @@ -0,0 +1,30 @@ +package com.comet.opik.infrastructure; + +import com.fasterxml.jackson.annotation.JsonProperty; +import jakarta.validation.Valid; +import jakarta.validation.constraints.Positive; +import jakarta.validation.constraints.PositiveOrZero; +import lombok.Data; + +import java.util.Map; + +@Data +public class RateLimitConfig { + + public record LimitConfig(@Valid @JsonProperty @PositiveOrZero long limit, + @Valid @JsonProperty @Positive long durationInSeconds) { + } + + @Valid + @JsonProperty + private boolean enabled; + + @Valid + @JsonProperty + private LimitConfig generalLimit; + + @Valid + @JsonProperty + private Map customLimits; + +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/AuthFilter.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/AuthFilter.java index add73588..e5e25404 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/AuthFilter.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/AuthFilter.java @@ -23,6 +23,7 @@ public class AuthFilter implements ContainerRequestFilter { private final AuthService authService; + private final jakarta.inject.Provider requestContext; @Override public void filter(ContainerRequestContext context) throws IOException { @@ -36,7 +37,7 @@ public void filter(ContainerRequestContext context) throws IOException { if (Pattern.matches("/v1/private/.*", requestUri.getPath())) { authService.authenticate(headers, sessionToken); } - + requestContext.get().setHeaders(context.getHeaders()); } HttpHeaders getHttpHeaders(ContainerRequestContext context) { diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/AuthModule.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/AuthModule.java index 757a994e..f9cd9a71 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/AuthModule.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/AuthModule.java @@ -2,7 +2,7 @@ import com.comet.opik.infrastructure.AuthenticationConfig; import com.comet.opik.infrastructure.OpikConfiguration; -import com.comet.opik.infrastructure.redis.LockService; +import com.comet.opik.infrastructure.lock.LockService; import com.google.common.base.Preconditions; import com.google.inject.Provides; import jakarta.inject.Provider; diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/AuthService.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/AuthService.java index 26fdb916..fa851635 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/AuthService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/AuthService.java @@ -31,6 +31,7 @@ public void authenticate(HttpHeaders headers, Cookie sessionToken) { requestContext.get().setWorkspaceName(currentWorkspaceName); requestContext.get().setUserName(ProjectService.DEFAULT_USER); requestContext.get().setWorkspaceId(ProjectService.DEFAULT_WORKSPACE_ID); + requestContext.get().setApiKey("default"); return; } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/RemoteAuthService.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/RemoteAuthService.java index f59a328f..68b55539 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/RemoteAuthService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/RemoteAuthService.java @@ -1,7 +1,7 @@ package com.comet.opik.infrastructure.auth; import com.comet.opik.domain.ProjectService; -import com.comet.opik.infrastructure.redis.LockService; +import com.comet.opik.infrastructure.lock.LockService; import jakarta.inject.Provider; import jakarta.ws.rs.ClientErrorException; import jakarta.ws.rs.client.Client; @@ -22,7 +22,7 @@ import static com.comet.opik.infrastructure.AuthenticationConfig.UrlConfig; import static com.comet.opik.infrastructure.auth.AuthCredentialsCacheService.AuthCredentials; -import static com.comet.opik.infrastructure.redis.LockService.Lock; +import static com.comet.opik.infrastructure.lock.LockService.Lock; @RequiredArgsConstructor @Slf4j @@ -81,6 +81,7 @@ private void authenticateUsingSessionToken(Cookie sessionToken, String workspace AuthResponse credentials = verifyResponse(response); setCredentialIntoContext(credentials.user(), credentials.workspaceId()); + requestContext.get().setApiKey(sessionToken.getValue()); } } @@ -108,6 +109,7 @@ private void authenticateUsingApiKey(HttpHeaders headers, String workspaceName) } setCredentialIntoContext(credentials.userName(), credentials.workspaceId()); + requestContext.get().setApiKey(apiKey); } private ValidatedAuthCredentials validateApiKeyAndGetCredentials(String workspaceName, String apiKey) { diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/RequestContext.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/RequestContext.java index b70a0191..a9f174a2 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/RequestContext.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/auth/RequestContext.java @@ -1,8 +1,11 @@ package com.comet.opik.infrastructure.auth; import com.google.inject.servlet.RequestScoped; +import jakarta.ws.rs.core.MultivaluedMap; +import lombok.Data; @RequestScoped +@Data public class RequestContext { public static final String WORKSPACE_HEADER = "Comet-Workspace"; @@ -10,32 +13,14 @@ public class RequestContext { public static final String WORKSPACE_NAME = "workspaceName"; public static final String SESSION_COOKIE = "sessionToken"; public static final String WORKSPACE_ID = "workspaceId"; + public static final String API_KEY = "apiKey"; + public static final String USER_LIMIT = "Opik-User-Limit"; + public static final String USER_REMAINING_LIMIT = "Opik-User-Remaining-Limit"; + public static final String USER_LIMIT_REMAINING_TTL = "Opik-User-Remaining-Limit-TTL-Millis"; private String userName; private String workspaceName; private String workspaceId; - - public final String getUserName() { - return userName; - } - - public final String getWorkspaceName() { - return workspaceName; - } - - public final String getWorkspaceId() { - return workspaceId; - } - - void setUserName(String workspaceName) { - this.userName = workspaceName; - } - - void setWorkspaceName(String workspaceName) { - this.workspaceName = workspaceName; - } - - public void setWorkspaceId(String workspaceId) { - this.workspaceId = workspaceId; - } + private String apiKey; + private MultivaluedMap headers; } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/LockService.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/lock/LockService.java similarity index 92% rename from apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/LockService.java rename to apps/opik-backend/src/main/java/com/comet/opik/infrastructure/lock/LockService.java index a02bba19..5c9d6a2f 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/LockService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/lock/LockService.java @@ -1,4 +1,4 @@ -package com.comet.opik.infrastructure.redis; +package com.comet.opik.infrastructure.lock; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateEventContainer.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateEventContainer.java new file mode 100644 index 00000000..a7f07dc2 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateEventContainer.java @@ -0,0 +1,7 @@ +package com.comet.opik.infrastructure.ratelimit; + +public interface RateEventContainer { + + long eventCount(); + +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitInterceptor.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitInterceptor.java new file mode 100644 index 00000000..aef9e154 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitInterceptor.java @@ -0,0 +1,99 @@ +package com.comet.opik.infrastructure.ratelimit; + +import com.comet.opik.infrastructure.RateLimitConfig; +import com.comet.opik.infrastructure.auth.RequestContext; +import io.swagger.v3.oas.annotations.parameters.RequestBody; +import jakarta.inject.Provider; +import jakarta.ws.rs.ClientErrorException; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.aopalliance.intercept.MethodInterceptor; +import org.aopalliance.intercept.MethodInvocation; +import org.apache.hc.core5.http.HttpStatus; + +import java.lang.reflect.Method; +import java.util.List; +import java.util.Optional; + +import static com.comet.opik.infrastructure.RateLimitConfig.LimitConfig; + +@Slf4j +@RequiredArgsConstructor +class RateLimitInterceptor implements MethodInterceptor { + + private final Provider requestContext; + private final Provider rateLimitService; + private final RateLimitConfig rateLimitConfig; + + @Override + public Object invoke(MethodInvocation invocation) throws Throwable { + + // Get the method being invoked + Method method = invocation.getMethod(); + + if (!rateLimitConfig.isEnabled()) { + return invocation.proceed(); + } + + // Check if the method is annotated with @RateLimit + if (!method.isAnnotationPresent(RateLimited.class)) { + return invocation.proceed(); + } + + RateLimited rateLimit = method.getAnnotation(RateLimited.class); + + // Check events bucket + Optional limitConfig = Optional.ofNullable(rateLimitConfig.getCustomLimits()) + .map(limits -> limits.get(rateLimit.value())); + + String limitBucket = limitConfig.isPresent() ? rateLimit.value() : RateLimited.GENERAL_EVENTS; + + LimitConfig generalLimit = limitConfig + .orElse(rateLimitConfig.getGeneralLimit()); + + String apiKey = requestContext.get().getApiKey(); + Object body = getParameters(invocation); + + long events = body instanceof RateEventContainer container ? container.eventCount() : 1; + + verifyRateLimit(events, apiKey, limitBucket, generalLimit); + + try { + return invocation.proceed(); + } finally { + setLimitHeaders(apiKey, limitBucket); + } + } + + private void verifyRateLimit(long events, String apiKey, String bucket, LimitConfig limitConfig) { + + // Check if the rate limit is exceeded + Boolean limitExceeded = rateLimitService.get() + .isLimitExceeded(apiKey, events, bucket, limitConfig.limit(), limitConfig.durationInSeconds()) + .block(); + + if (Boolean.TRUE.equals(limitExceeded)) { + setLimitHeaders(apiKey, bucket); + throw new ClientErrorException("Too Many Requests", HttpStatus.SC_TOO_MANY_REQUESTS); + } + } + + private void setLimitHeaders(String apiKey, String bucket) { + requestContext.get().getHeaders().put(RequestContext.USER_LIMIT, List.of(bucket)); + requestContext.get().getHeaders().put(RequestContext.USER_LIMIT_REMAINING_TTL, + List.of("" + rateLimitService.get().getRemainingTTL(apiKey, bucket).block())); + requestContext.get().getHeaders().put(RequestContext.USER_REMAINING_LIMIT, + List.of("" + rateLimitService.get().availableEvents(apiKey, bucket).block())); + } + + private Object getParameters(MethodInvocation method) { + + for (int i = 0; i < method.getArguments().length; i++) { + if (method.getMethod().getParameters()[i].isAnnotationPresent(RequestBody.class)) { + return method.getArguments()[i]; + } + } + + return null; + } +} \ No newline at end of file diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitModule.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitModule.java new file mode 100644 index 00000000..693ee823 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitModule.java @@ -0,0 +1,23 @@ +package com.comet.opik.infrastructure.ratelimit; + +import com.comet.opik.infrastructure.OpikConfiguration; +import com.comet.opik.infrastructure.RateLimitConfig; +import com.comet.opik.infrastructure.auth.RequestContext; +import com.google.inject.matcher.Matchers; +import ru.vyarus.dropwizard.guice.module.support.DropwizardAwareModule; + +public class RateLimitModule extends DropwizardAwareModule { + + @Override + protected void configure() { + + var rateLimit = getProvider(RateLimitService.class); + var config = configuration(RateLimitConfig.class); + var requestContext = getProvider(RequestContext.class); + + var rateLimitInterceptor = new RateLimitInterceptor(requestContext, rateLimit, config); + + bindInterceptor(Matchers.any(), Matchers.annotatedWith(RateLimited.class), rateLimitInterceptor); + } + +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitResponseFilter.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitResponseFilter.java new file mode 100644 index 00000000..2f1e4cc2 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitResponseFilter.java @@ -0,0 +1,37 @@ +package com.comet.opik.infrastructure.ratelimit; + +import com.comet.opik.infrastructure.auth.RequestContext; +import jakarta.inject.Inject; +import jakarta.ws.rs.container.ContainerRequestContext; +import jakarta.ws.rs.container.ContainerResponseContext; +import jakarta.ws.rs.container.ContainerResponseFilter; +import jakarta.ws.rs.ext.Provider; +import lombok.RequiredArgsConstructor; + +import java.io.IOException; +import java.util.List; + +@Provider +@RequiredArgsConstructor(onConstructor_ = @Inject) +public class RateLimitResponseFilter implements ContainerResponseFilter { + + @Override + public void filter(ContainerRequestContext requestContext, ContainerResponseContext responseContext) + throws IOException { + List userLimit = getValueFromHeader(requestContext, RequestContext.USER_LIMIT); + List remainingLimit = getValueFromHeader(requestContext, RequestContext.USER_REMAINING_LIMIT); + List remainingTtl = getValueFromHeader(requestContext, RequestContext.USER_LIMIT_REMAINING_TTL); + + responseContext.getHeaders().put(RequestContext.USER_LIMIT, userLimit); + responseContext.getHeaders().put(RequestContext.USER_REMAINING_LIMIT, remainingLimit); + responseContext.getHeaders().put(RequestContext.USER_LIMIT_REMAINING_TTL, remainingTtl); + } + + private List getValueFromHeader(ContainerRequestContext requestContext, String key) { + return requestContext.getHeaders().getOrDefault(key, List.of()) + .stream() + .map(Object.class::cast) + .toList(); + } + +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitService.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitService.java new file mode 100644 index 00000000..68940bca --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimitService.java @@ -0,0 +1,13 @@ +package com.comet.opik.infrastructure.ratelimit; + +import reactor.core.publisher.Mono; + +public interface RateLimitService { + + Mono isLimitExceeded(String apiKey, long events, String bucketName, long limit, + long limitDurationInSeconds); + + Mono availableEvents(String apiKey, String bucketName); + + Mono getRemainingTTL(String apiKey, String bucket); +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimited.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimited.java new file mode 100644 index 00000000..e9671f28 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/ratelimit/RateLimited.java @@ -0,0 +1,15 @@ +package com.comet.opik.infrastructure.ratelimit; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +@Target(ElementType.METHOD) +@Retention(RetentionPolicy.RUNTIME) +public @interface RateLimited { + + String GENERAL_EVENTS = "general_events"; + + String value() default GENERAL_EVENTS; // bucket capacity +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/RedisModule.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/RedisModule.java index 69f8f0e4..adc5b651 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/RedisModule.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/RedisModule.java @@ -3,6 +3,8 @@ import com.comet.opik.infrastructure.DistributedLockConfig; import com.comet.opik.infrastructure.OpikConfiguration; import com.comet.opik.infrastructure.RedisConfig; +import com.comet.opik.infrastructure.lock.LockService; +import com.comet.opik.infrastructure.ratelimit.RateLimitService; import com.google.inject.Provides; import jakarta.inject.Singleton; import org.redisson.Redisson; @@ -25,4 +27,10 @@ public LockService lockService(RedissonReactiveClient redisClient, return new RedissonLockService(redisClient, distributedLockConfig); } + @Provides + @Singleton + public RateLimitService rateLimitService(RedissonReactiveClient redisClient) { + return new RedisRateLimitService(redisClient); + } + } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/RedisRateLimitService.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/RedisRateLimitService.java new file mode 100644 index 00000000..5d784399 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/RedisRateLimitService.java @@ -0,0 +1,45 @@ +package com.comet.opik.infrastructure.redis; + +import com.comet.opik.infrastructure.ratelimit.RateLimitService; +import lombok.NonNull; +import lombok.RequiredArgsConstructor; +import org.redisson.api.RRateLimiterReactive; +import org.redisson.api.RateIntervalUnit; +import org.redisson.api.RateType; +import org.redisson.api.RedissonReactiveClient; +import reactor.core.publisher.Mono; + +import java.time.Duration; + +@RequiredArgsConstructor +public class RedisRateLimitService implements RateLimitService { + + private static final String KEY = "%s:%s"; + + private final RedissonReactiveClient redisClient; + + @Override + public Mono isLimitExceeded(String apiKey, long events, String bucketName, long limit, + long limitDurationInSeconds) { + + RRateLimiterReactive rateLimit = redisClient.getRateLimiter(KEY.formatted(bucketName, apiKey)); + + return rateLimit.trySetRate(RateType.OVERALL, limit, limitDurationInSeconds, RateIntervalUnit.SECONDS) + .then(Mono.defer(() -> rateLimit.expireIfNotSet(Duration.ofSeconds(limitDurationInSeconds)))) + .then(Mono.defer(() -> rateLimit.tryAcquire(events))) + .map(Boolean.FALSE::equals); + } + + @Override + public Mono availableEvents(@NonNull String apiKey, @NonNull String bucketName) { + RRateLimiterReactive rateLimit = redisClient.getRateLimiter(KEY.formatted(bucketName, apiKey)); + return rateLimit.availablePermits(); + } + + @Override + public Mono getRemainingTTL(@NonNull String apiKey, @NonNull String bucketName) { + RRateLimiterReactive rateLimit = redisClient.getRateLimiter(KEY.formatted(bucketName, apiKey)); + return rateLimit.remainTimeToLive(); + } + +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/RedissonLockService.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/RedissonLockService.java index 5ef02eac..61589333 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/RedissonLockService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/RedissonLockService.java @@ -1,6 +1,7 @@ package com.comet.opik.infrastructure.redis; import com.comet.opik.infrastructure.DistributedLockConfig; +import com.comet.opik.infrastructure.lock.LockService; import lombok.NonNull; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/TestDropwizardAppExtensionUtils.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/TestDropwizardAppExtensionUtils.java index 1c339b9b..10207eb8 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/TestDropwizardAppExtensionUtils.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/TestDropwizardAppExtensionUtils.java @@ -4,13 +4,37 @@ import com.comet.opik.infrastructure.DatabaseAnalyticsFactory; import com.comet.opik.infrastructure.auth.TestHttpClientUtils; import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo; +import lombok.Builder; +import lombok.experimental.UtilityClass; +import org.apache.commons.collections4.CollectionUtils; import ru.vyarus.dropwizard.guice.hook.GuiceyConfigurationHook; +import ru.vyarus.dropwizard.guice.module.installer.bundle.GuiceyBundle; +import ru.vyarus.dropwizard.guice.module.installer.bundle.GuiceyEnvironment; import ru.vyarus.dropwizard.guice.test.jupiter.ext.TestDropwizardAppExtension; import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import static com.comet.opik.infrastructure.RateLimitConfig.LimitConfig; + +@UtilityClass public class TestDropwizardAppExtensionUtils { + @Builder + public record AppContextConfig( + String jdbcUrl, + DatabaseAnalyticsFactory databaseAnalyticsFactory, + WireMockRuntimeInfo runtimeInfo, + String redisUrl, + Integer cacheTtlInSeconds, + boolean rateLimitEnabled, + Long limit, + Long limitDurationInSeconds, + Map customLimits, + List customBeans) { + } + public static TestDropwizardAppExtension newTestDropwizardAppExtension(String jdbcUrl, WireMockRuntimeInfo runtimeInfo) { return newTestDropwizardAppExtension(jdbcUrl, null, runtimeInfo); @@ -40,36 +64,78 @@ public static TestDropwizardAppExtension newTestDropwizardAppExtension( WireMockRuntimeInfo runtimeInfo, String redisUrl, Integer cacheTtlInSeconds) { + return newTestDropwizardAppExtension( + AppContextConfig.builder() + .jdbcUrl(jdbcUrl) + .databaseAnalyticsFactory(databaseAnalyticsFactory) + .runtimeInfo(runtimeInfo) + .redisUrl(redisUrl) + .cacheTtlInSeconds(cacheTtlInSeconds) + .build()); + } + + public static TestDropwizardAppExtension newTestDropwizardAppExtension(AppContextConfig appContextConfig) { var list = new ArrayList(); - list.add("database.url: " + jdbcUrl); + list.add("database.url: " + appContextConfig.jdbcUrl()); - if (databaseAnalyticsFactory != null) { - list.add("databaseAnalytics.port: " + databaseAnalyticsFactory.getPort()); - list.add("databaseAnalytics.username: " + databaseAnalyticsFactory.getUsername()); - list.add("databaseAnalytics.password: " + databaseAnalyticsFactory.getPassword()); + if (appContextConfig.databaseAnalyticsFactory() != null) { + list.add("databaseAnalytics.port: " + appContextConfig.databaseAnalyticsFactory().getPort()); + list.add("databaseAnalytics.username: " + appContextConfig.databaseAnalyticsFactory().getUsername()); + list.add("databaseAnalytics.password: " + appContextConfig.databaseAnalyticsFactory().getPassword()); } - if (runtimeInfo != null) { + if (appContextConfig.runtimeInfo() != null) { list.add("authentication.enabled: true"); - list.add("authentication.sdk.url: " + "%s/opik/auth".formatted(runtimeInfo.getHttpsBaseUrl())); - list.add("authentication.ui.url: " + "%s/opik/auth-session".formatted(runtimeInfo.getHttpsBaseUrl())); + list.add("authentication.sdk.url: " + + "%s/opik/auth".formatted(appContextConfig.runtimeInfo().getHttpsBaseUrl())); + list.add("authentication.ui.url: " + + "%s/opik/auth-session".formatted(appContextConfig.runtimeInfo().getHttpsBaseUrl())); - if (cacheTtlInSeconds != null) { - list.add("authentication.apiKeyResolutionCacheTTLInSec: " + cacheTtlInSeconds); + if (appContextConfig.cacheTtlInSeconds() != null) { + list.add("authentication.apiKeyResolutionCacheTTLInSec: " + appContextConfig.cacheTtlInSeconds()); } } GuiceyConfigurationHook hook = injector -> { injector.modulesOverride(TestHttpClientUtils.testAuthModule()); + + injector.bundles(new GuiceyBundle() { + + @Override + public void run(GuiceyEnvironment environment) { + + if (CollectionUtils.isNotEmpty(appContextConfig.customBeans())) { + appContextConfig.customBeans() + .forEach(environment::register); + } + } + }); + }; - if (redisUrl != null) { - list.add("redis.singleNodeUrl: %s".formatted(redisUrl)); + if (appContextConfig.redisUrl() != null) { + list.add("redis.singleNodeUrl: %s".formatted(appContextConfig.redisUrl())); list.add("redis.sentinelMode: false"); list.add("redis.lockTimeout: 500"); } + if (appContextConfig.rateLimitEnabled()) { + list.add("rateLimit.enabled: true"); + list.add("rateLimit.generalLimit.limit: %d".formatted(appContextConfig.limit())); + list.add("rateLimit.generalLimit.durationInSeconds: %d" + .formatted(appContextConfig.limitDurationInSeconds())); + + if (appContextConfig.customLimits() != null) { + appContextConfig.customLimits() + .forEach((bucket, limitConfig) -> { + list.add("rateLimit.customLimits.%s.limit: %d".formatted(bucket, limitConfig.limit())); + list.add("rateLimit.customLimits.%s.durationInSeconds: %d".formatted(bucket, + limitConfig.durationInSeconds())); + }); + } + } + return TestDropwizardAppExtension.forApp(OpikApplication.class) .config("src/test/resources/config-test.yml") .configOverrides(list.toArray(new String[0])) diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetsResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetsResourceTest.java index 537d8ccc..9b6a4108 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetsResourceTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetsResourceTest.java @@ -123,6 +123,7 @@ class DatasetsResourceTest { private static final TestDropwizardAppExtension app; private static final WireMockRuntime wireMock; + public static final String[] DATASET_IGNORED_FIELDS = {"id", "createdAt", "lastUpdatedAt", "createdBy", "lastUpdatedBy", "experimentCount", "mostRecentExperimentAt", "experimentCount"}; diff --git a/apps/opik-backend/src/test/java/com/comet/opik/domain/DummyLockService.java b/apps/opik-backend/src/test/java/com/comet/opik/domain/DummyLockService.java index faf96f99..90e33543 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/domain/DummyLockService.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/domain/DummyLockService.java @@ -1,6 +1,6 @@ package com.comet.opik.domain; -import com.comet.opik.infrastructure.redis.LockService; +import com.comet.opik.infrastructure.lock.LockService; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; diff --git a/apps/opik-backend/src/test/java/com/comet/opik/domain/SpanServiceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/domain/SpanServiceTest.java index c5e404b9..31a4184f 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/domain/SpanServiceTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/domain/SpanServiceTest.java @@ -2,7 +2,7 @@ import com.comet.opik.api.SpanUpdate; import com.comet.opik.api.error.InvalidUUIDVersionException; -import com.comet.opik.infrastructure.redis.LockService; +import com.comet.opik.infrastructure.lock.LockService; import com.comet.opik.podam.PodamFactoryUtils; import com.fasterxml.uuid.Generators; import com.fasterxml.uuid.impl.TimeBasedEpochGenerator; diff --git a/apps/opik-backend/src/test/java/com/comet/opik/domain/TraceServiceImplTest.java b/apps/opik-backend/src/test/java/com/comet/opik/domain/TraceServiceImplTest.java index 2188190f..cce25f38 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/domain/TraceServiceImplTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/domain/TraceServiceImplTest.java @@ -8,7 +8,7 @@ import com.comet.opik.api.error.InvalidUUIDVersionException; import com.comet.opik.infrastructure.auth.RequestContext; import com.comet.opik.infrastructure.db.TransactionTemplate; -import com.comet.opik.infrastructure.redis.LockService; +import com.comet.opik.infrastructure.lock.LockService; import com.fasterxml.uuid.Generators; import io.r2dbc.spi.Connection; import org.junit.jupiter.api.Assertions; diff --git a/apps/opik-backend/src/test/java/com/comet/opik/infrastructure/ratelimit/RateLimitE2ETest.java b/apps/opik-backend/src/test/java/com/comet/opik/infrastructure/ratelimit/RateLimitE2ETest.java new file mode 100644 index 00000000..b889d7b2 --- /dev/null +++ b/apps/opik-backend/src/test/java/com/comet/opik/infrastructure/ratelimit/RateLimitE2ETest.java @@ -0,0 +1,638 @@ +package com.comet.opik.infrastructure.ratelimit; + +import com.comet.opik.api.DatasetItem; +import com.comet.opik.api.DatasetItemBatch; +import com.comet.opik.api.ExperimentItem; +import com.comet.opik.api.ExperimentItemsBatch; +import com.comet.opik.api.FeedbackScoreBatch; +import com.comet.opik.api.FeedbackScoreBatchItem; +import com.comet.opik.api.Span; +import com.comet.opik.api.SpanBatch; +import com.comet.opik.api.Trace; +import com.comet.opik.api.TraceBatch; +import com.comet.opik.api.resources.utils.AuthTestUtils; +import com.comet.opik.api.resources.utils.ClickHouseContainerUtils; +import com.comet.opik.api.resources.utils.ClientSupportUtils; +import com.comet.opik.api.resources.utils.MigrationUtils; +import com.comet.opik.api.resources.utils.MySQLContainerUtils; +import com.comet.opik.api.resources.utils.RedisContainerUtils; +import com.comet.opik.api.resources.utils.TestDropwizardAppExtensionUtils; +import com.comet.opik.api.resources.utils.WireMockUtils; +import com.comet.opik.infrastructure.auth.RequestContext; +import com.comet.opik.podam.PodamFactoryUtils; +import com.redis.testcontainers.RedisContainer; +import io.dropwizard.jersey.errors.ErrorMessage; +import io.reactivex.rxjava3.internal.operators.single.SingleDelay; +import io.swagger.v3.oas.annotations.parameters.RequestBody; +import jakarta.ws.rs.Consumes; +import jakarta.ws.rs.HttpMethod; +import jakarta.ws.rs.POST; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.Produces; +import jakarta.ws.rs.client.Entity; +import jakarta.ws.rs.client.Invocation; +import jakarta.ws.rs.core.HttpHeaders; +import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; +import org.apache.hc.core5.http.HttpStatus; +import org.jdbi.v3.core.Jdbi; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.testcontainers.containers.ClickHouseContainer; +import org.testcontainers.containers.MySQLContainer; +import org.testcontainers.junit.jupiter.Testcontainers; +import reactor.core.publisher.Flux; +import ru.vyarus.dropwizard.guice.test.ClientSupport; +import ru.vyarus.dropwizard.guice.test.jupiter.ext.TestDropwizardAppExtension; +import uk.co.jemos.podam.api.PodamFactory; + +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.Stream; + +import static com.comet.opik.api.Trace.TracePage; +import static com.comet.opik.api.resources.utils.ClickHouseContainerUtils.DATABASE_NAME; +import static com.comet.opik.api.resources.utils.MigrationUtils.CLICKHOUSE_CHANGELOG_FILE; +import static com.comet.opik.api.resources.utils.TestDropwizardAppExtensionUtils.AppContextConfig; +import static com.comet.opik.infrastructure.RateLimitConfig.LimitConfig; +import static com.comet.opik.infrastructure.auth.RequestContext.WORKSPACE_HEADER; +import static java.util.stream.Collectors.counting; +import static java.util.stream.Collectors.groupingBy; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; + +@Testcontainers(parallel = true) +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +@DisplayName("Rate limit Resource Test") +class RateLimitE2ETest { + + private static final String BASE_RESOURCE_URI = "%s/v1/private/traces"; + + private static final RedisContainer REDIS = RedisContainerUtils.newRedisContainer(); + private static final MySQLContainer MYSQL = MySQLContainerUtils.newMySQLContainer(); + private static final ClickHouseContainer CLICKHOUSE = ClickHouseContainerUtils.newClickHouseContainer(); + + @RegisterExtension + private static final TestDropwizardAppExtension app; + private static final WireMockUtils.WireMockRuntime wireMock; + + private static final long LIMIT = 4L; + private static final long LIMIT_DURATION_IN_SECONDS = 1L; + public static final String CUSTOM_LIMIT = "customLimit"; + + private final PodamFactory factory = PodamFactoryUtils.newPodamFactory(); + + @Path("/v1/private/test") + @Produces(MediaType.APPLICATION_JSON) + @Consumes(MediaType.APPLICATION_JSON) + public static class CustomRatedBean { + + @POST + @RateLimited(value = CUSTOM_LIMIT) + public Response test(@RequestBody String test) { + return Response.status(Response.Status.CREATED).build(); + } + } + + static { + MYSQL.start(); + CLICKHOUSE.start(); + REDIS.start(); + + wireMock = WireMockUtils.startWireMock(); + + var databaseAnalyticsFactory = ClickHouseContainerUtils.newDatabaseAnalyticsFactory( + CLICKHOUSE, DATABASE_NAME); + + app = TestDropwizardAppExtensionUtils.newTestDropwizardAppExtension( + AppContextConfig.builder() + .jdbcUrl(MYSQL.getJdbcUrl()) + .databaseAnalyticsFactory(databaseAnalyticsFactory) + .runtimeInfo(wireMock.runtimeInfo()) + .redisUrl(REDIS.getRedisURI()) + .rateLimitEnabled(true) + .limit(LIMIT) + .limitDurationInSeconds(LIMIT_DURATION_IN_SECONDS) + .customLimits(Map.of(CUSTOM_LIMIT, new LimitConfig(1, 1))) + .build()); + } + + private String baseURI; + private ClientSupport client; + + @BeforeAll + void setUpAll(ClientSupport client, Jdbi jdbi) throws Exception { + + MigrationUtils.runDbMigration(jdbi, MySQLContainerUtils.migrationParameters()); + + try (var connection = CLICKHOUSE.createConnection("")) { + MigrationUtils.runDbMigration(connection, CLICKHOUSE_CHANGELOG_FILE, + ClickHouseContainerUtils.migrationParameters()); + } + + this.baseURI = "http://localhost:%d".formatted(client.getPort()); + this.client = client; + + ClientSupportUtils.config(client); + } + + @AfterAll + void tearDownAll() { + wireMock.server().stop(); + } + + private static void mockTargetWorkspace(String apiKey, String workspaceName, String workspaceId, String user) { + AuthTestUtils.mockTargetWorkspace(wireMock.server(), apiKey, workspaceName, workspaceId, user); + } + + private static void mockSessionCookieTargetWorkspace(String sessionToken, String workspaceName, String workspaceId, + String user) { + AuthTestUtils.mockSessionCookieTargetWorkspace(wireMock.server(), sessionToken, workspaceName, workspaceId, + user); + } + + @Test + @DisplayName("Rate limit: When using apiKey and limit is exceeded, Then block remaining calls") + void rateLimit__whenUsingApiKeyAndLimitIsExceeded__shouldBlockRemainingCalls() { + + String apiKey = UUID.randomUUID().toString(); + String user = UUID.randomUUID().toString(); + String workspaceId = UUID.randomUUID().toString(); + String workspaceName = UUID.randomUUID().toString(); + + mockTargetWorkspace(apiKey, workspaceName, workspaceId, user); + + String projectName = UUID.randomUUID().toString(); + + Map responseMap = triggerCallsWithApiKey(LIMIT * 2, projectName, apiKey, workspaceName); + + assertEquals(LIMIT, responseMap.get(HttpStatus.SC_TOO_MANY_REQUESTS)); + assertEquals(LIMIT, responseMap.get(HttpStatus.SC_CREATED)); + + try (var response = client.target(BASE_RESOURCE_URI.formatted(baseURI)) + .queryParam("project_name", projectName) + .queryParam("size", LIMIT * 2) + .request() + .accept(MediaType.APPLICATION_JSON_TYPE) + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .get()) { + + // Verify that traces created are equal to the limit + assertEquals(HttpStatus.SC_OK, response.getStatus()); + TracePage page = response.readEntity(TracePage.class); + + assertEquals(LIMIT, page.content().size()); + assertEquals(LIMIT, page.total()); + assertEquals(LIMIT, page.size()); + } + + } + + @Test + @DisplayName("Rate limit: When using apiKey and limit is not exceeded given duration, Then allow all calls") + void rateLimit__whenUsingApiKeyAndLimitIsNotExceededGivenDuration__thenAllowAllCalls() { + + String apiKey = UUID.randomUUID().toString(); + String user = UUID.randomUUID().toString(); + String workspaceId = UUID.randomUUID().toString(); + String workspaceName = UUID.randomUUID().toString(); + + mockTargetWorkspace(apiKey, workspaceName, workspaceId, user); + + String projectName = UUID.randomUUID().toString(); + + Map responseMap = triggerCallsWithApiKey(LIMIT, projectName, apiKey, workspaceName); + + assertEquals(LIMIT, responseMap.get(HttpStatus.SC_CREATED)); + + SingleDelay.timer(LIMIT_DURATION_IN_SECONDS, TimeUnit.SECONDS).blockingGet(); + + responseMap = triggerCallsWithApiKey(LIMIT, projectName, apiKey, workspaceName); + + assertEquals(LIMIT, responseMap.get(HttpStatus.SC_CREATED)); + + try (var response = client.target(BASE_RESOURCE_URI.formatted(baseURI)) + .queryParam("project_name", projectName) + .queryParam("size", LIMIT * 2) + .request() + .accept(MediaType.APPLICATION_JSON_TYPE) + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .get()) { + + assertEquals(HttpStatus.SC_OK, response.getStatus()); + TracePage page = response.readEntity(TracePage.class); + + assertEquals(LIMIT * 2, page.content().size()); + assertEquals(LIMIT * 2, page.total()); + assertEquals(LIMIT * 2, page.size()); + } + + } + + @Test + @DisplayName("Rate limit: When using sessionToken and limit is exceeded, Then block remaining calls") + void rateLimit__whenUsingSessionTokenAndLimitIsExceeded__shouldBlockRemainingCalls() { + + String sessionToken = UUID.randomUUID().toString(); + String user = UUID.randomUUID().toString(); + String workspaceId = UUID.randomUUID().toString(); + String workspaceName = UUID.randomUUID().toString(); + + mockSessionCookieTargetWorkspace(sessionToken, workspaceName, workspaceId, user); + + String projectName = UUID.randomUUID().toString(); + + Map responseMap = triggerCallsWithCookie(LIMIT * 2, projectName, sessionToken, workspaceName); + + assertEquals(LIMIT, responseMap.get(HttpStatus.SC_TOO_MANY_REQUESTS)); + assertEquals(LIMIT, responseMap.get(HttpStatus.SC_CREATED)); + + try (var response = client.target(BASE_RESOURCE_URI.formatted(baseURI)) + .queryParam("project_name", projectName) + .queryParam("size", LIMIT * 2) + .request() + .accept(MediaType.APPLICATION_JSON_TYPE) + .cookie(RequestContext.SESSION_COOKIE, sessionToken) + .header(WORKSPACE_HEADER, workspaceName) + .get()) { + + assertEquals(HttpStatus.SC_OK, response.getStatus()); + TracePage page = response.readEntity(TracePage.class); + + assertEquals(LIMIT, page.content().size()); + assertEquals(LIMIT, page.total()); + assertEquals(LIMIT, page.size()); + } + + } + + @Test + @DisplayName("Rate limit: When using sessionToken and limit is not exceeded given duration, Then allow all calls") + void rateLimit__whenUsingSessionTokenAndLimitIsNotExceededGivenDuration__thenAllowAllCalls() { + + String sessionToken = UUID.randomUUID().toString(); + String user = UUID.randomUUID().toString(); + String workspaceId = UUID.randomUUID().toString(); + String workspaceName = UUID.randomUUID().toString(); + + mockSessionCookieTargetWorkspace(sessionToken, workspaceName, workspaceId, user); + + String projectName = UUID.randomUUID().toString(); + + Map responseMap = triggerCallsWithCookie(LIMIT, projectName, sessionToken, workspaceName); + + assertEquals(LIMIT, responseMap.get(HttpStatus.SC_CREATED)); + + SingleDelay.timer(LIMIT_DURATION_IN_SECONDS, TimeUnit.SECONDS).blockingGet(); + + responseMap = triggerCallsWithCookie(LIMIT, projectName, sessionToken, workspaceName); + + assertEquals(LIMIT, responseMap.get(HttpStatus.SC_CREATED)); + + try (var response = client.target(BASE_RESOURCE_URI.formatted(baseURI)) + .queryParam("project_name", projectName) + .queryParam("size", LIMIT * 2) + .request() + .accept(MediaType.APPLICATION_JSON_TYPE) + .cookie(RequestContext.SESSION_COOKIE, sessionToken) + .header(WORKSPACE_HEADER, workspaceName) + .get()) { + + // Verify that traces created are equal to the limit + assertEquals(HttpStatus.SC_OK, response.getStatus()); + TracePage page = response.readEntity(TracePage.class); + + assertEquals(LIMIT * 2, page.content().size()); + assertEquals(LIMIT * 2, page.total()); + assertEquals(LIMIT * 2, page.size()); + } + + } + + @Test + @DisplayName("Rate limit: When remaining limit is less than the batch size, Then reject the request") + void rateLimit__whenRemainingLimitIsLessThanRequestedSize__thenRejectTheRequest() { + + String apiKey = UUID.randomUUID().toString(); + String user = UUID.randomUUID().toString(); + String workspaceId = UUID.randomUUID().toString(); + String workspaceName = UUID.randomUUID().toString(); + + mockTargetWorkspace(apiKey, workspaceName, workspaceId, user); + + String projectName = UUID.randomUUID().toString(); + + Map responseMap = triggerCallsWithApiKey(1, projectName, apiKey, workspaceName); + + assertEquals(1, responseMap.get(HttpStatus.SC_CREATED)); + + List traces = IntStream.range(0, (int) LIMIT) + .mapToObj(i -> factory.manufacturePojo(Trace.class).toBuilder() + .projectName(projectName) + .projectId(null) + .build()) + .toList(); + + try (var response = client.target(BASE_RESOURCE_URI.formatted(baseURI)) + .path("batch") + .request() + .accept(MediaType.APPLICATION_JSON_TYPE) + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .post(Entity.json(new TraceBatch(traces)))) { + + assertEquals(HttpStatus.SC_TOO_MANY_REQUESTS, response.getStatus()); + var error = response.readEntity(ErrorMessage.class); + assertEquals("Too Many Requests", error.getMessage()); + } + } + + @Test + @DisplayName("Rate limit: When after reject request due to batch size, Then accept the request with remaining limit") + void rateLimit__whenAfterRejectRequestDueToBatchSize__thenAcceptTheRequestWithRemainingLimit() { + + String apiKey = UUID.randomUUID().toString(); + String user = UUID.randomUUID().toString(); + String workspaceId = UUID.randomUUID().toString(); + String workspaceName = UUID.randomUUID().toString(); + + mockTargetWorkspace(apiKey, workspaceName, workspaceId, user); + + String projectName = UUID.randomUUID().toString(); + + Map responseMap = triggerCallsWithApiKey(1, projectName, apiKey, workspaceName); + + assertEquals(1, responseMap.get(HttpStatus.SC_CREATED)); + + List traces = IntStream.range(0, (int) LIMIT) + .mapToObj(i -> factory.manufacturePojo(Trace.class).toBuilder() + .projectName(projectName) + .projectId(null) + .build()) + .toList(); + + try (var response = client.target(BASE_RESOURCE_URI.formatted(baseURI)) + .path("batch") + .request() + .accept(MediaType.APPLICATION_JSON_TYPE) + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .post(Entity.json(new TraceBatch(traces)))) { + + assertEquals(HttpStatus.SC_TOO_MANY_REQUESTS, response.getStatus()); + var error = response.readEntity(ErrorMessage.class); + assertEquals("Too Many Requests", error.getMessage()); + } + + try (var response = client.target(BASE_RESOURCE_URI.formatted(baseURI)) + .path("batch") + .request() + .accept(MediaType.APPLICATION_JSON_TYPE) + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .post(Entity.json(new TraceBatch(traces.subList(0, (int) LIMIT - 1))))) { + + assertEquals(HttpStatus.SC_NO_CONTENT, response.getStatus()); + } + } + + @ParameterizedTest + @MethodSource + @DisplayName("Rate limit: When batch endpoint consumer remaining limit, Then reject next request") + void rateLimit__whenBatchEndpointConsumerRemainingLimit__thenRejectNextRequest( + Object batch, + Object batch2, + String url, + String method) { + + String apiKey = UUID.randomUUID().toString(); + String user = UUID.randomUUID().toString(); + String workspaceId = UUID.randomUUID().toString(); + String workspaceName = UUID.randomUUID().toString(); + + mockTargetWorkspace(apiKey, workspaceName, workspaceId, user); + + Invocation.Builder request = client.target(url) + .request() + .accept(MediaType.APPLICATION_JSON_TYPE) + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName); + + try (var response = request.method(method, Entity.json(batch))) { + + assertEquals(HttpStatus.SC_NO_CONTENT, response.getStatus()); + } + + try (var response = request.method(method, Entity.json(batch2))) { + + assertEquals(HttpStatus.SC_TOO_MANY_REQUESTS, response.getStatus()); + var error = response.readEntity(ErrorMessage.class); + assertEquals("Too Many Requests", error.getMessage()); + } + } + + @Test + @DisplayName("Rate limit: When processing operations, Then return remaining limit as header") + void rateLimit__whenProcessingOperations__thenReturnRemainingLimitAsHeader() { + + String apiKey = UUID.randomUUID().toString(); + String user = UUID.randomUUID().toString(); + String workspaceId = UUID.randomUUID().toString(); + String workspaceName = UUID.randomUUID().toString(); + + mockTargetWorkspace(apiKey, workspaceName, workspaceId, user); + + String projectName = UUID.randomUUID().toString(); + + IntStream.range(0, (int) LIMIT + 1).forEach(i -> { + Trace trace = factory.manufacturePojo(Trace.class).toBuilder() + .projectName(projectName) + .build(); + + try (var response = client.target(BASE_RESOURCE_URI.formatted(baseURI)) + .request() + .accept(MediaType.APPLICATION_JSON_TYPE) + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .post(Entity.json(trace))) { + + if (i < LIMIT) { + assertEquals(HttpStatus.SC_CREATED, response.getStatus()); + + assertLimitHeaders(response, LIMIT - i - 1, RateLimited.GENERAL_EVENTS, + (int) LIMIT_DURATION_IN_SECONDS); + } else { + assertEquals(HttpStatus.SC_TOO_MANY_REQUESTS, response.getStatus()); + } + } + }); + } + + public Stream rateLimit__whenBatchEndpointConsumerRemainingLimit__thenRejectNextRequest() { + + var projectName = UUID.randomUUID().toString(); + + var traces = IntStream.range(0, (int) LIMIT) + .mapToObj(i -> factory.manufacturePojo(Trace.class).toBuilder() + .projectName(projectName) + .projectId(null) + .build()) + .toList(); + + var spans = IntStream.range(0, (int) LIMIT) + .mapToObj(i -> factory.manufacturePojo(Span.class).toBuilder() + .projectName(projectName) + .projectId(null) + .parentSpanId(null) + .build()) + .toList(); + + var datasetItems = IntStream.range(0, (int) LIMIT) + .mapToObj(i -> factory.manufacturePojo(DatasetItem.class).toBuilder() + .experimentItems(null) + .build()) + .toList(); + + var tracesFeedbackScores = IntStream.range(0, (int) LIMIT) + .mapToObj(i -> factory.manufacturePojo(FeedbackScoreBatchItem.class).toBuilder() + .projectId(null) + .build()) + .toList(); + + var spansFeedbackScores = IntStream.range(0, (int) LIMIT) + .mapToObj(i -> factory.manufacturePojo(FeedbackScoreBatchItem.class).toBuilder() + .projectId(null) + .build()) + .toList(); + + var experimentItems = IntStream.range(0, (int) LIMIT) + .mapToObj(i -> factory.manufacturePojo(ExperimentItem.class).toBuilder() + .feedbackScores(null) + .build()) + .collect(Collectors.toSet()); + + return Stream.of( + Arguments.of(new TraceBatch(traces), new TraceBatch(List.of(traces.getFirst())), + BASE_RESOURCE_URI.formatted(baseURI) + "/batch", HttpMethod.POST), + Arguments.of(new SpanBatch(spans), new SpanBatch(List.of(spans.getFirst())), + "%s/v1/private/spans".formatted(baseURI) + "/batch", HttpMethod.POST), + Arguments.of(new DatasetItemBatch(projectName, null, datasetItems), + new DatasetItemBatch(projectName, null, List.of(datasetItems.getFirst())), + "%s/v1/private/datasets".formatted(baseURI) + "/items", HttpMethod.PUT), + Arguments.of(new FeedbackScoreBatch(tracesFeedbackScores), + new FeedbackScoreBatch(List.of(tracesFeedbackScores.getFirst())), + BASE_RESOURCE_URI.formatted(baseURI) + "/feedback-scores", HttpMethod.PUT), + Arguments.of(new FeedbackScoreBatch(spansFeedbackScores), + new FeedbackScoreBatch(List.of(spansFeedbackScores.getFirst())), + "%s/v1/private/spans".formatted(baseURI) + "/feedback-scores", HttpMethod.PUT), + Arguments.of(new ExperimentItemsBatch(experimentItems), + new ExperimentItemsBatch(Set.of(experimentItems.stream().findFirst().orElseThrow())), + "%s/v1/private/experiments".formatted(baseURI) + "/items", HttpMethod.POST)); + } + + @Test + @DisplayName("Rate limit: When custom rated bean method is called, Then rate limit is applied") + void rateLimit__whenCustomRatedBeanMethodIsCalled__thenRateLimitIsApplied() { + String apiKey = UUID.randomUUID().toString(); + String user = UUID.randomUUID().toString(); + String workspaceId = UUID.randomUUID().toString(); + String workspaceName = UUID.randomUUID().toString(); + + mockTargetWorkspace(apiKey, workspaceName, workspaceId, user); + + try (var response = client.target("%s/v1/private/test".formatted(baseURI)) + .request() + .accept(MediaType.APPLICATION_JSON_TYPE) + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .post(Entity.json(""))) { + + assertEquals(HttpStatus.SC_CREATED, response.getStatus()); + + assertLimitHeaders(response, 0, CUSTOM_LIMIT, 1); + } + + try (var response = client.target("%s/v1/private/test".formatted(baseURI)) + .request() + .accept(MediaType.APPLICATION_JSON_TYPE) + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .post(Entity.json(""))) { + + assertEquals(HttpStatus.SC_TOO_MANY_REQUESTS, response.getStatus()); + + assertLimitHeaders(response, 0, CUSTOM_LIMIT, 1); + } + + } + + private static void assertLimitHeaders(Response response, long expected, String limitBucket, int limitDuration) { + String remainingLimit = response.getHeaderString(RequestContext.USER_REMAINING_LIMIT); + String userLimit = response.getHeaderString(RequestContext.USER_LIMIT); + String remainingTtl = response.getHeaderString(RequestContext.USER_LIMIT_REMAINING_TTL); + + assertEquals(expected, Long.parseLong(remainingLimit)); + assertEquals(limitBucket, userLimit); + assertThat(Long.parseLong(remainingTtl)).isBetween(0L, Duration.ofSeconds(limitDuration).toMillis()); + } + + private Map triggerCallsWithCookie(long limit, String projectName, String sessionToken, + String workspaceName) { + return Flux.range(0, ((int) limit)) + .flatMap(i -> { + Trace trace = factory.manufacturePojo(Trace.class).toBuilder() + .projectName(projectName) + .build(); + + try (var response = client.target(BASE_RESOURCE_URI.formatted(baseURI)) + .request() + .accept(MediaType.APPLICATION_JSON_TYPE) + .cookie(RequestContext.SESSION_COOKIE, sessionToken) + .header(WORKSPACE_HEADER, workspaceName) + .post(Entity.json(trace))) { + + return Flux.just(response); + } + }, 5) + .toStream() + .collect(groupingBy(Response::getStatus, counting())); + } + + private Map triggerCallsWithApiKey(long limit, String projectName, String apiKey, + String workspaceName) { + return Flux.range(0, ((int) limit)) + .flatMap(i -> { + Trace trace = factory.manufacturePojo(Trace.class).toBuilder() + .projectName(projectName) + .build(); + + try (var response = client.target(BASE_RESOURCE_URI.formatted(baseURI)) + .request() + .accept(MediaType.APPLICATION_JSON_TYPE) + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .post(Entity.json(trace))) { + + return Flux.just(response); + } + }, 5) + .toStream() + .collect(groupingBy(Response::getStatus, counting())); + } + +} \ No newline at end of file diff --git a/apps/opik-backend/src/test/java/com/comet/opik/infrastructure/ratelimit/RateLimitSetupTest.java b/apps/opik-backend/src/test/java/com/comet/opik/infrastructure/ratelimit/RateLimitSetupTest.java new file mode 100644 index 00000000..c763ce3c --- /dev/null +++ b/apps/opik-backend/src/test/java/com/comet/opik/infrastructure/ratelimit/RateLimitSetupTest.java @@ -0,0 +1,97 @@ +package com.comet.opik.infrastructure.ratelimit; + +import com.comet.opik.api.resources.v1.priv.DatasetsResource; +import com.comet.opik.api.resources.v1.priv.ExperimentsResource; +import com.comet.opik.api.resources.v1.priv.FeedbackDefinitionResource; +import com.comet.opik.api.resources.v1.priv.ProjectsResource; +import com.comet.opik.api.resources.v1.priv.SpansResource; +import com.comet.opik.api.resources.v1.priv.TracesResource; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.lang.reflect.Method; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Stream; + +class RateLimitSetupTest { + + @Test + void allEventFromDatasetsResourceShouldBeRateLimited() { + + // Given + Stream.of("createDataset", "updateDataset", "createDatasetItems") + .forEach(methodName -> { + // Then + assertIfMethodAreAnnotated(methodName, DatasetsResource.class); + }); + } + + @Test + void allEventFromExperimentsResourceShouldBeRateLimited() { + + // Given + Stream.of("create", "createExperimentItems") + .forEach(methodName -> { + // Then + assertIfMethodAreAnnotated(methodName, ExperimentsResource.class); + }); + } + + private void assertIfMethodAreAnnotated(String methodName, Class targetClass) { + List targetMethods = Arrays.stream(targetClass.getMethods()) + .filter(method -> method.getName().equals(methodName)) + .toList(); + + boolean actualMatch = !targetMethods.isEmpty() && targetMethods.stream() + .allMatch(method -> method.isAnnotationPresent(RateLimited.class)); + + Assertions.assertTrue(actualMatch, + "Method %s.%s is not annotated".formatted(targetClass.getSimpleName(), methodName)); + } + + @Test + void allEventFromFeedbackResourceShouldBeRateLimited() { + + // Given + Stream.of("create", "update") + .forEach(methodName -> { + // Then + assertIfMethodAreAnnotated(methodName, FeedbackDefinitionResource.class); + }); + } + + @Test + void allEventFromProjectsResourceShouldBeRateLimited() { + + // Given + Stream.of("create", "update") + .forEach(methodName -> { + // Then + assertIfMethodAreAnnotated(methodName, ProjectsResource.class); + }); + } + + @Test + void allEventFromSpansResourceShouldBeRateLimited() { + + // Given + Stream.of("create", "createSpans", "update", "addSpanFeedbackScore", "scoreBatchOfSpans") + .forEach(methodName -> { + // Then + assertIfMethodAreAnnotated(methodName, SpansResource.class); + }); + } + + @Test + void allEventFromTracesResourceShouldBeRateLimited() { + + // Given + Stream.of("create", "createTraces", "update", "addTraceFeedbackScore", "scoreBatchOfTraces") + .forEach(methodName -> { + // Then + assertIfMethodAreAnnotated(methodName, TracesResource.class); + }); + } + +} diff --git a/apps/opik-backend/src/test/java/com/comet/opik/infrastructure/redis/RedissonLockServiceIntegrationTest.java b/apps/opik-backend/src/test/java/com/comet/opik/infrastructure/redis/RedissonLockServiceIntegrationTest.java index 728921b1..660158af 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/infrastructure/redis/RedissonLockServiceIntegrationTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/infrastructure/redis/RedissonLockServiceIntegrationTest.java @@ -4,6 +4,7 @@ import com.comet.opik.api.resources.utils.MySQLContainerUtils; import com.comet.opik.api.resources.utils.RedisContainerUtils; import com.comet.opik.api.resources.utils.TestDropwizardAppExtensionUtils; +import com.comet.opik.infrastructure.lock.LockService; import com.redis.testcontainers.RedisContainer; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; diff --git a/apps/opik-backend/src/test/resources/config-test.yml b/apps/opik-backend/src/test/resources/config-test.yml index 412ecf9c..19dddb1d 100644 --- a/apps/opik-backend/src/test/resources/config-test.yml +++ b/apps/opik-backend/src/test/resources/config-test.yml @@ -65,3 +65,6 @@ server: enableVirtualThreads: ${ENABLE_VIRTUAL_THREADS:-false} gzip: enabled: true + +rateLimit: + enabled: false