diff --git a/context/amf_ue.go b/context/amf_ue.go index b64239be..16d01531 100644 --- a/context/amf_ue.go +++ b/context/amf_ue.go @@ -108,6 +108,7 @@ type AmfUe struct { RoutingIndicator string AuthenticationCtx *models.UeAuthenticationCtx AuthFailureCauseSynchFailureTimes int + IdentityRequestSendTimes int ABBA []uint8 Kseaf string Kamf string @@ -538,6 +539,7 @@ func (ue *AmfUe) ClearRegistrationRequestData(accessType models.AccessType) { ue.RegistrationType5GS = 0 ue.IdentityTypeUsedForRegistration = 0 ue.AuthFailureCauseSynchFailureTimes = 0 + ue.IdentityRequestSendTimes = 0 ue.ServingAmfChanged = false ue.RegistrationAcceptForNon3GPPAccess = nil if ranUe := ue.RanUe[accessType]; ranUe != nil { diff --git a/gmm/handler.go b/gmm/handler.go index 9dce8332..bcf5850f 100644 --- a/gmm/handler.go +++ b/gmm/handler.go @@ -1476,6 +1476,7 @@ func AuthenticationProcedure(ue *context.AmfUe, accessType models.AccessType) (b } } else { // Request UE's SUCI by sending identity request + ue.IdentityRequestSendTimes++ gmm_message.SendIdentityRequest(ue.RanUe[accessType], accessType, nasMessage.MobileIdentity5GSTypeSuci) return false, nil } @@ -1914,7 +1915,8 @@ func HandleAuthenticationResponse(ue *context.AmfUe, accessType models.AccessTyp if hResStar != av5gAka.HxresStar { ue.GmmLog.Errorf("HRES* Validation Failure (received: %s, expected: %s)", hResStar, av5gAka.HxresStar) - if ue.IdentityTypeUsedForRegistration == nasMessage.MobileIdentity5GSType5gGuti { + if ue.IdentityTypeUsedForRegistration == nasMessage.MobileIdentity5GSType5gGuti && ue.IdentityRequestSendTimes == 0 { + ue.IdentityRequestSendTimes++ gmm_message.SendIdentityRequest(ue.RanUe[accessType], accessType, nasMessage.MobileIdentity5GSTypeSuci) return nil } else { @@ -1947,7 +1949,8 @@ func HandleAuthenticationResponse(ue *context.AmfUe, accessType models.AccessTyp ArgEAPMessage: "", }) case models.AuthResult_FAILURE: - if ue.IdentityTypeUsedForRegistration == nasMessage.MobileIdentity5GSType5gGuti { + if ue.IdentityTypeUsedForRegistration == nasMessage.MobileIdentity5GSType5gGuti && ue.IdentityRequestSendTimes == 0 { + ue.IdentityRequestSendTimes++ gmm_message.SendIdentityRequest(ue.RanUe[accessType], accessType, nasMessage.MobileIdentity5GSTypeSuci) return nil } else { @@ -1982,7 +1985,8 @@ func HandleAuthenticationResponse(ue *context.AmfUe, accessType models.AccessTyp ArgEAPMessage: response.EapPayload, }) case models.AuthResult_FAILURE: - if ue.IdentityTypeUsedForRegistration == nasMessage.MobileIdentity5GSType5gGuti { + if ue.IdentityTypeUsedForRegistration == nasMessage.MobileIdentity5GSType5gGuti && ue.IdentityRequestSendTimes == 0 { + ue.IdentityRequestSendTimes++ gmm_message.SendAuthenticationResult(ue.RanUe[accessType], false, response.EapPayload) gmm_message.SendIdentityRequest(ue.RanUe[accessType], accessType, nasMessage.MobileIdentity5GSTypeSuci) return nil diff --git a/gmm/sm.go b/gmm/sm.go index df15c3d2..0b554050 100644 --- a/gmm/sm.go +++ b/gmm/sm.go @@ -8,6 +8,7 @@ import ( "github.com/free5gc/amf/logger" "github.com/free5gc/fsm" "github.com/free5gc/nas" + "github.com/free5gc/nas/nasConvert" "github.com/free5gc/nas/nasMessage" "github.com/free5gc/openapi/models" ) @@ -167,6 +168,10 @@ func Authentication(state *fsm.State, event fsm.EventType, args fsm.ArgsType) { if err := HandleIdentityResponse(amfUe, gmmMessage.IdentityResponse); err != nil { logger.GmmLog.Errorln(err) } + // update identity type used for reauthentication + mobileIdentityContents := gmmMessage.IdentityResponse.MobileIdentity.GetMobileIdentityContents() + amfUe.IdentityTypeUsedForRegistration = nasConvert.GetTypeOfIdentity(mobileIdentityContents[0]) + err := GmmFSM.SendEvent(state, AuthRestartEvent, fsm.ArgsType{ArgAmfUe: amfUe, ArgAccessType: accessType}) if err != nil { logger.GmmLog.Errorln(err) @@ -205,6 +210,7 @@ func Authentication(state *fsm.State, event fsm.EventType, args fsm.ArgsType) { amfUe.GmmLog.Debugln(event) amfUe.AuthenticationCtx = nil amfUe.AuthFailureCauseSynchFailureTimes = 0 + amfUe.IdentityRequestSendTimes = 0 default: logger.GmmLog.Errorf("Unknown event [%+v]", event) }