Skip to content

Commit

Permalink
avoid CSRF attacks by
Browse files Browse the repository at this point in the history
having API-calls send Spring-created CSRF tokens and checking them.
Also https with self-signed certificates is included - but commented out for experimenting
  • Loading branch information
wisskirchenj committed Apr 27, 2024
1 parent bf463ac commit 4d39231
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 44 deletions.
1 change: 1 addition & 0 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@ spring-boot = "3.2.4"
spring-dependency-management = "1.1.4"
node-gradle = "7.0.2"
sonar-gradle = "5.0.0.4638"
git-gradle = "2.4.1"
node-js = "22.0.0"
graalvm-buildtools = "0.10.1"
8 changes: 5 additions & 3 deletions server/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ plugins {
id("org.springframework.boot") version libs.versions.spring.boot
id("io.spring.dependency-management") version libs.versions.spring.dependency.management
id("org.sonarqube") version libs.versions.sonar.gradle
id("com.gorylenko.gradle-git-properties") version "2.4.1"
id("com.gorylenko.gradle-git-properties") version libs.versions.git.gradle
jacoco
}

Expand All @@ -23,13 +23,16 @@ tasks.jacocoTestReport {
}
}

jacoco {
reportsDirectory = layout.buildDirectory.dir("reports/jacoco")
}

