Skip to content

Commit

Permalink
Merge pull request #472 from iusildra/evaluate-flow-control
Browse files Browse the repository at this point in the history
Evaluate flow control
  • Loading branch information
adpi2 authored Jul 5, 2023
2 parents 904bc0e + 396bef8 commit 093dcc7
Show file tree
Hide file tree
Showing 9 changed files with 389 additions and 181 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class RuntimeDefaultEvaluator(val frame: JdiFrame, val logger: Logger) extends R
case instance: NewInstanceTree => instantiate(instance)
case method: InstanceMethodTree => invoke(method)
case array: ArrayElemTree => evaluateArrayElement(array)
case branching: IfTree => evaluateIf(branching)
case staticMethod: StaticMethodTree => invokeStatic(staticMethod)
case outer: OuterTree => evaluateOuter(outer)
}
Expand Down Expand Up @@ -120,6 +121,15 @@ class RuntimeDefaultEvaluator(val frame: JdiFrame, val logger: Logger) extends R
array <- eval(tree.array)
index <- eval(tree.index).flatMap(_.unboxIfPrimitive).flatMap(_.toInt)
} yield array.asArray.getValue(index)

/* -------------------------------------------------------------------------- */
/* If tree evaluation */
/* -------------------------------------------------------------------------- */
def evaluateIf(tree: IfTree): Safe[JdiValue] =
for {
predicate <- eval(tree.p).flatMap(_.unboxIfPrimitive).flatMap(_.toBoolean)
value <- if (predicate) eval(tree.thenp) else eval(tree.elsep)
} yield value
}

object RuntimeDefaultEvaluator {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,13 @@ class RuntimeDefaultValidator(val frame: JdiFrame, val logger: Logger) extends R

def validate(expression: Stat): Validation[RuntimeEvaluableTree] =
expression match {
case lit: Lit => validateLiteral(lit)
case value: Term.Name => validateName(value.value, thisTree)
case _: Term.This => thisTree
case sup: Term.Super => Recoverable("Super not (yet) supported at runtime")
case _: Term.Apply | _: Term.ApplyInfix | _: Term.ApplyUnary => validateMethod(extractCall(expression))
case select: Term.Select => validateSelect(select)
case lit: Lit => validateLiteral(lit)
case branch: Term.If => validateIf(branch)
case instance: Term.New => validateNew(instance)
case _ => Recoverable("Expression not supported at runtime")
}
Expand All @@ -55,10 +56,10 @@ class RuntimeDefaultValidator(val frame: JdiFrame, val logger: Logger) extends R
/* -------------------------------------------------------------------------- */
/* Literal validation */
/* -------------------------------------------------------------------------- */
def validateLiteral(lit: Lit): Validation[LiteralTree] =
def validateLiteral(lit: Lit): Validation[RuntimeEvaluableTree] =
frame.classLoader().map(loader => LiteralTree(fromLitToValue(lit, loader))).extract match {
case Success(value) => value
case Failure(e) => Fatal(e)
case Failure(e) => CompilerRecoverable(e)
}

/* -------------------------------------------------------------------------- */
Expand Down Expand Up @@ -151,12 +152,19 @@ class RuntimeDefaultValidator(val frame: JdiFrame, val logger: Logger) extends R
if (methodFirst) zeroArg.orElse(field)
else field.orElse(zeroArg)

of.flatMap { of =>
member
.orElse(validateModule(name, Some(of)))
.orElse(findOuter(of).flatMap(o => validateName(value, Valid(o), methodFirst)))
}.orElse(localVarTreeByName(name))
.orElse(validateModule(name, None))
of
.flatMap { of =>
member
.orElse(validateModule(name, Some(of)))
.orElse(findOuter(of).flatMap(o => validateName(value, Valid(o), methodFirst)))
}
.orElse {
of match {
case Valid(_: ThisTree) | _: Recoverable => localVarTreeByName(name)
case _ => Recoverable(s"${name} is not a local variable")
}
}
.orElse { validateModule(name, None) }
}

