Skip to content

Commit

Permalink
Make managedIdentitySource not static
Browse files Browse the repository at this point in the history
  • Loading branch information
Avery-Dunn committed Aug 27, 2024
1 parent 3b47606 commit c46e5d2
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public class ManagedIdentityApplication extends AbstractApplicationBase implemen
static TokenCache sharedTokenCache = new TokenCache();

@Getter(value = AccessLevel.PUBLIC)
static ManagedIdentitySourceType managedIdentitySource = ManagedIdentityClient.getManagedIdentitySource();
ManagedIdentitySourceType managedIdentitySource = ManagedIdentityClient.getManagedIdentitySource();

@Getter(value = AccessLevel.PACKAGE)
static IEnvironmentVariables environmentVariables;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,36 +12,24 @@
class ManagedIdentityClient {
private static final Logger LOG = LoggerFactory.getLogger(ManagedIdentityClient.class);

private static ManagedIdentitySourceType managedIdentitySourceType;

protected static void resetManagedIdentitySourceType() {
managedIdentitySourceType = ManagedIdentitySourceType.NONE;
}

static ManagedIdentitySourceType getManagedIdentitySource() {
if (managedIdentitySourceType != null && managedIdentitySourceType != ManagedIdentitySourceType.NONE) {
return managedIdentitySourceType;
}

IEnvironmentVariables environmentVariables = AbstractManagedIdentitySource.getEnvironmentVariables();

if (!StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.IDENTITY_ENDPOINT)) &&
!StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.IDENTITY_HEADER))) {
if (!StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.IDENTITY_SERVER_THUMBPRINT))) {
managedIdentitySourceType = ManagedIdentitySourceType.SERVICE_FABRIC;
return ManagedIdentitySourceType.SERVICE_FABRIC;
} else {
managedIdentitySourceType = ManagedIdentitySourceType.APP_SERVICE;
return ManagedIdentitySourceType.APP_SERVICE;
}
} else if (!StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.MSI_ENDPOINT))) {
managedIdentitySourceType = ManagedIdentitySourceType.CLOUD_SHELL;
return ManagedIdentitySourceType.CLOUD_SHELL;
} else if (!StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.IDENTITY_ENDPOINT)) &&
!StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.IMDS_ENDPOINT))) {
managedIdentitySourceType = ManagedIdentitySourceType.AZURE_ARC;
return ManagedIdentitySourceType.AZURE_ARC;
} else {
managedIdentitySourceType = ManagedIdentitySourceType.DEFAULT_TO_IMDS;
return ManagedIdentitySourceType.DEFAULT_TO_IMDS;
}

return managedIdentitySourceType;
}

AbstractManagedIdentitySource managedIdentitySource;
Expand All @@ -64,11 +52,7 @@ ManagedIdentityResponse getManagedIdentityResponse(ManagedIdentityParameters par
private static AbstractManagedIdentitySource createManagedIdentitySource(MsalRequest msalRequest,
ServiceBundle serviceBundle) {

if (managedIdentitySourceType == null || managedIdentitySourceType == ManagedIdentitySourceType.NONE) {
managedIdentitySourceType = getManagedIdentitySource();
}

switch (managedIdentitySourceType) {
switch (getManagedIdentitySource()) {
case SERVICE_FABRIC:
return ServiceFabricManagedIdentitySource.create(msalRequest, serviceBundle);
case APP_SERVICE:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,18 +155,22 @@ private HttpResponse expectedResponse(int statusCode, String response) {
void managedIdentity_GetManagedIdentitySource(ManagedIdentitySourceType source, String endpoint, ManagedIdentitySourceType expectedSource) {
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint);
ManagedIdentityApplication.setEnvironmentVariables(environmentVariables);
ManagedIdentityClient.resetManagedIdentitySourceType();

ManagedIdentitySourceType managedIdentitySourceType = ManagedIdentityClient.getManagedIdentitySource();
assertEquals(expectedSource, managedIdentitySourceType);
miApp = ManagedIdentityApplication
.builder(ManagedIdentityId.systemAssigned())
.build();

ManagedIdentitySourceType miClientSourceType = ManagedIdentityClient.getManagedIdentitySource();
ManagedIdentitySourceType miAppSourceType = miApp.managedIdentitySource;
assertEquals(expectedSource, miClientSourceType);
assertEquals(expectedSource, miAppSourceType);
}

@ParameterizedTest
@MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createData")
void managedIdentityTest_SystemAssigned_SuccessfulResponse(ManagedIdentitySourceType source, String endpoint, String resource) throws Exception {
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint);
ManagedIdentityApplication.setEnvironmentVariables(environmentVariables);
ManagedIdentityClient.resetManagedIdentitySourceType();
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);

when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource)));
Expand Down Expand Up @@ -201,7 +205,6 @@ void managedIdentityTest_SystemAssigned_SuccessfulResponse(ManagedIdentitySource
void managedIdentityTest_UserAssigned_SuccessfulResponse(ManagedIdentitySourceType source, String endpoint, ManagedIdentityId id) throws Exception {
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint);
ManagedIdentityApplication.setEnvironmentVariables(environmentVariables);
ManagedIdentityClient.resetManagedIdentitySourceType();
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);

when(httpClientMock.send(expectedRequest(source, resource, id))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource)));
Expand All @@ -228,7 +231,6 @@ void managedIdentityTest_RefreshOnHalfOfExpiresOn() throws Exception {
// so any of the MI options should let us verify that it's being set correctly
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(ManagedIdentitySourceType.APP_SERVICE, appServiceEndpoint);
ManagedIdentityApplication.setEnvironmentVariables(environmentVariables);
ManagedIdentityClient.resetManagedIdentitySourceType();
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);