val sonarToken: String by project
sonar {
properties {
property("sonar.token", sonarToken)
property("sonar.projectKey", "flashcards-server")
property("sonar.projectName", "Flashcards Server")
property("sonar.jacoco.reportPaths", "build/reports/jacoco")
property("sonar.junit.reportPaths", "build/test-results/test")
property("sonar.host.url", "http://localhost:9000")
}
Expand All @@ -39,7 +42,6 @@ tasks.sonar {
dependsOn(tasks.jacocoTestReport)
}


configurations {
compileOnly {
extendsFrom(configurations.annotationProcessor.get())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,28 @@
import java.util.Objects;
import java.util.Optional;

import static org.hyperskill.community.flashcards.card.model.Card.QUESTION_KEY;
import static org.hyperskill.community.flashcards.card.model.Card.TAGS_KEY;
import static org.hyperskill.community.flashcards.card.model.Card.TITLE_KEY;
import static org.springframework.data.mongodb.core.query.Criteria.where;

@Service
@Slf4j
@RequiredArgsConstructor
public class CardService {
public static final String CATEGORY_ID_NON_NULL = "Category ID cannot be null";
public static final String CARD_ID_NON_NULL = "Card ID cannot be null";
private static final int PAGE_SIZE = 20;

private final MongoTemplate mongoTemplate;
private final CategoryService categoryService;
private final CardMapper mapper;

public Page<Card> getCardsByCategory(String username, String categoryId, int page, String titleFilter) {
Objects.requireNonNull(categoryId, "Category ID cannot be null");
Objects.requireNonNull(categoryId, CATEGORY_ID_NON_NULL);

final var category = categoryService.findById(username, categoryId);
var pageRequest = PageRequest.of(page, PAGE_SIZE, Sort.by("question"));
var pageRequest = PageRequest.of(page, PAGE_SIZE, Sort.by(QUESTION_KEY));
var query = createFilterQuery(titleFilter);
var count = mongoTemplate.count(query, category.name());
var cards = mongoTemplate.find(query.with(pageRequest), Card.class, category.name());
Expand All @@ -59,9 +65,9 @@ private Query createFilterQuery(String filter) {
if (StringUtils.hasText(filter)) {
var pattern = ".*" + filter + ".*";
var criteria = new Criteria().orOperator(
where("title").regex(pattern, "i"),
where("tags").regex(pattern, "i"),
where("question").regex(pattern, "i")
where(TITLE_KEY).regex(pattern, "i"),
where(TAGS_KEY).regex(pattern, "i"),
where(QUESTION_KEY).regex(pattern, "i")
);
query.addCriteria(criteria);
}
Expand All @@ -80,8 +86,8 @@ public String createCard(String username, CardRequest request, String categoryId
}

public Card getCardById(String username, String cardId, String categoryId) {
Objects.requireNonNull(cardId, "Card ID cannot be null");
Objects.requireNonNull(categoryId, "Category ID cannot be null");
Objects.requireNonNull(cardId, CARD_ID_NON_NULL);
Objects.requireNonNull(categoryId, CATEGORY_ID_NON_NULL);

var category = categoryService.findById(username, categoryId);
var card = Optional.ofNullable(mongoTemplate.findById(cardId, Card.class, category.name()))
Expand All @@ -90,16 +96,16 @@ public Card getCardById(String username, String cardId, String categoryId) {
}

public long deleteCardById(String username, String cardId, String categoryId) {
Objects.requireNonNull(cardId, "Card ID cannot be null");
Objects.requireNonNull(cardId, CARD_ID_NON_NULL);

var collection = getCollectionName(username, categoryId, "d");
var query = Query.query(where(Card.ID_KEY).is(cardId));
return mongoTemplate.remove(query, collection).getDeletedCount();
}

public Card updateCardById(String username, String cardId, CardRequest request, String categoryId) {
Objects.requireNonNull(cardId, "Card ID cannot be null");
Objects.requireNonNull(categoryId, "Category ID cannot be null");
Objects.requireNonNull(cardId, CARD_ID_NON_NULL);
Objects.requireNonNull(categoryId, CATEGORY_ID_NON_NULL);

var category = categoryService.findById(username, categoryId, "w");
var cardBeforeUpdate = getCardById(username, cardId, categoryId);
Expand Down Expand Up @@ -128,17 +134,17 @@ private String getCollectionName(String username, String categoryId) {
* @return category name
*/
private String getCollectionName(String username, String categoryId, String permission) {
Objects.requireNonNull(categoryId, "Category ID cannot be null");
Objects.requireNonNull(categoryId, CATEGORY_ID_NON_NULL);
return categoryService.findById(username, categoryId, permission).name();
}

private Update updateFrom(CardRequest request) {
Objects.requireNonNull(request, "Update request cannot be null");

var update = new Update()
.set("title", request.title())
.set("question", request.question())
.set("tags", request.tags());
.set(TITLE_KEY, request.title())
.set(QUESTION_KEY, request.question())
.set(TAGS_KEY, request.tags());

return switch (request) {
case QuestionAndAnswerRequestDto qna -> update.set("answer", qna.answer());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
public sealed interface Card permits QuestionAndAnswer, SingleChoiceQuiz, MultipleChoiceQuiz {

String ID_KEY = "_id";
String QUESTION_KEY = "question";
String TAGS_KEY = "tags";
String TITLE_KEY = "title";

String id();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import static org.hyperskill.community.flashcards.TestUtils.TEST2;
import static org.hyperskill.community.flashcards.TestUtils.jwtUser;
import static org.hyperskill.community.flashcards.TestUtils.oidc;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.jwt;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.delete;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
Expand Down Expand Up @@ -81,13 +82,16 @@ void getCardsValidationError_Gives400(String url) throws Exception {
@Test
void putpostDeleteCardsMissingCategoryId_Gives400() throws Exception {
mockMvc.perform(put("/api/cards/id")
.with(oidc(TEST1)))
.with(oidc(TEST1))
.with(csrf()))
.andExpect(status().isBadRequest());
mockMvc.perform(delete("/api/cards/id")
.with(oidc(TEST1)))
.with(oidc(TEST1))
.with(csrf()))
.andExpect(status().isBadRequest());
mockMvc.perform(post("/api/cards")
.with(oidc(TEST1)))
.with(oidc(TEST1))
.with(csrf()))
.andExpect(status().isBadRequest());
}

Expand All @@ -96,14 +100,16 @@ void putpostDeleteCardsMissingCategoryId_Gives400() throws Exception {
@ValueSource(strings = {"/api/cards", "/api/cards/details"})
void getCardsWithReadRights_givesInitialCardData(String endpoint) throws Exception {
mockMvc.perform(get(endpoint + "?categoryId=" + exampleCategoryId)
.with(jwt().jwt(jwtUser(TEST1))))
.with(jwt().jwt(jwtUser(TEST1)))
.with(csrf()))
.andExpect(status().isOk())
.andExpect(jsonPath("$.totalPages").value(3))
.andExpect(jsonPath("$.isLast").value(false))
.andExpect(jsonPath("$.cards[19]").exists())
.andExpect(jsonPath("$.cards[20]").doesNotExist());
mockMvc.perform(get(endpoint + "?page=2&categoryId=" + exampleCategoryId)
.with(jwt().jwt(jwtUser(TEST1))))
.with(jwt().jwt(jwtUser(TEST1)))
.with(csrf()))
.andExpect(status().isOk())
.andExpect(jsonPath("$.isLast").value(true))
.andExpect(jsonPath("$.cards[19]").exists())
Expand All @@ -113,14 +119,16 @@ void getCardsWithReadRights_givesInitialCardData(String endpoint) throws Excepti
@Test
void getCardsNoReadRights_gives403() throws Exception {
mockMvc.perform(get("/api/cards?categoryId=" + exampleCategoryId)
.with(jwt().jwt(jwtUser(TEST2))))
.with(jwt().jwt(jwtUser(TEST2)))
.with(csrf()))
.andExpect(status().isForbidden());
}

@Test
void getCardsCountWithRights_givesAll() throws Exception {
mockMvc.perform(get("/api/cards/count?categoryId=" + exampleCategoryId)
.with(jwt().jwt(jwtUser(TEST1))))
.with(jwt().jwt(jwtUser(TEST1)))
.with(csrf()))
.andExpect(status().isOk())
.andExpect(jsonPath("$").value(60));
}
Expand All @@ -133,19 +141,22 @@ void createCards_AddsAndDeleteCardDeletes() throws Exception {
.tags(Set.of("t1", "t2"))
.title("title")
.build();
var uri= mockMvc.perform(post("/api/cards?categoryId=" + exampleCategoryId)
.with(oidc(TEST1)).contentType("application/json")
var uri = mockMvc.perform(post("/api/cards?categoryId=" + exampleCategoryId)
.with(oidc(TEST1)).with(csrf()).contentType("application/json")
.content(new ObjectMapper().writeValueAsString(cardRequest)))
.andExpect(status().isCreated())
.andReturn().getResponse().getHeader("Location");
mockMvc.perform(get(uri + "?categoryId=" + exampleCategoryId)
.with(jwt().jwt(jwtUser(TEST1))))
.with(jwt().jwt(jwtUser(TEST1)))
.with(csrf()))
.andExpect(status().isOk());
mockMvc.perform(delete(uri + "?categoryId=" + exampleCategoryId)
.with(jwt().jwt(jwtUser(TEST1))))
.with(jwt().jwt(jwtUser(TEST1)))
.with(csrf()))
.andExpect(status().isOk());
mockMvc.perform(get(uri + "?categoryId=" + exampleCategoryId)
.with(jwt().jwt(jwtUser(TEST1))))
.with(jwt().jwt(jwtUser(TEST1)))
.with(csrf()))
.andExpect(status().isNotFound());
}

Expand All @@ -158,8 +169,8 @@ void updateCardOtherCardType_Updates() throws Exception {
.tags(Set.of("t1", "t2"))
.title("updated")
.build();
var uri= mockMvc.perform(post("/api/cards?categoryId=" + exampleCategoryId)
.with(oidc(TEST1)).contentType("application/json")
var uri = mockMvc.perform(post("/api/cards?categoryId=" + exampleCategoryId)
.with(oidc(TEST1)).with(csrf()).contentType("application/json")
.content(new ObjectMapper().writeValueAsString(cardRequest)))
.andExpect(status().isCreated())
.andReturn().getResponse().getHeader("Location");
Expand All @@ -171,24 +182,27 @@ void updateCardOtherCardType_Updates() throws Exception {
.title("updated")
.build();
mockMvc.perform(put(uri + "?categoryId=" + exampleCategoryId)
.with(oidc(TEST1)).contentType("application/json")
.with(oidc(TEST1)).with(csrf()).contentType("application/json")
.content(new ObjectMapper().writeValueAsString(cardRequest)))
.andExpect(status().isOk());
mockMvc.perform(get(uri + "?categoryId=" + exampleCategoryId)
.with(jwt().jwt(jwtUser(TEST1))))
.with(jwt().jwt(jwtUser(TEST1)))
.with(csrf()))
.andExpect(status().isOk())
.andExpect(jsonPath("$.title").value("updated"))
.andExpect(jsonPath("$.correctOption").value(1));
// clean up to not affect other tests
mockMvc.perform(delete(uri + "?categoryId=" + exampleCategoryId)
.with(jwt().jwt(jwtUser(TEST1))))
.with(jwt().jwt(jwtUser(TEST1)))
.with(csrf()))
.andExpect(status().isOk());
}

@Test
void getCardsById_givesAll() throws Exception {
mockMvc.perform(get("/api/cards/count?categoryId=" + exampleCategoryId)
.with(jwt().jwt(jwtUser(TEST1))))
.with(jwt().jwt(jwtUser(TEST1)))
.with(csrf()))
.andExpect(status().isOk())
.andExpect(jsonPath("$").value(60));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import static org.hyperskill.community.flashcards.TestUtils.TEST2;
import static org.hyperskill.community.flashcards.TestUtils.jwtUser;
import static org.hyperskill.community.flashcards.TestUtils.oidc;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.jwt;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.delete;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
Expand Down Expand Up @@ -61,7 +62,8 @@ void getCategoriesValidationError_Gives400() throws Exception {
@Test
void getCategoriesNotOwner_givesEmptyResponse() throws Exception {
mockMvc.perform(get("/api/categories")
.with(jwt().jwt(jwtUser(TEST2))))
.with(jwt().jwt(jwtUser(TEST2)))
.with(csrf()))
.andExpect(status().isOk())
.andExpect(jsonPath("$.totalPages").value(0))
.andExpect(jsonPath("$.isLast").value(true))
Expand Down Expand Up @@ -115,27 +117,27 @@ void getCategoriesOwnerPage1_isEmpty() throws Exception {
@Test
void createCategory_creates() throws Exception {
var result = mockMvc.perform(post("/api/categories")
.with(oidc(TEST1)).contentType("application/json")
.with(oidc(TEST1)).with(csrf()).contentType("application/json")
.content("{\"name\":\"test\"}"))
.andExpect(status().isCreated())
.andReturn();
var id = Objects.requireNonNull(result.getResponse().getHeader("Location")).split("/")[3];
mockMvc.perform(get("/api/categories/" + id)
.with(oidc(TEST1)))
.with(oidc(TEST1)).with(csrf()))
.andExpect(status().isOk())
.andExpect(jsonPath("$.name").value("test"));
}

@Test
void updateCategory_updates() throws Exception {
var result = mockMvc.perform(post("/api/categories")
.with(oidc(TEST1)).contentType("application/json")
.with(oidc(TEST1)).with(csrf()).contentType("application/json")
.content("{\"name\":\"to-update\"}"))
.andExpect(status().isCreated())
.andReturn();
var id = Objects.requireNonNull(result.getResponse().getHeader("Location")).split("/")[3];
mockMvc.perform(put("/api/categories/" + id)
.with(oidc(TEST1)).contentType("application/json")
.with(oidc(TEST1)).with(csrf()).contentType("application/json")
.content("{\"name\":\"updated\"}"))
.andExpect(status().isOk())
.andExpect(jsonPath("$.name").value("updated"));
Expand All @@ -144,16 +146,16 @@ void updateCategory_updates() throws Exception {
@Test
void deleteCategory_deletes() throws Exception {
var result = mockMvc.perform(post("/api/categories")
.with(oidc(TEST1)).contentType("application/json")
.with(oidc(TEST1)).with(csrf()).contentType("application/json")
.content("{\"name\":\"to-delete\"}"))
.andExpect(status().isCreated())
.andReturn();
var id = Objects.requireNonNull(result.getResponse().getHeader("Location")).split("/")[3];
mockMvc.perform(delete("/api/categories/" + id)
.with(oidc(TEST1)))
.with(oidc(TEST1)).with(csrf()))
.andExpect(status().isOk());
mockMvc.perform(get("/api/categories/" + id)
.with(oidc(TEST1)))
.with(oidc(TEST1)).with(csrf()))
.andExpect(status().isNotFound());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;

Expand Down Expand Up @@ -46,6 +47,7 @@ void registerUnauthenticatedValidJson_AddsUser() throws Exception {
// user not existing
assertThrows(UsernameNotFoundException.class, () -> userDetailsService.loadUserByUsername(username));
mockMvc.perform(post("/api/register")
.with(csrf())
.contentType(MediaType.APPLICATION_JSON)
.content(objectMapper.writeValueAsString(
new UserDto(username, "12345678"))))
Expand All @@ -57,11 +59,13 @@ void registerUnauthenticatedValidJson_AddsUser() throws Exception {
@Test
void registerUnauthenticatedExistingUser_Gives400() throws Exception {
mockMvc.perform(post("/api/register")
.with(csrf())
.contentType(MediaType.APPLICATION_JSON)
.content(objectMapper.writeValueAsString(
new UserDto("test@xyz.de", "12345678"))))
.andExpect(status().isOk());
mockMvc.perform(post("/api/register") // and again
mockMvc.perform(post("/api/register")
.with(csrf()) // and again
.contentType(MediaType.APPLICATION_JSON)
.content(objectMapper.writeValueAsString(
new UserDto("test@xyz.de", "12345678"))))
Expand All @@ -71,6 +75,7 @@ void registerUnauthenticatedExistingUser_Gives400() throws Exception {
@Test
void registerUnauthenticatedInvalidDto_Gives400() throws Exception {
mockMvc.perform(post("/api/register") // and again
.with(csrf())
.contentType(MediaType.APPLICATION_JSON)
.content(objectMapper.writeValueAsString(
new UserDto("wrong", "1234"))))
Expand Down

0 comments on commit 4d39231

Please sign in to comment.