Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Evaluate flow control #472

Merged
merged 4 commits into from
Jul 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class RuntimeDefaultEvaluator(val frame: JdiFrame, val logger: Logger) extends R
case instance: NewInstanceTree => instantiate(instance)
case method: InstanceMethodTree => invoke(method)
case array: ArrayElemTree => evaluateArrayElement(array)
case branching: IfTree => evaluateIf(branching)
case staticMethod: StaticMethodTree => invokeStatic(staticMethod)
case outer: OuterTree => evaluateOuter(outer)
}
Expand Down Expand Up @@ -120,6 +121,15 @@ class RuntimeDefaultEvaluator(val frame: JdiFrame, val logger: Logger) extends R
array <- eval(tree.array)
index <- eval(tree.index).flatMap(_.unboxIfPrimitive).flatMap(_.toInt)
} yield array.asArray.getValue(index)

/* -------------------------------------------------------------------------- */
/* If tree evaluation */
/* -------------------------------------------------------------------------- */
def evaluateIf(tree: IfTree): Safe[JdiValue] =
for {
predicate <- eval(tree.p).flatMap(_.unboxIfPrimitive).flatMap(_.toBoolean)
value <- if (predicate) eval(tree.thenp) else eval(tree.elsep)
} yield value
}

object RuntimeDefaultEvaluator {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,13 @@ class RuntimeDefaultValidator(val frame: JdiFrame, val logger: Logger) extends R

def validate(expression: Stat): Validation[RuntimeEvaluableTree] =
expression match {
case lit: Lit => validateLiteral(lit)
case value: Term.Name => validateName(value.value, thisTree)
case _: Term.This => thisTree
case sup: Term.Super => Recoverable("Super not (yet) supported at runtime")
case _: Term.Apply | _: Term.ApplyInfix | _: Term.ApplyUnary => validateMethod(extractCall(expression))
case select: Term.Select => validateSelect(select)
case lit: Lit => validateLiteral(lit)
case branch: Term.If => validateIf(branch)
case instance: Term.New => validateNew(instance)
case _ => Recoverable("Expression not supported at runtime")
}
Expand All @@ -55,10 +56,10 @@ class RuntimeDefaultValidator(val frame: JdiFrame, val logger: Logger) extends R
/* -------------------------------------------------------------------------- */
/* Literal validation */
/* -------------------------------------------------------------------------- */
def validateLiteral(lit: Lit): Validation[LiteralTree] =
def validateLiteral(lit: Lit): Validation[RuntimeEvaluableTree] =
frame.classLoader().map(loader => LiteralTree(fromLitToValue(lit, loader))).extract match {
case Success(value) => value
case Failure(e) => Fatal(e)
case Failure(e) => CompilerRecoverable(e)
}

/* -------------------------------------------------------------------------- */
Expand Down Expand Up @@ -151,12 +152,19 @@ class RuntimeDefaultValidator(val frame: JdiFrame, val logger: Logger) extends R
if (methodFirst) zeroArg.orElse(field)
else field.orElse(zeroArg)

of.flatMap { of =>
member
.orElse(validateModule(name, Some(of)))
.orElse(findOuter(of).flatMap(o => validateName(value, Valid(o), methodFirst)))
}.orElse(localVarTreeByName(name))
.orElse(validateModule(name, None))
of
.flatMap { of =>
member
.orElse(validateModule(name, Some(of)))
.orElse(findOuter(of).flatMap(o => validateName(value, Valid(o), methodFirst)))
}
.orElse {
of match {
case Valid(_: ThisTree) | _: Recoverable => localVarTreeByName(name)
case _ => Recoverable(s"${name} is not a local variable")
}
}
.orElse { validateModule(name, None) }
}

/* -------------------------------------------------------------------------- */
Expand Down Expand Up @@ -252,6 +260,26 @@ class RuntimeDefaultValidator(val frame: JdiFrame, val logger: Logger) extends R
outerTree <- OuterTree(tree, outer)
} yield outerTree
}

