Skip to content

Commit

Permalink
Multiple outers (#495)
Browse files Browse the repository at this point in the history
* indexed classes

put back nested module initialization in validation

allow fqcn + some cleanup

class search and static US bug fixes

fix class not loaded for arrays

applied review

rebase fix

* fixes #496
  • Loading branch information
iusildra authored Jul 13, 2023
1 parent 084ce54 commit c6b2734
Show file tree
Hide file tree
Showing 14 changed files with 286 additions and 199 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ private[internal] class EvaluationProvider(
override def evaluate(expression: String, thread: ThreadReference, depth: Int): CompletableFuture[Value] = {
val frame = JdiFrame(thread, depth)
val evaluator = RuntimeDefaultEvaluator(frame, logger)
val validator = RuntimePreEvaluationValidator(frame, logger, evaluator)
val validator = RuntimePreEvaluationValidator(frame, evaluator, sourceLookUp, logger)
val evaluation = for {
preparedExpression <- prepare(expression, frame, validator)
evaluation <- evaluate(preparedExpression, frame)
Expand All @@ -85,7 +85,7 @@ private[internal] class EvaluationProvider(
if (breakpoint.getCompiledExpression(locationCode) != null) {
breakpoint.getCompiledExpression(locationCode).asInstanceOf[Try[PreparedExpression]]
} else if (breakpoint.containsConditionalExpression) {
prepare(breakpoint.getCondition, frame, RuntimeDefaultValidator(frame, logger))
prepare(breakpoint.getCondition, frame, RuntimeDefaultValidator(frame, sourceLookUp, logger))
} else if (breakpoint.containsLogpointExpression) {
prepareLogMessage(breakpoint.getLogMessage, frame)
} else {
Expand Down Expand Up @@ -147,7 +147,7 @@ private[internal] class EvaluationProvider(
} else {
val tripleQuote = "\"\"\""
val expression = s"""println(s$tripleQuote$message$tripleQuote)"""
prepare(expression, frame, RuntimeDefaultValidator(frame, logger))
prepare(expression, frame, RuntimeDefaultValidator(frame, sourceLookUp, logger))
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package ch.epfl.scala.debugadapter.internal

import scala.collection.mutable.Map

trait ClassSearch {
def insert(fqcn: String): Unit
def find(className: String): List[String]
}

class SourceEntrySearchMap extends ClassSearch {
private val classes = Map[String, List[String]]().withDefaultValue(List.empty)

private def getLastInnerType(className: String): Option[String] = {
val pattern = """(.+\$)(.+)$""".r
className match {
case pattern(_, innerType) => Some(innerType)
case _ => None
}
}

def insert(fqcn: String) = {
val classWithOuters = fqcn.split('.').last
var className = getLastInnerType(classWithOuters).getOrElse(classWithOuters)

classes.update(className, fqcn :: classes(className))
}

def find(className: String): List[String] = classes(className)
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ private[debugadapter] final class SourceLookUpProvider(
sourceUriToClassPathEntry: Map[URI, ClassEntryLookUp],
fqcnToClassPathEntry: Map[String, ClassEntryLookUp]
) extends ISourceLookUpProvider {
private val classSearch = new SourceEntrySearchMap
for (entries <- classPathEntries; fqcn <- entries.fullyQualifiedNames) classSearch.insert(fqcn)

override def supportsRealtimeBreakpointVerification(): Boolean = true

override def getSourceFileURI(fqcn: String, path: String): String = {
Expand Down Expand Up @@ -62,6 +65,8 @@ private[debugadapter] final class SourceLookUpProvider(
classPathEntries.flatMap(_.fullyQualifiedNames)
private[internal] def allOrphanClasses: Iterable[ClassFile] =
classPathEntries.flatMap(_.orphanClassFiles)
private[internal] def classesByName(name: String) =
classSearch.find(name)

private[internal] def getScalaSig(fqcn: String): Option[ScalaSig] = {
for {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import com.sun.jdi._
import java.nio.file.Path
import RuntimeEvaluatorExtractors.IsAnyVal

import scala.jdk.CollectionConverters.*

private[internal] class JdiClassLoader(
reference: ClassLoaderReference,
thread: ThreadReference
Expand Down Expand Up @@ -122,7 +124,7 @@ private[internal] class JdiClassLoader(
def boxUnboxOnNeed(
expected: java.util.List[Type],
received: Seq[JdiValue]
): Safe[Seq[JdiValue]] = boxUnboxOnNeed(expected.asScalaSeq, received)
): Safe[Seq[JdiValue]] = boxUnboxOnNeed(expected.asScala.toSeq, received)

def createArray(arrayType: String, values: Seq[JdiValue]): Safe[JdiArray] =
for {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ private[evaluator] class JdiObject(

// we use a Seq instead of a Map because the ScalaEvaluator rely on the order of the fields
def fields: Seq[(String, JdiValue)] =
reference.referenceType.fields.asScalaSeq
reference.referenceType.fields.asScala.toSeq
.map(f => (f.name, JdiValue(reference.getValue(f), thread)))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package ch.epfl.scala.debugadapter.internal.evaluator
import ch.epfl.scala.debugadapter.Logger

class RuntimeDefaultEvaluator(val frame: JdiFrame, implicit val logger: Logger) extends RuntimeEvaluator {
val helper = new RuntimeEvaluationHelpers(frame)
def evaluate(stat: RuntimeEvaluableTree): Safe[JdiValue] =
eval(stat).map(_.derefIfRef)

Expand All @@ -23,7 +22,6 @@ class RuntimeDefaultEvaluator(val frame: JdiFrame, implicit val logger: Logger)
case array: ArrayElemTree => evaluateArrayElement(array)
case branching: IfTree => evaluateIf(branching)
case staticMethod: StaticMethodTree => invokeStatic(staticMethod)
case outer: OuterTree => evaluateOuter(outer)
case UnitTree => Safe(JdiValue(frame.thread.virtualMachine.mirrorOfVoid, frame.thread))
}

Expand All @@ -37,16 +35,6 @@ class RuntimeDefaultEvaluator(val frame: JdiFrame, implicit val logger: Logger)
result <- loader.mirrorOfLiteral(value)
} yield result

/* -------------------------------------------------------------------------- */
/* Outer evaluation */
/* -------------------------------------------------------------------------- */
def evaluateOuter(tree: OuterTree): Safe[JdiValue] =
tree match {
case OuterModuleTree(module) => evaluateModule(module)
case outerClass: OuterClassTree =>
eval(outerClass.inner).map(_.asObject.getField("$outer"))
}

/* -------------------------------------------------------------------------- */
/* Field evaluation */
/* -------------------------------------------------------------------------- */
Expand Down Expand Up @@ -99,8 +87,7 @@ class RuntimeDefaultEvaluator(val frame: JdiFrame, implicit val logger: Logger)
def evaluateModule(tree: ModuleTree): Safe[JdiValue] =
tree match {
case TopLevelModuleTree(mod) => Safe(JdiObject(mod.instances(1).get(0), frame.thread))
case NestedModuleTree(mod, of) => helper.initializeModule(mod, eval(of))
// TODO: change the $of attribute to be a Method validated by the validator to avoid crashes at evaluation time
case NestedModuleTree(_, init) => invoke(init)
}

/* -------------------------------------------------------------------------- */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,16 @@ import RuntimeEvaluatorExtractors.*
import scala.util.Failure
import scala.util.Success

import ch.epfl.scala.debugadapter.internal.SourceLookUpProvider

import scala.jdk.CollectionConverters.*

case class Call(fun: Term, argClause: Term.ArgClause)
case class PreparedCall(qual: Validation[RuntimeTree], name: String)

class RuntimeDefaultValidator(val frame: JdiFrame, implicit val logger: Logger) extends RuntimeValidator {
val helper = new RuntimeEvaluationHelpers(frame)
class RuntimeDefaultValidator(val frame: JdiFrame, val sourceLookUp: SourceLookUpProvider, implicit val logger: Logger)
extends RuntimeValidator {
val helper = new RuntimeEvaluationHelpers(frame, sourceLookUp)
import helper.*

protected def parse(expression: String): Validation[Stat] =
Expand Down Expand Up @@ -88,49 +93,52 @@ class RuntimeDefaultValidator(val frame: JdiFrame, implicit val logger: Logger)
.filter(_.`type`.name() != "scala.Function0", runtimeFatal = true)
.map(v => LocalVarTree(name, v.`type`))

// We might sometimes need to access a 'private' attribute of a class
private def fieldLookup(name: String, ref: ReferenceType) =
Option(ref.fieldByName(name))
.orElse { ref.visibleFields().asScala.find(_.name().endsWith("$" + name)) }

def fieldTreeByName(
of: Validation[RuntimeTree],
name: String
): Validation[RuntimeEvaluableTree] =
for {
ref <- extractReferenceType(of)
field <- Validation(ref.fieldByName(name))
_ = loadClassOnNeed(field)
fieldTree <- toStaticIfNeeded(field, of.get)
} yield fieldTree
of match {
case ReferenceTree(ref) =>
for {
field <- Validation.fromOption { fieldLookup(name, ref) }
_ = loadClassOnNeed(field)
fieldTree <- toStaticIfNeeded(field, of.get)
} yield fieldTree
case _ => Recoverable(s"Cannot access field $name from non reference type ${of.get.`type`.name()}")
}

private def inCompanion(name: Option[String], moduleName: String) = name
.filter(_.endsWith("$"))
.map(n => loadClass(n.stripSuffix("$")))
.map {
_.withFilterNot {
_.cls.methodsByName(moduleName.stripSuffix("$")).isEmpty()
}.extract
}
.exists(_.filterNot {
_.`type`.methodsByName(moduleName.stripSuffix("$")).isEmpty()
}.isValid)

// ? When RuntimeTree is not a ThisTree, it might be meaningful to directly load by concatenating the qualifier type's name and the target's name
def validateModule(name: String, of: Option[RuntimeTree]): Validation[RuntimeEvaluableTree] = {
val moduleName = if (name.endsWith("$")) name else name + "$"
val ofName = of.map(_.`type`.name())
searchAllClassesFor(moduleName, ofName).flatMap { moduleCls =>
searchClasses(moduleName, ofName).flatMap { moduleCls =>
val isInModule = inCompanion(ofName, moduleName)

(isInModule, moduleCls.`type`, of) match {
case (Some(Success(cls: JdiClass)), _, _) =>
CompilerRecoverable(s"Cannot access module ${name} from ${ofName}")
case (true, _, _) => CompilerRecoverable(s"Cannot access module ${name} from ${ofName}")
case (_, Module(module), _) => Valid(TopLevelModuleTree(module))
case (_, cls, Some(instance: RuntimeEvaluableTree)) =>
if (cls.name().startsWith(instance.`type`.name()))
Valid(NestedModuleTree(cls, instance))
moduleInitializer(cls, instance)
else Recoverable(s"Cannot access module $moduleCls from ${instance.`type`.name()}")
case _ => Recoverable(s"Cannot access module $moduleCls")
}
}
}

// ? Same as validateModule, but for classes
def validateClass(name: String, of: Validation[RuntimeTree]): Validation[ClassTree] =
searchAllClassesFor(name.stripSuffix("$"), of.map(_.`type`.name()).toOption)
searchClasses(name.stripSuffix("$"), of.map(_.`type`.name()).toOption)
.transform { cls =>
(cls, of) match {
case (Valid(_), Valid(_: RuntimeEvaluableTree) | _: Invalid) => cls
Expand All @@ -149,7 +157,7 @@ class RuntimeDefaultValidator(val frame: JdiFrame, implicit val logger: Logger)
): Validation[RuntimeEvaluableTree] = {
val name = NameTransformer.encode(value)
def field = fieldTreeByName(of, name)
def zeroArg = of.flatMap(methodTreeByNameAndArgs(_, name, List.empty))
def zeroArg = of.flatMap(zeroArgMethodTreeByName(_, name))
def member =
if (methodFirst) zeroArg.orElse(field)
else field.orElse(zeroArg)
Expand Down Expand Up @@ -249,20 +257,15 @@ class RuntimeDefaultValidator(val frame: JdiFrame, implicit val logger: Logger)
/* -------------------------------------------------------------------------- */
/* Looking for $outer */
/* -------------------------------------------------------------------------- */
def validateOuter(tree: RuntimeTree): Validation[RuntimeEvaluableTree] = {
for {
ref <- extractReferenceType(tree)
outer <- outerLookup(ref)
outerTree <- OuterTree(tree, outer)
} yield outerTree
}
def validateOuter(tree: RuntimeTree): Validation[RuntimeEvaluableTree] =
outerLookup(tree)

/* -------------------------------------------------------------------------- */
/* Flow control validation */
/* -------------------------------------------------------------------------- */

def validateIf(tree: Term.If): Validation[RuntimeEvaluableTree] = {
lazy val objType = loadClass("java.lang.Object").extract.get.cls
lazy val objType = loadClass("java.lang.Object").get.`type`
for {
cond <- validate(tree.cond)
thenp <- validate(tree.thenp)
Expand All @@ -279,6 +282,6 @@ class RuntimeDefaultValidator(val frame: JdiFrame, implicit val logger: Logger)
}

object RuntimeDefaultValidator {
def apply(frame: JdiFrame, logger: Logger) =
new RuntimeDefaultValidator(frame, logger)
def apply(frame: JdiFrame, sourceLookUp: SourceLookUpProvider, logger: Logger) =
new RuntimeDefaultValidator(frame, sourceLookUp, logger)
}
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,6 @@ trait RuntimeEvaluator {

def evaluateLiteral(tree: LiteralTree): Safe[JdiValue]

def evaluateOuter(tree: OuterTree): Safe[JdiValue]

def evaluateField(tree: InstanceFieldTree): Safe[JdiValue]

def evaluateStaticField(tree: StaticFieldTree): Safe[JdiValue]
Expand Down
Loading

0 comments on commit c6b2734

Please sign in to comment.