/* -------------------------------------------------------------------------- */
Expand Down Expand Up @@ -252,6 +260,26 @@ class RuntimeDefaultValidator(val frame: JdiFrame, val logger: Logger) extends R
outerTree <- OuterTree(tree, outer)
} yield outerTree
}

/* -------------------------------------------------------------------------- */
/* Flow control validation */
/* -------------------------------------------------------------------------- */

def validateIf(tree: Term.If): Validation[RuntimeEvaluableTree] = {
lazy val objType = loadClass("java.lang.Object").extract.get.cls
for {
cond <- validate(tree.cond)
thenp <- validate(tree.thenp)
elsep <- validate(tree.elsep)
ifTree <- IfTree(
cond,
thenp,
elsep,
isAssignableFrom(_, _),
extractCommonSuperClass(thenp.`type`, elsep.`type`).getOrElse(objType)
)
} yield ifTree
}
}

object RuntimeDefaultValidator {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ trait RuntimeValidator {
*/
protected def validateWithClass(expression: Stat): Validation[RuntimeTree]

def validateLiteral(lit: Lit): Validation[LiteralTree]
def validateLiteral(lit: Lit): Validation[RuntimeEvaluableTree]

def localVarTreeByName(name: String): Validation[RuntimeEvaluableTree]

Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
package ch.epfl.scala.debugadapter.internal.evaluator

import com.sun.jdi._
import scala.meta.trees.*

import scala.meta.Lit
import scala.util.Success
import RuntimeEvaluatorExtractors.*
import scala.meta.Stat
import scala.meta.Term
import scala.meta.trees.*
import scala.meta.{Type => MType}
import scala.util.Failure
import scala.util.Success
import scala.util.Try
import scala.jdk.CollectionConverters.*

import RuntimeEvaluatorExtractors.*

private[evaluator] class RuntimeEvaluationHelpers(frame: JdiFrame) {
import RuntimeEvaluationHelpers.*
def fromLitToValue(literal: Lit, classLoader: JdiClassLoader): (Safe[Any], Type) = {
Expand Down Expand Up @@ -136,12 +138,12 @@ private[evaluator] class RuntimeEvaluationHelpers(frame: JdiFrame) {
}

(got, expected) match {
case (g: ArrayType, at: ArrayType) => isAssignableFrom(g.componentType, at.componentType)
case (g: ArrayType, at: ArrayType) => g.componentType().equals(at.componentType()) // TODO: check this
case (g: PrimitiveType, pt: PrimitiveType) => got.equals(pt)
case (g: ReferenceType, ref: ReferenceType) => referenceTypesMatch(g, ref)
case (_: VoidType, _: VoidType) => true

case (g: ClassType, pt: PrimitiveType) =>
case (g: ReferenceType, pt: PrimitiveType) =>
isAssignableFrom(g, frame.getPrimitiveBoxedClass(pt))
case (g: PrimitiveType, ct: ReferenceType) =>
isAssignableFrom(frame.getPrimitiveBoxedClass(g), ct)
Expand Down Expand Up @@ -215,6 +217,20 @@ private[evaluator] class RuntimeEvaluationHelpers(frame: JdiFrame) {

loop(qual)
}

def extractCommonSuperClass(tpe1: Type, tpe2: Type): Option[Type] = {
def getSuperClasses(of: Type): Array[ClassType] =
of match {
case cls: ClassType =>
Iterator.iterate(cls)(cls => cls.superclass()).takeWhile(_ != null).toArray
case _ => Array()
}

val superClasses1 = getSuperClasses(tpe1)
val superClasses2 = getSuperClasses(tpe2)
superClasses1.find(superClasses2.contains)
}

def validateType(tpe: MType, thisType: Option[RuntimeEvaluableTree])(
termValidation: Term => Validation[RuntimeEvaluableTree]
): Validation[(Option[RuntimeEvaluableTree], ClassTree)] =
Expand Down Expand Up @@ -374,7 +390,9 @@ private[evaluator] object RuntimeEvaluationHelpers {
case (_, Module(mod)) => Valid(InstanceFieldTree(field, mod))
case (_, eval: RuntimeEvaluableTree) =>
if (field.isStatic())
Fatal(s"Accessing static field $field from instance ${eval.`type`} can lead to unexpected behavior")
CompilerRecoverable(
s"Accessing static field $field from instance ${eval.`type`} can lead to unexpected behavior"
)
else Valid(InstanceFieldTree(field, eval))
}

Expand All @@ -387,7 +405,9 @@ private[evaluator] object RuntimeEvaluationHelpers {
case Module(mod) => Valid(InstanceMethodTree(method, args, mod))
case eval: RuntimeEvaluableTree =>
if (method.isStatic())
Fatal(s"Accessing static method $method from instance ${eval.`type`} can lead to unexpected behavior")
CompilerRecoverable(
s"Accessing static method $method from instance ${eval.`type`} can lead to unexpected behavior"
)
else Valid(InstanceMethodTree(method, args, eval))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,20 @@ protected[internal] object RuntimeEvaluatorExtractors {
}

object MethodCall {
def unapply(tree: RuntimeTree): Option[RuntimeTree] =
def unapply(tree: RuntimeTree): Boolean =
tree match {
case mt: NestedModuleTree => unapply(mt.of)
case ft: InstanceFieldTree => unapply(ft.qual)
case oct: OuterClassTree => unapply(oct.inner)
case OuterModuleTree(module) => unapply(module)
case _: MethodTree | _: NewInstanceTree => Some(tree)
case _: LiteralTree | _: LocalVarTree | _: PreEvaluatedTree | _: ThisTree => None
case _: StaticFieldTree | _: ClassTree | _: TopLevelModuleTree => None
case _: PrimitiveBinaryOpTree | _: PrimitiveUnaryOpTree | _: ArrayElemTree => None
case IfTree(p, t, f, _) => unapply(p) || unapply(t) || unapply(f)
case _: MethodTree | _: NewInstanceTree => true
case _: LiteralTree | _: LocalVarTree | _: PreEvaluatedTree | _: ThisTree => false
case _: StaticFieldTree | _: ClassTree | _: TopLevelModuleTree => false
case _: PrimitiveBinaryOpTree | _: PrimitiveUnaryOpTree | _: ArrayElemTree => false
}
def unapply(tree: Validation[RuntimeTree]): Option[RuntimeTree] =
tree.toOption.filter { unapply(_).isDefined }
def unapply(tree: Validation[RuntimeTree]): Validation[RuntimeTree] =
tree.filter(unapply)
}

object ReferenceTree {
Expand All @@ -61,6 +62,18 @@ protected[internal] object RuntimeEvaluatorExtractors {
}
}

object BooleanTree {
def unapply(p: Validation[RuntimeEvaluableTree]): Validation[RuntimeEvaluableTree] =
p.flatMap(unapply)

def unapply(p: RuntimeEvaluableTree): Validation[RuntimeEvaluableTree] = p.`type` match {
case bt: BooleanType => Valid(p)
case rt: ReferenceType if rt.name() == "java.lang.Boolean" =>
Valid(p)
case _ => CompilerRecoverable(s"The predicate must be a boolean expression, found ${p.`type`}")
}
}

object PrimitiveTest {
object IsIntegral {
def unapply(x: Type): Boolean = x match {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package ch.epfl.scala.debugadapter.internal.evaluator

import ch.epfl.scala.debugadapter.Logger
import scala.util.Success
import scala.meta.Term
import scala.meta.Lit

class RuntimePreEvaluationValidator(
override val frame: JdiFrame,
Expand All @@ -13,8 +16,8 @@ class RuntimePreEvaluationValidator(
Validation.fromTry(tpe).map(PreEvaluatedTree(value, _))
}

override lazy val thisTree: Validation[PreEvaluatedTree] =
ThisTree(frame.thisObject).flatMap(preEvaluate)
override def validateLiteral(lit: Lit): Validation[RuntimeEvaluableTree] =
super.validateLiteral(lit).flatMap(preEvaluate)

override def localVarTreeByName(name: String): Validation[PreEvaluatedTree] =
super.localVarTreeByName(name).flatMap(preEvaluate)
Expand All @@ -41,6 +44,22 @@ class RuntimePreEvaluationValidator(
preEvaluate(tree)
case tree => Valid(tree)
}

override def validateIf(tree: Term.If): Validation[RuntimeEvaluableTree] =
super.validateIf(tree).transform {
case tree @ Valid(IfTree(p: PreEvaluatedTree, thenp, elsep, _)) =>
val predicate = for {
pValue <- p.value
unboxed <- pValue.unboxIfPrimitive
bool <- unboxed.toBoolean
} yield bool
predicate.extract match {
case Success(true) => Valid(thenp)
case Success(false) => Valid(elsep)
case _ => tree
}
case tree => tree
}
}

object RuntimePreEvaluationValidator {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package ch.epfl.scala.debugadapter.internal.evaluator

import com.sun.jdi._
import RuntimeEvaluatorExtractors.{IsAnyVal, Module}
import RuntimeEvaluatorExtractors.{BooleanTree, IsAnyVal, Module}
import scala.util.Success

/* -------------------------------------------------------------------------- */
Expand Down Expand Up @@ -51,7 +51,7 @@ case class LiteralTree private (
object LiteralTree {
def apply(value: (Safe[Any], Type)): Validation[LiteralTree] = value._1 match {
case Safe(Success(_: String)) | Safe(Success(IsAnyVal(_))) => Valid(new LiteralTree(value._1, value._2))
case _ => Fatal(s"Unsupported literal type: ${value.getClass}")
case _ => CompilerRecoverable(s"Unsupported literal type: ${value.getClass}")
}
}

Expand Down Expand Up @@ -327,3 +327,56 @@ case class PreEvaluatedTree(
object PreEvaluatedTree {
def apply(value: (Safe[JdiValue], Type)) = new PreEvaluatedTree(value._1, value._2)
}

/* -------------------------------------------------------------------------- */
/* Flow control trees */
/* -------------------------------------------------------------------------- */
case class IfTree private (
p: RuntimeEvaluableTree,
thenp: RuntimeEvaluableTree,
elsep: RuntimeEvaluableTree,
`type`: Type
) extends RuntimeEvaluableTree {
override def prettyPrint(depth: Int): String = {
val indent = "\t" * (depth + 1)
s"""|IfTree(
|${indent}p= ${p.prettyPrint(depth + 1)},
|${indent}ifTrue= ${thenp.prettyPrint(depth + 1)},
|${indent}ifFalse= ${elsep.prettyPrint(depth + 1)}
|${indent}t= ${`type`}
|${indent.dropRight(1)})""".stripMargin
}
}

object IfTree {

/**
* Returns the type of the branch that is chosen, if any
*
* @param t1
* @param t2
* @return Some(true) if t1 is chosen, Some(false) if t2 is chosen, None if no branch is chosen
*/
def apply(
p: RuntimeEvaluableTree,
ifTrue: RuntimeEvaluableTree,
ifFalse: RuntimeEvaluableTree,
assignableFrom: (
Type,
Type
) => Boolean, // ! This is a hack, passing a wrong method would lead to inconsistent trees
objType: => Type
): Validation[IfTree] = {
val pType = p.`type`
val tType = ifTrue.`type`
val fType = ifFalse.`type`

p match {
case BooleanTree(_) =>
if (assignableFrom(tType, fType)) Valid(IfTree(p, ifTrue, ifFalse, tType))
else if (assignableFrom(fType, tType)) Valid(IfTree(p, ifTrue, ifFalse, fType))
else Valid(IfTree(p, ifTrue, ifFalse, objType))
case _ => CompilerRecoverable("A predicate must be a boolean")
}
}
}
Loading

0 comments on commit 093dcc7

Please sign in to comment.