when(httpClientMock.send(expectedRequest(ManagedIdentitySourceType.APP_SERVICE, resource))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource)));
Expand All @@ -255,7 +257,6 @@ void managedIdentityTest_RefreshOnHalfOfExpiresOn() throws Exception {
void managedIdentityTest_UserAssigned_NotSupported(ManagedIdentitySourceType source, String endpoint, ManagedIdentityId id) throws Exception {
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint);
ManagedIdentityApplication.setEnvironmentVariables(environmentVariables);
ManagedIdentityClient.resetManagedIdentitySourceType();
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);

miApp = ManagedIdentityApplication
Expand Down Expand Up @@ -292,7 +293,6 @@ void managedIdentityTest_DifferentScopes_RequestsNewToken(ManagedIdentitySourceT

IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint);
ManagedIdentityApplication.setEnvironmentVariables(environmentVariables);
ManagedIdentityClient.resetManagedIdentitySourceType();
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);

when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource)));
Expand Down Expand Up @@ -326,7 +326,6 @@ void managedIdentityTest_DifferentScopes_RequestsNewToken(ManagedIdentitySourceT
void managedIdentityTest_WrongScopes(ManagedIdentitySourceType source, String endpoint, String resource) throws Exception {
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint);
ManagedIdentityApplication.setEnvironmentVariables(environmentVariables);
ManagedIdentityClient.resetManagedIdentitySourceType();
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);

if (environmentVariables.getEnvironmentVariable("SourceType").equals(ManagedIdentitySourceType.CLOUD_SHELL.toString())) {
Expand Down Expand Up @@ -365,7 +364,6 @@ void managedIdentityTest_WrongScopes(ManagedIdentitySourceType source, String en
void managedIdentityTest_Retry(ManagedIdentitySourceType source, String endpoint, String resource) throws Exception {
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint);
ManagedIdentityApplication.setEnvironmentVariables(environmentVariables);
ManagedIdentityClient.resetManagedIdentitySourceType();
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);

miApp = ManagedIdentityApplication
Expand Down Expand Up @@ -416,7 +414,6 @@ void managedIdentityTest_Retry(ManagedIdentitySourceType source, String endpoint
void managedIdentity_RequestFailed_NoPayload(ManagedIdentitySourceType source, String endpoint) throws Exception {
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint);
ManagedIdentityApplication.setEnvironmentVariables(environmentVariables);
ManagedIdentityClient.resetManagedIdentitySourceType();
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);

when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(500, ""));
Expand Down Expand Up @@ -451,7 +448,6 @@ void managedIdentity_RequestFailed_NoPayload(ManagedIdentitySourceType source, S
void managedIdentity_RequestFailed_NullResponse(ManagedIdentitySourceType source, String endpoint) throws Exception {
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint);
ManagedIdentityApplication.setEnvironmentVariables(environmentVariables);
ManagedIdentityClient.resetManagedIdentitySourceType();
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);

when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(200, ""));
Expand Down Expand Up @@ -486,7 +482,6 @@ void managedIdentity_RequestFailed_NullResponse(ManagedIdentitySourceType source
void managedIdentity_RequestFailed_UnreachableNetwork(ManagedIdentitySourceType source, String endpoint) throws Exception {
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint);
ManagedIdentityApplication.setEnvironmentVariables(environmentVariables);
ManagedIdentityClient.resetManagedIdentitySourceType();
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);

when(httpClientMock.send(expectedRequest(source, resource))).thenThrow(new SocketException("A socket operation was attempted to an unreachable network."));
Expand Down Expand Up @@ -520,7 +515,6 @@ void managedIdentity_RequestFailed_UnreachableNetwork(ManagedIdentitySourceType
void azureArcManagedIdentity_MissingAuthHeader() throws Exception {
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(ManagedIdentitySourceType.AZURE_ARC, azureArcEndpoint);
ManagedIdentityApplication.setEnvironmentVariables(environmentVariables);
ManagedIdentityClient.resetManagedIdentitySourceType();
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);

HttpResponse response = new HttpResponse();
Expand Down Expand Up @@ -559,7 +553,6 @@ void azureArcManagedIdentity_MissingAuthHeader() throws Exception {
void managedIdentity_SharedCache(ManagedIdentitySourceType source, String endpoint) throws Exception {
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint);
ManagedIdentityApplication.setEnvironmentVariables(environmentVariables);
ManagedIdentityClient.resetManagedIdentitySourceType();
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);

when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource)));
Expand Down Expand Up @@ -600,7 +593,6 @@ void managedIdentity_SharedCache(ManagedIdentitySourceType source, String endpoi
void azureArcManagedIdentity_InvalidAuthHeader() throws Exception {
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(ManagedIdentitySourceType.AZURE_ARC, azureArcEndpoint);
ManagedIdentityApplication.setEnvironmentVariables(environmentVariables);
ManagedIdentityClient.resetManagedIdentitySourceType();
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);

HttpResponse response = new HttpResponse();
Expand Down Expand Up @@ -639,7 +631,6 @@ void azureArcManagedIdentity_InvalidAuthHeader() throws Exception {
void azureArcManagedIdentityAuthheaderValidationTest() throws Exception {
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(ManagedIdentitySourceType.AZURE_ARC, azureArcEndpoint);
ManagedIdentityApplication.setEnvironmentVariables(environmentVariables);
ManagedIdentityClient.resetManagedIdentitySourceType();
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);

//Both a missing file and an invalid path structure should throw an exception
Expand Down

0 comments on commit c46e5d2

Please sign in to comment.