Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Runtime evaluator #405

Merged
merged 20 commits into from
May 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,6 @@ target/
.bsp/
.vscode/
.metals/
**/metals.sbt
**/metals.sbt
*.worksheet.sc
*.plantuml
3 changes: 2 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ lazy val core = projectMatrix
Dependencies.scalaReflect(scalaVersion.value),
Dependencies.asm,
Dependencies.asmUtil,
Dependencies.sbtTestAgent
Dependencies.sbtTestAgent,
Dependencies.scalaMeta
),
libraryDependencies += onScalaVersion(
scala212 = Dependencies.scalaCollectionCompat,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,19 @@ object DebugConfig {

sealed trait EvaluationMode {
def allowScalaEvaluation: Boolean = false
def allowSimpleEvaluation: Boolean = false
def allowRuntimeEvaluation: Boolean = false
}

case object ScalaEvaluationOnly extends EvaluationMode {
override def allowScalaEvaluation: Boolean = true
}
case object SimpleEvaluationOnly extends EvaluationMode {
override def allowSimpleEvaluation: Boolean = true
case object RuntimeEvaluationOnly extends EvaluationMode {
iusildra marked this conversation as resolved.
Show resolved Hide resolved
override def allowRuntimeEvaluation: Boolean = true
}

case object NoEvaluation extends EvaluationMode
case object MixedEvaluation extends EvaluationMode {
override def allowScalaEvaluation: Boolean = true
override def allowSimpleEvaluation: Boolean = true
override def allowRuntimeEvaluation: Boolean = true
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ case class ScalaVersion(value: String) {
}

object ScalaVersion {
val `2.11` = ScalaVersion(value = "2.11.12")
val `2.12` = ScalaVersion(BuildInfo.scala212)
val `2.13` = ScalaVersion(BuildInfo.scala213)
val `3.0` = ScalaVersion(BuildInfo.scala30)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,19 @@ import ch.epfl.scala.debugadapter.JavaRuntime
import ch.epfl.scala.debugadapter.Logger
import ch.epfl.scala.debugadapter.ManagedEntry
import ch.epfl.scala.debugadapter.UnmanagedEntry
import ch.epfl.scala.debugadapter.internal.evaluator.RuntimeEvaluation
import ch.epfl.scala.debugadapter.internal.evaluator.RuntimeExpression
import ch.epfl.scala.debugadapter.internal.evaluator.CompiledExpression
import ch.epfl.scala.debugadapter.internal.evaluator.JdiFrame
import ch.epfl.scala.debugadapter.internal.evaluator.JdiObject
import ch.epfl.scala.debugadapter.internal.evaluator.JdiValue
import ch.epfl.scala.debugadapter.internal.evaluator.LocalValue
import ch.epfl.scala.debugadapter.internal.evaluator.MessageLogger
import ch.epfl.scala.debugadapter.internal.evaluator.MethodInvocationFailed
import ch.epfl.scala.debugadapter.internal.evaluator.PlainLogMessage
import ch.epfl.scala.debugadapter.internal.evaluator.PreparedExpression
import ch.epfl.scala.debugadapter.internal.evaluator.ScalaEvaluator
import ch.epfl.scala.debugadapter.internal.evaluator.SimpleEvaluator
import ch.epfl.scala.debugadapter.internal.evaluator.{Recoverable, Valid, CompilerRecoverable, Fatal}
import evaluator.RuntimeEvaluatorExtractors.MethodCall
import com.microsoft.java.debug.core.IEvaluatableBreakpoint
import com.microsoft.java.debug.core.adapter.IDebugAdapterContext
import com.microsoft.java.debug.core.adapter.IEvaluationProvider
Expand All @@ -34,10 +36,9 @@ import scala.util.Success
import scala.util.Try

import ScalaExtension.*

import ch.epfl.scala.debugadapter.internal.evaluator.RuntimeEvaluationTree
private[internal] class EvaluationProvider(
sourceLookUp: SourceLookUpProvider,
simpleEvaluator: SimpleEvaluator,
messageLogger: MessageLogger,
scalaEvaluators: Map[ClassEntry, ScalaEvaluator],
mode: DebugConfig.EvaluationMode,
Expand Down Expand Up @@ -107,7 +108,7 @@ private[internal] class EvaluationProvider(
.invoke(methodName, methodSignature, wrappedArgs)
.recover {
// if invocation throws an exception, we return that exception as the result
case MethodInvocationFailed(msg, exception) => exception
case MethodInvocationFailed(msg, Some(exception)) => exception
}
.map(_.value)
}
Expand All @@ -126,7 +127,7 @@ private[internal] class EvaluationProvider(
m.scalaVersion match {
case None => s"Unsupported evaluation in Java classpath entry: ${entry.name}"
case Some(sv) =>
s"""|Missing scala-expression-compiler_{$sv} with version ${BuildInfo.version}.
s"""|Missing scala-expression-compiler_$sv with version ${BuildInfo.version}.
|You can open an issue at https://github.com/scalacenter/scala-debug-adapter.""".stripMargin
}
case _: JavaRuntime => s"Unsupported evaluation in JDK: ${entry.name}"
Expand All @@ -144,11 +145,8 @@ private[internal] class EvaluationProvider(
}
}

private def prepare(expression: String, frame: JdiFrame): Try[PreparedExpression] = {
lazy val simpleExpression = simpleEvaluator.prepare(expression, frame)
if (mode.allowSimpleEvaluation && simpleExpression.isDefined) {
Success(simpleExpression.get)
} else if (mode.allowScalaEvaluation) {
private def compilePrepare(expression: String, frame: JdiFrame) =
if (mode.allowScalaEvaluation) {
val fqcn = frame.current().location.declaringType.name
for {
evaluator <- getScalaEvaluator(fqcn)
Expand All @@ -160,17 +158,27 @@ private[internal] class EvaluationProvider(
} else {
Failure(new EvaluationFailed(s"Cannot evaluate '$expression' with $mode mode"))
}
}

private def evaluate(expression: PreparedExpression, frame: JdiFrame): Try[Value] = {
private def prepare(expression: String, frame: JdiFrame): Try[PreparedExpression] =
if (mode.allowRuntimeEvaluation)
RuntimeEvaluation(frame, logger).validate(expression) match {
case MethodCall(tree: RuntimeEvaluationTree) if mode.allowScalaEvaluation =>
compilePrepare(expression, frame).orElse(Success(RuntimeExpression(tree)))
case Valid(tree) => Success(RuntimeExpression(tree))
case Fatal(e) => Failure(e)
case Recoverable(_) | CompilerRecoverable(_) => compilePrepare(expression, frame)
}
else compilePrepare(expression, frame)

private def evaluate(expression: PreparedExpression, frame: JdiFrame): Try[Value] = evaluationBlock {
expression match {
case logMessage: PlainLogMessage => messageLogger.log(logMessage, frame)
case localValue: LocalValue => simpleEvaluator.evaluate(localValue, frame)
case RuntimeExpression(tree) => RuntimeEvaluation(frame, logger).evaluate(tree).getResult.map(_.value)
iusildra marked this conversation as resolved.
Show resolved Hide resolved
case expression: CompiledExpression =>
val fqcn = frame.current().location.declaringType.name
for {
evaluator <- getScalaEvaluator(fqcn)
compiledExpression <- evaluationBlock { evaluator.evaluate(expression, frame) }
compiledExpression <- evaluator.evaluate(expression, frame)
} yield compiledExpression
}
}
Expand Down Expand Up @@ -201,14 +209,12 @@ private[internal] object EvaluationProvider {
logger: Logger,
config: DebugConfig
): IEvaluationProvider = {
val simpleEvaluator = new SimpleEvaluator(logger, config.testMode)
val scalaEvaluators = tools.expressionCompilers.view.map { case (entry, compiler) =>
(entry, new ScalaEvaluator(entry, compiler, logger, config.testMode))
}.toMap
val messageLogger = new MessageLogger()
new EvaluationProvider(
tools.sourceLookUp,
simpleEvaluator,
messageLogger,
scalaEvaluators,
config.evaluationMode,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,12 @@ private[debugadapter] object ScalaExtension {
}
}
}

implicit class TrySeq[A](seq: Seq[Try[A]]) {
def traverse: Try[Seq[A]] = {
seq.foldRight(Try(Seq.empty[A])) { (safeHead, safeTail) =>
safeTail.flatMap(tail => safeHead.map(head => head +: tail))
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import scala.jdk.CollectionConverters.*
import scala.util.control.NonFatal

private[internal] class JdiClass(
cls: ClassType,
val cls: ClassType,
thread: ThreadReference
) extends JdiObject(cls.classObject, thread) {

Expand All @@ -15,16 +15,16 @@ private[internal] class JdiClass(
override def classLoader: JdiClassLoader = JdiClassLoader(cls.classLoader, thread)

def newInstance(args: Seq[JdiValue]): Safe[JdiObject] = {
val ctr = cls.methodsByName("<init>").asScala.head
val ctr = cls.methodsByName("<init>").get(0)
newInstance(ctr, args)
}

def newInstance(signature: String, args: Seq[JdiValue]): Safe[JdiObject] = {
val ctr = cls.methodsByName("<init>", signature).asScala.head
val ctr = cls.methodsByName("<init>", signature).get(0)
newInstance(ctr, args)
}

private def newInstance(ctr: Method, args: Seq[JdiValue]): Safe[JdiObject] =
def newInstance(ctr: Method, args: Seq[JdiValue]): Safe[JdiObject] =
for {
_ <- prepareMethod(ctr)
instance <- Safe(cls.newInstance(thread, ctr, args.map(_.value).asJava, ObjectReference.INVOKE_SINGLE_THREADED))
Expand Down Expand Up @@ -52,16 +52,16 @@ private[internal] class JdiClass(
Safe(cls.getValue(cls.fieldByName(fieldName))).map(JdiValue(_, thread))

def invokeStatic(methodName: String, args: Seq[JdiValue]): Safe[JdiValue] = {
val method = cls.methodsByName(methodName).asScala.head
val method = cls.methodsByName(methodName).get(0)
invokeStatic(method, args)
}

def invokeStatic(methodName: String, signature: String, args: Seq[JdiValue]): Safe[JdiValue] = {
val method = cls.methodsByName(methodName, signature).asScala.head
val method = cls.methodsByName(methodName, signature).get(0)
invokeStatic(method, args)
}

private def invokeStatic(method: Method, args: Seq[JdiValue]): Safe[JdiValue] =
def invokeStatic(method: Method, args: Seq[JdiValue]): Safe[JdiValue] =
Safe(cls.invokeMethod(thread, method, args.map(_.value).asJava, ObjectReference.INVOKE_SINGLE_THREADED))
.map(JdiValue(_, thread))
.recoverWith(wrapInvocationException(thread))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package ch.epfl.scala.debugadapter.internal.evaluator
import com.sun.jdi._

import java.nio.file.Path
import RuntimeEvaluatorExtractors.IsAnyVal

private[internal] class JdiClassLoader(
reference: ClassLoaderReference,
Expand Down Expand Up @@ -30,48 +31,109 @@ private[internal] class JdiClassLoader(
Safe(thread.virtualMachine.mirrorOf(str)).map(new JdiString(_, thread))

def mirrorOf(boolean: Boolean): JdiValue =
new JdiValue(thread.virtualMachine.mirrorOf(boolean), thread)
JdiValue(thread.virtualMachine.mirrorOf(boolean), thread)

def mirrorOf(integer: Int): JdiValue =
new JdiValue(thread.virtualMachine.mirrorOf(integer), thread)
def mirrorOf(byte: Byte): JdiValue =
JdiValue(thread.virtualMachine.mirrorOf(byte), thread)

def mirrorOf(char: Char): JdiValue =
JdiValue(thread.virtualMachine.mirrorOf(char), thread)

def mirrorOf(double: Double): JdiValue =
JdiValue(thread.virtualMachine.mirrorOf(double), thread)

def mirrorOf(float: Float): JdiValue =
JdiValue(thread.virtualMachine.mirrorOf(float), thread)

def mirrorOf(int: Int): JdiValue =
JdiValue(thread.virtualMachine.mirrorOf(int), thread)

def mirrorOf(long: Long): JdiValue =
JdiValue(thread.virtualMachine.mirrorOf(long), thread)

def mirrorOf(short: Short): JdiValue =
JdiValue(thread.virtualMachine.mirrorOf(short), thread)

def mirrorOfVoid(): JdiValue =
JdiValue(thread.virtualMachine.mirrorOfVoid(), thread)

def mirrorOfAnyVal(value: AnyVal): JdiValue = value match {
case d: Double => mirrorOf(d)
case f: Float => mirrorOf(f)
case l: Long => mirrorOf(l)
case i: Int => mirrorOf(i)
case s: Short => mirrorOf(s)
case c: Char => mirrorOf(c)
case b: Byte => mirrorOf(b)
case b: Boolean => mirrorOf(b)
}

def mirrorOfLiteral(value: Any): Safe[JdiValue] = value match {
case IsAnyVal(value) => Safe(mirrorOfAnyVal(value))
case value: String => mirrorOf(value)
case () => Safe(mirrorOfVoid())
case _ => Safe.failed(new IllegalArgumentException(s"Unsupported literal $value"))
}

def boxIfPrimitive(value: JdiValue): Safe[JdiValue] =
value.value match {
case value: BooleanValue => box(value.value)
case value: CharValue => box(value.value)
case value: DoubleValue => box(value.value)
case value: FloatValue => box(value.value)
case value: IntegerValue => box(value.value)
case value: LongValue => box(value.value)
case value: ShortValue => box(value.value)
case value => Safe(JdiValue(value, thread))
case _: PrimitiveValue => box(value)
case _ => Safe(value)
}

def box(value: AnyVal): Safe[JdiObject] =
def box(value: JdiValue): Safe[JdiObject] = {
for {
jdiValue <- mirrorOf(value.toString)
jdiValue <- value.value match {
case c: CharValue => Safe(value)
case _ => mirrorOf(value.value.toString)
}
_ = getClass
(className, sig) = value match {
case _: Boolean => ("java.lang.Boolean", "(Ljava/lang/String;)Ljava/lang/Boolean;")
case _: Byte => ("java.lang.Byte", "(Ljava/lang/String;)Ljava/lang/Byte;")
case _: Char => ("java.lang.Character", "(Ljava/lang/String;)Ljava/lang/Character;")
case _: Double => ("java.lang.Double", "(Ljava/lang/String;)Ljava/lang/Double;")
case _: Float => ("java.lang.Float", "(Ljava/lang/String;)Ljava/lang/Float;")
case _: Int => ("java.lang.Integer", "(Ljava/lang/String;)Ljava/lang/Integer;")
case _: Long => ("java.lang.Long", "(Ljava/lang/String;)Ljava/lang/Long;")
case _: Short => ("java.lang.Short", "(Ljava/lang/String;)Ljava/lang/Short;")
(className, sig) = value.value match {
case _: BooleanValue => ("java.lang.Boolean", "(Ljava/lang/String;)Ljava/lang/Boolean;")
case _: ByteValue => ("java.lang.Byte", "(Ljava/lang/String;)Ljava/lang/Byte;")
case _: CharValue => ("java.lang.Character", "(C)Ljava/lang/Character;")
case _: DoubleValue => ("java.lang.Double", "(Ljava/lang/String;)Ljava/lang/Double;")
case _: FloatValue => ("java.lang.Float", "(Ljava/lang/String;)Ljava/lang/Float;")
case _: IntegerValue => ("java.lang.Integer", "(Ljava/lang/String;)Ljava/lang/Integer;")
case _: LongValue => ("java.lang.Long", "(Ljava/lang/String;)Ljava/lang/Long;")
case _: ShortValue => ("java.lang.Short", "(Ljava/lang/String;)Ljava/lang/Short;")
}
clazz <- loadClass(className)
objectRef <- clazz.invokeStatic("valueOf", sig, List(jdiValue))
} yield objectRef.asObject
}

def boxUnboxOnNeed(
expected: Seq[Type],
received: Seq[JdiValue]
): Safe[Seq[JdiValue]] = {
expected
.zip(received)
.map { case (expect: Type, got: JdiValue) =>
(expect, got.value) match {
case (argType: ReferenceType, arg: PrimitiveValue) => boxIfPrimitive(got)
case (argType: PrimitiveType, arg: ObjectReference) => got.unboxIfPrimitive
case (argType, arg) => Safe(got)
}
}
.traverse
}

def boxUnboxOnNeed(
expected: java.util.List[Type],
received: Seq[JdiValue]
): Safe[Seq[JdiValue]] = boxUnboxOnNeed(expected.asScalaSeq, received)

def createArray(arrayType: String, values: Seq[JdiValue]): Safe[JdiArray] =
for {
arrayTypeClass <- loadClass(arrayType)
arrayClass <- loadClass("java.lang.reflect.Array")
size = mirrorOf(values.size)
array <- arrayClass
.invokeStatic("newInstance", "(Ljava/lang/Class;I)Ljava/lang/Object;", Seq(arrayTypeClass, size))
.invokeStatic(
"newInstance",
"(Ljava/lang/Class;I)Ljava/lang/Object;",
Seq(arrayTypeClass, mirrorOf(values.size))
)
.map(_.asArray)
} yield {
array.setValues(values)
Expand Down
Loading