Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Change Fabric Cog Service Token to Support Billing #2291

Merged
merged 3 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ package com.microsoft.azure.synapse.ml.services
import com.microsoft.azure.synapse.ml.codegen.Wrappable
import com.microsoft.azure.synapse.ml.core.contracts.HasOutputCol
import com.microsoft.azure.synapse.ml.core.schema.DatasetExtensions
import com.microsoft.azure.synapse.ml.fabric.{FabricClient, TokenLibrary}
import com.microsoft.azure.synapse.ml.fabric.FabricClient
import com.microsoft.azure.synapse.ml.io.http._
import com.microsoft.azure.synapse.ml.logging.SynapseMLLogging
import com.microsoft.azure.synapse.ml.logging.common.PlatformDetails
Expand Down Expand Up @@ -330,7 +330,7 @@ trait HasCognitiveServiceInput extends HasURL with HasSubscriptionKey with HasAA
val providedCustomAuthHeader = getValueOpt(row, CustomAuthHeader)
if (providedCustomAuthHeader .isEmpty && PlatformDetails.runningOnFabric()) {
logInfo("Using Default AAD Token On Fabric")
Option(TokenLibrary.getAuthHeader)
Option(FabricClient.getCognitiveMWCTokenAuthHeader)
} else {
providedCustomAuthHeader
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
package com.microsoft.azure.synapse.ml.services.openai

import com.microsoft.azure.synapse.ml.codegen.GenerationUtils
import com.microsoft.azure.synapse.ml.fabric.{FabricClient, OpenAIFabricSetting, OpenAITokenLibrary}
import com.microsoft.azure.synapse.ml.fabric.{FabricClient, OpenAIFabricSetting}
import com.microsoft.azure.synapse.ml.logging.common.PlatformDetails
import com.microsoft.azure.synapse.ml.param.ServiceParam
import com.microsoft.azure.synapse.ml.services._
Expand Down Expand Up @@ -277,18 +277,6 @@ trait HasOpenAITextParams extends HasOpenAISharedParams {
}
}

trait HasOpenAICognitiveServiceInput extends HasCognitiveServiceInput {
override protected def getCustomAuthHeader(row: Row): Option[String] = {
val providedCustomHeader = getValueOpt(row, CustomAuthHeader)
if (providedCustomHeader.isEmpty && PlatformDetails.runningOnFabric()) {
logInfo("Using Default OpenAI Token On Fabric")
Option(OpenAITokenLibrary.getAuthHeader)
} else {
providedCustomHeader
}
}
}