/* -------------------------------------------------------------------------- */
/* Flow control validation */
/* -------------------------------------------------------------------------- */

def validateIf(tree: Term.If): Validation[RuntimeEvaluableTree] = {
lazy val objType = loadClass("java.lang.Object").extract.get.cls
for {
cond <- validate(tree.cond)
thenp <- validate(tree.thenp)
elsep <- validate(tree.elsep)
ifTree <- IfTree(
cond,
thenp,
elsep,
isAssignableFrom(_, _),
extractCommonSuperClass(thenp.`type`, elsep.`type`).getOrElse(objType)
)
} yield ifTree
}
}

object RuntimeDefaultValidator {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ trait RuntimeValidator {
*/
protected def validateWithClass(expression: Stat): Validation[RuntimeTree]

def validateLiteral(lit: Lit): Validation[LiteralTree]
def validateLiteral(lit: Lit): Validation[RuntimeEvaluableTree]

def localVarTreeByName(name: String): Validation[RuntimeEvaluableTree]

Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
package ch.epfl.scala.debugadapter.internal.evaluator

import com.sun.jdi._
import scala.meta.trees.*

import scala.meta.Lit
import scala.util.Success
import RuntimeEvaluatorExtractors.*
import scala.meta.Stat
import scala.meta.Term
import scala.meta.trees.*
import scala.meta.{Type => MType}
import scala.util.Failure
import scala.util.Success
import scala.util.Try
import scala.jdk.CollectionConverters.*

import RuntimeEvaluatorExtractors.*

private[evaluator] class RuntimeEvaluationHelpers(frame: JdiFrame) {
import RuntimeEvaluationHelpers.*
def fromLitToValue(literal: Lit, classLoader: JdiClassLoader): (Safe[Any], Type) = {
Expand Down Expand Up @@ -136,12 +138,12 @@ private[evaluator] class RuntimeEvaluationHelpers(frame: JdiFrame) {
}

(got, expected) match {
case (g: ArrayType, at: ArrayType) => isAssignableFrom(g.componentType, at.componentType)
case (g: ArrayType, at: ArrayType) => g.componentType().equals(at.componentType()) // TODO: check this
case (g: PrimitiveType, pt: PrimitiveType) => got.equals(pt)
case (g: ReferenceType, ref: ReferenceType) => referenceTypesMatch(g, ref)
case (_: VoidType, _: VoidType) => true

case (g: ClassType, pt: PrimitiveType) =>
case (g: ReferenceType, pt: PrimitiveType) =>
isAssignableFrom(g, frame.getPrimitiveBoxedClass(pt))
case (g: PrimitiveType, ct: ReferenceType) =>
isAssignableFrom(frame.getPrimitiveBoxedClass(g), ct)
Expand Down Expand Up @@ -215,6 +217,20 @@ private[evaluator] class RuntimeEvaluationHelpers(frame: JdiFrame) {

loop(qual)
}

def extractCommonSuperClass(tpe1: Type, tpe2: Type): Option[Type] = {
def getSuperClasses(of: Type): Array[ClassType] =
of match {
case cls: ClassType =>
Iterator.iterate(cls)(cls => cls.superclass()).takeWhile(_ != null).toArray
case _ => Array()
}

val superClasses1 = getSuperClasses(tpe1)
val superClasses2 = getSuperClasses(tpe2)
superClasses1.find(superClasses2.contains)
}

def validateType(tpe: MType, thisType: Option[RuntimeEvaluableTree])(
termValidation: Term => Validation[RuntimeEvaluableTree]
): Validation[(Option[RuntimeEvaluableTree], ClassTree)] =
Expand Down Expand Up @@ -374,7 +390,9 @@ private[evaluator] object RuntimeEvaluationHelpers {
case (_, Module(mod)) => Valid(InstanceFieldTree(field, mod))
case (_, eval: RuntimeEvaluableTree) =>
if (field.isStatic())
Fatal(s"Accessing static field $field from instance ${eval.`type`} can lead to unexpected behavior")
CompilerRecoverable(
s"Accessing static field $field from instance ${eval.`type`} can lead to unexpected behavior"
)
else Valid(InstanceFieldTree(field, eval))
}

Expand All @@ -387,7 +405,9 @@ private[evaluator] object RuntimeEvaluationHelpers {
case Module(mod) => Valid(InstanceMethodTree(method, args, mod))
case eval: RuntimeEvaluableTree =>
if (method.isStatic())
Fatal(s"Accessing static method $method from instance ${eval.`type`} can lead to unexpected behavior")
CompilerRecoverable(
s"Accessing static method $method from instance ${eval.`type`} can lead to unexpected behavior"
)
else Valid(InstanceMethodTree(method, args, eval))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,20 @@ protected[internal] object RuntimeEvaluatorExtractors {
}

object MethodCall {
def unapply(tree: RuntimeTree): Option[RuntimeTree] =
def unapply(tree: RuntimeTree): Boolean =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is not an extractor (it returns a Boolean), it should not be called unapply.

tree match {
case mt: NestedModuleTree => unapply(mt.of)
case ft: InstanceFieldTree => unapply(ft.qual)
case oct: OuterClassTree => unapply(oct.inner)
case OuterModuleTree(module) => unapply(module)
case _: MethodTree | _: NewInstanceTree => Some(tree)
case _: LiteralTree | _: LocalVarTree | _: PreEvaluatedTree | _: ThisTree => None
case _: StaticFieldTree | _: ClassTree | _: TopLevelModuleTree => None
case _: PrimitiveBinaryOpTree | _: PrimitiveUnaryOpTree | _: ArrayElemTree => None
case IfTree(p, t, f, _) => unapply(p) || unapply(t) || unapply(f)
case _: MethodTree | _: NewInstanceTree => true
case _: LiteralTree | _: LocalVarTree | _: PreEvaluatedTree | _: ThisTree => false
case _: StaticFieldTree | _: ClassTree | _: TopLevelModuleTree => false
case _: PrimitiveBinaryOpTree | _: PrimitiveUnaryOpTree | _: ArrayElemTree => false
}
def unapply(tree: Validation[RuntimeTree]): Option[RuntimeTree] =
tree.toOption.filter { unapply(_).isDefined }
def unapply(tree: Validation[RuntimeTree]): Validation[RuntimeTree] =
tree.filter(unapply)
}

object ReferenceTree {
Expand All @@ -61,6 +62,18 @@ protected[internal] object RuntimeEvaluatorExtractors {
}
}

object BooleanTree {
def unapply(p: Validation[RuntimeEvaluableTree]): Validation[RuntimeEvaluableTree] =
p.flatMap(unapply)

def unapply(p: RuntimeEvaluableTree): Validation[RuntimeEvaluableTree] = p.`type` match {
case bt: BooleanType => Valid(p)
case rt: ReferenceType if rt.name() == "java.lang.Boolean" =>
Valid(p)
case _ => CompilerRecoverable(s"The predicate must be a boolean expression, found ${p.`type`}")
}
}

object PrimitiveTest {
object IsIntegral {
def unapply(x: Type): Boolean = x match {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package ch.epfl.scala.debugadapter.internal.evaluator

import ch.epfl.scala.debugadapter.Logger
import scala.util.Success
import scala.meta.Term
import scala.meta.Lit

class RuntimePreEvaluationValidator(
override val frame: JdiFrame,
Expand All @@ -13,8 +16,8 @@ class RuntimePreEvaluationValidator(
Validation.fromTry(tpe).map(PreEvaluatedTree(value, _))
}

override lazy val thisTree: Validation[PreEvaluatedTree] =
ThisTree(frame.thisObject).flatMap(preEvaluate)
override def validateLiteral(lit: Lit): Validation[RuntimeEvaluableTree] =
super.validateLiteral(lit).flatMap(preEvaluate)

override def localVarTreeByName(name: String): Validation[PreEvaluatedTree] =
super.localVarTreeByName(name).flatMap(preEvaluate)
Expand All @@ -41,6 +44,22 @@ class RuntimePreEvaluationValidator(
preEvaluate(tree)
case tree => Valid(tree)
}

override def validateIf(tree: Term.If): Validation[RuntimeEvaluableTree] =
super.validateIf(tree).transform {
case tree @ Valid(IfTree(p: PreEvaluatedTree, thenp, elsep, _)) =>
val predicate = for {
pValue <- p.value
unboxed <- pValue.unboxIfPrimitive
bool <- unboxed.toBoolean
} yield bool
predicate.extract match {
case Success(true) => Valid(thenp)
case Success(false) => Valid(elsep)
case _ => tree
}
case tree => tree
}
}

object RuntimePreEvaluationValidator {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package ch.epfl.scala.debugadapter.internal.evaluator

import com.sun.jdi._
import RuntimeEvaluatorExtractors.{IsAnyVal, Module}
import RuntimeEvaluatorExtractors.{BooleanTree, IsAnyVal, Module}
import scala.util.Success

/* -------------------------------------------------------------------------- */
Expand Down Expand Up @@ -51,7 +51,7 @@ case class LiteralTree private (
object LiteralTree {
def apply(value: (Safe[Any], Type)): Validation[LiteralTree] = value._1 match {
case Safe(Success(_: String)) | Safe(Success(IsAnyVal(_))) => Valid(new LiteralTree(value._1, value._2))
case _ => Fatal(s"Unsupported literal type: ${value.getClass}")
case _ => CompilerRecoverable(s"Unsupported literal type: ${value.getClass}")
}
}

Expand Down Expand Up @@ -327,3 +327,56 @@ case class PreEvaluatedTree(
object PreEvaluatedTree {
def apply(value: (Safe[JdiValue], Type)) = new PreEvaluatedTree(value._1, value._2)
}

/* -------------------------------------------------------------------------- */
/* Flow control trees */
/* -------------------------------------------------------------------------- */
case class IfTree private (
p: RuntimeEvaluableTree,
thenp: RuntimeEvaluableTree,
elsep: RuntimeEvaluableTree,
`type`: Type
) extends RuntimeEvaluableTree {
override def prettyPrint(depth: Int): String = {
val indent = "\t" * (depth + 1)
s"""|IfTree(
|${indent}p= ${p.prettyPrint(depth + 1)},
|${indent}ifTrue= ${thenp.prettyPrint(depth + 1)},
|${indent}ifFalse= ${elsep.prettyPrint(depth + 1)}
|${indent}t= ${`type`}
|${indent.dropRight(1)})""".stripMargin
}
}

object IfTree {

/**
* Returns the type of the branch that is chosen, if any
*
* @param t1
* @param t2
* @return Some(true) if t1 is chosen, Some(false) if t2 is chosen, None if no branch is chosen
*/
def apply(
p: RuntimeEvaluableTree,
ifTrue: RuntimeEvaluableTree,
ifFalse: RuntimeEvaluableTree,
assignableFrom: (
Type,
Type
) => Boolean, // ! This is a hack, passing a wrong method would lead to inconsistent trees
objType: => Type
): Validation[IfTree] = {
val pType = p.`type`
val tType = ifTrue.`type`
val fType = ifFalse.`type`

p match {
case BooleanTree(_) =>
if (assignableFrom(tType, fType)) Valid(IfTree(p, ifTrue, ifFalse, tType))
else if (assignableFrom(fType, tType)) Valid(IfTree(p, ifTrue, ifFalse, fType))
else Valid(IfTree(p, ifTrue, ifFalse, objType))
case _ => CompilerRecoverable("A predicate must be a boolean")
}
}
}
Loading
Loading