abstract class OpenAIServicesBase(override val uid: String) extends CognitiveServicesBase(uid: String)
with HasOpenAISharedParams with OpenAIFabricSetting {
setDefault(timeout -> 360.0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@ package com.microsoft.azure.synapse.ml.services.openai

import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
import com.microsoft.azure.synapse.ml.param.AnyJsonFormat.anyFormat
import com.microsoft.azure.synapse.ml.services.HasInternalJsonOutputParser
import com.microsoft.azure.synapse.ml.services.{HasCognitiveServiceInput, HasInternalJsonOutputParser}
import org.apache.http.entity.{AbstractHttpEntity, ContentType, StringEntity}
import org.apache.spark.ml.ComplexParamsReadable
import org.apache.spark.ml.param.Param
import org.apache.spark.ml.util._
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
Expand All @@ -20,7 +19,7 @@ import scala.language.existentials
object OpenAIChatCompletion extends ComplexParamsReadable[OpenAIChatCompletion]

class OpenAIChatCompletion(override val uid: String) extends OpenAIServicesBase(uid)
with HasOpenAITextParams with HasMessagesInput with HasOpenAICognitiveServiceInput
with HasOpenAITextParams with HasMessagesInput with HasCognitiveServiceInput
with HasInternalJsonOutputParser with SynapseMLLogging {
logClass(FeatureNames.AiServices.OpenAI)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ package com.microsoft.azure.synapse.ml.services.openai

import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
import com.microsoft.azure.synapse.ml.param.AnyJsonFormat.anyFormat
import com.microsoft.azure.synapse.ml.services.HasInternalJsonOutputParser
import com.microsoft.azure.synapse.ml.services.{HasCognitiveServiceInput, HasInternalJsonOutputParser}
import org.apache.http.entity.{AbstractHttpEntity, ContentType, StringEntity}
import org.apache.spark.ml.ComplexParamsReadable
import org.apache.spark.ml.util._
Expand All @@ -19,7 +19,7 @@ import scala.language.existentials
object OpenAICompletion extends ComplexParamsReadable[OpenAICompletion]

class OpenAICompletion(override val uid: String) extends OpenAIServicesBase(uid)
with HasOpenAITextParams with HasPromptInputs with HasOpenAICognitiveServiceInput
with HasOpenAITextParams with HasPromptInputs with HasCognitiveServiceInput
with HasInternalJsonOutputParser with SynapseMLLogging {
logClass(FeatureNames.AiServices.OpenAI)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import com.microsoft.azure.synapse.ml.param.AnyJsonFormat.anyFormat
import com.microsoft.azure.synapse.ml.io.http.JSONOutputParser
import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
import com.microsoft.azure.synapse.ml.param.ServiceParam
import com.microsoft.azure.synapse.ml.services.HasCognitiveServiceInput
import org.apache.http.entity.{AbstractHttpEntity, ContentType, StringEntity}
import org.apache.spark.ml.ComplexParamsReadable
import org.apache.spark.ml.linalg.SQLDataTypes.VectorType
Expand All @@ -22,7 +23,7 @@ import scala.language.existentials
object OpenAIEmbedding extends ComplexParamsReadable[OpenAIEmbedding]

class OpenAIEmbedding (override val uid: String) extends OpenAIServicesBase(uid)
with HasOpenAIEmbeddingParams with HasOpenAICognitiveServiceInput with SynapseMLLogging {
with HasOpenAIEmbeddingParams with HasCognitiveServiceInput with SynapseMLLogging {
logClass(FeatureNames.AiServices.OpenAI)

def this() = this(Identifiable.randomUID("OpenAIEmbedding"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,16 @@

package com.microsoft.azure.synapse.ml.services.openai

import com.microsoft.azure.synapse.ml.services._
import com.microsoft.azure.synapse.ml.core.contracts.HasOutputCol
import com.microsoft.azure.synapse.ml.core.spark.Functions
import com.microsoft.azure.synapse.ml.io.http.{ConcurrencyParams, HasErrorCol, HasURL}
import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
import com.microsoft.azure.synapse.ml.param.StringStringMapParam
import com.microsoft.azure.synapse.ml.services._
import org.apache.http.entity.AbstractHttpEntity
import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable, Transformer}
import org.apache.spark.sql.Row.unapplySeq
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.{DataType, StructType}
Expand All @@ -28,7 +27,7 @@ class OpenAIPrompt(override val uid: String) extends Transformer
with HasErrorCol with HasOutputCol
with HasURL with HasCustomCogServiceDomain with ConcurrencyParams
with HasSubscriptionKey with HasAADToken with HasCustomAuthHeader
with HasOpenAICognitiveServiceInput
with HasCognitiveServiceInput
with ComplexParamsWritable with SynapseMLLogging {

logClass(FeatureNames.AiServices.OpenAI)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ object FabricClient extends RESTUtils {

private def getHeaders: Map[String, String] = {
Map(
"Authorization" -> s"Bearer ${TokenLibrary.getAccessToken}",
"Authorization" -> s"${getMLWorkloadAADAuthHeader}",
"RequestId" -> UUID.randomUUID().toString,
"Content-Type" -> "application/json",
"x-ms-workload-resource-moniker" -> UUID.randomUUID().toString
Expand All @@ -143,4 +143,10 @@ object FabricClient extends RESTUtils {
def usagePost(url: String, body: String): JsValue = {
usagePost(url, body, getHeaders);
}

def getMLWorkloadAADAuthHeader: String = TokenLibrary.getMLWorkloadAADAuthHeader

def getCognitiveMWCTokenAuthHeader: String = {
TokenLibrary.getCognitiveMwcTokenAuthHeader(WorkspaceID.getOrElse(""), ArtifactID.getOrElse(""))
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,7 @@ package com.microsoft.azure.synapse.ml.fabric
import scala.reflect.runtime.currentMirror
import scala.reflect.runtime.universe._

trait AuthHeaderProvider {
def getAuthHeader: String
}

object TokenLibrary extends AuthHeaderProvider {
object TokenLibrary {
def getAccessToken: String = {
val objectName = "com.microsoft.azure.trident.tokenlibrary.TokenLibrary"
val mirror = currentMirror
Expand All @@ -27,9 +23,29 @@ object TokenLibrary extends AuthHeaderProvider {
}
}.getOrElse(throw new NoSuchMethodException(s"Method $methodName with argument type $argType not found"))
val methodMirror = mirror.reflect(obj).reflectMethod(selectedMethodSymbol.asMethod)
methodMirror("pbi").asInstanceOf[String]
methodMirror("ml").asInstanceOf[String]
}

def getSparkMwcToken(workspaceId: String, artifactId: String): String = {
val objectName = "com.microsoft.azure.trident.tokenlibrary.TokenLibrary"
val mirror = currentMirror
val module = mirror.staticModule(objectName)
val obj = mirror.reflectModule(module).instance
val objType = mirror.reflect(obj).symbol.toType
val methodName = "getMwcToken"
val methodSymbols = objType.decl(TermName(methodName)).asTerm.alternatives
val argTypes = List(typeOf[String], typeOf[String], typeOf[Integer], typeOf[String])
val selectedMethodSymbol = methodSymbols.find { m =>
m.asMethod.paramLists.flatten.map(_.typeSignature).zip(argTypes).forall { case (a, b) => a =:= b }
}.getOrElse(throw new NoSuchMethodException(s"Method $methodName with argument type not found"))
val methodMirror = mirror.reflect(obj).reflectMethod(selectedMethodSymbol.asMethod)
methodMirror(workspaceId, artifactId, 2, "SparkCore")
.asInstanceOf[String]
}


def getMLWorkloadAADAuthHeader: String = "Bearer " + getAccessToken

def getAuthHeader: String = "Bearer " + getAccessToken
def getCognitiveMwcTokenAuthHeader(workspaceId: String, artifactId: String): String = "MwcToken " +
getSparkMwcToken(workspaceId, artifactId)
}
Loading