Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance class search #484

Merged
merged 5 commits into from
Jul 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,18 @@ import com.microsoft.java.debug.core.adapter.ISourceLookUpProvider

import java.net.URI
import scala.collection.parallel.immutable.ParVector
import ch.epfl.scala.debugadapter.internal.evaluator.NameTransformer

private[debugadapter] final class SourceLookUpProvider(
private[internal] val classPathEntries: Seq[ClassEntryLookUp],
sourceUriToClassPathEntry: Map[URI, ClassEntryLookUp],
fqcnToClassPathEntry: Map[String, ClassEntryLookUp]
) extends ISourceLookUpProvider {
private val classSearch = new SourceEntrySearchMap
for (entries <- classPathEntries; fqcn <- entries.fullyQualifiedNames) classSearch.insert(fqcn)

val classSearch =
classPathEntries
.flatMap { _.fullyQualifiedNames.filterNot { _.contains("$$anon$") } }
.groupBy { SourceLookUpProvider.getScalaClassName }

override def supportsRealtimeBreakpointVerification(): Boolean = true

Expand Down Expand Up @@ -65,8 +69,8 @@ private[debugadapter] final class SourceLookUpProvider(
classPathEntries.flatMap(_.fullyQualifiedNames)
private[internal] def allOrphanClasses: Iterable[ClassFile] =
classPathEntries.flatMap(_.orphanClassFiles)
private[internal] def classesByName(name: String) =
classSearch.find(name)
private[internal] def classesByName(name: String): Seq[String] =
classSearch.get(name).getOrElse(Seq.empty)

private[internal] def getScalaSig(fqcn: String): Option[ScalaSig] = {
for {
Expand All @@ -86,6 +90,13 @@ private[debugadapter] object SourceLookUpProvider {
def empty: SourceLookUpProvider =
new SourceLookUpProvider(Seq.empty, Map.empty, Map.empty)

def getScalaClassName(className: String): String = {
val lastDot = className.lastIndexOf('.') + 1
val decoded = NameTransformer.decode { className.drop(lastDot) }
val lastDollar = decoded.stripSuffix("$").lastIndexOf('$') + 1
decoded.drop { lastDollar }
}

def apply(entries: Seq[ClassEntry], logger: Logger): SourceLookUpProvider = {
val parrallelEntries = ParVector(entries*)
val sourceFilesByEntry = parrallelEntries
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,13 @@ class RuntimeDefaultValidator(val frame: JdiFrame, val sourceLookUp: SourceLookU
expression match {
case value: Term.Name => validateName(value.value, thisTree).orElse(validateClass(value.value, thisTree))
case Term.Select(qual, name) =>
for {
qual <- validateWithClass(qual)
name <- validateName(name.value, Valid(qual)).orElse(validateClass(name.value, Valid(qual)))
} yield name
validateWithClass(qual).transform {
case qual: Valid[?] =>
validateName(name.value, qual)
.orElse { validateClass(name.value, qual) }
case _: Invalid =>
searchClassesQCN(qual.toString + "." + name.value)
}
case _ => validate(expression)
}

Expand Down Expand Up @@ -115,21 +118,23 @@ class RuntimeDefaultValidator(val frame: JdiFrame, val sourceLookUp: SourceLookU
private def inCompanion(name: Option[String], moduleName: String) = name
.filter(_.endsWith("$"))
.map(n => loadClass(n.stripSuffix("$")))
.exists(_.filterNot {
_.`type`.methodsByName(moduleName.stripSuffix("$")).isEmpty()
}.isValid)
.exists {
_.filterNot {
_.methodsByName(moduleName.stripSuffix("$")).isEmpty()
}.isValid
}

def validateModule(name: String, of: Option[RuntimeTree]): Validation[RuntimeEvaluableTree] = {
val moduleName = if (name.endsWith("$")) name else name + "$"
val ofName = of.map(_.`type`.name())
searchClasses(moduleName, ofName).flatMap { moduleCls =>
val isInModule = inCompanion(ofName, moduleName)

(isInModule, moduleCls.`type`, of) match {
(isInModule, moduleCls, of) match {
case (true, _, _) => CompilerRecoverable(s"Cannot access module ${name} from ${ofName}")
case (_, Module(module), _) => Valid(TopLevelModuleTree(module))
case (_, Module(_), _) => Valid(TopLevelModuleTree(moduleCls))
case (_, cls, Some(instance: RuntimeEvaluableTree)) =>
if (cls.name().startsWith(instance.`type`.name()))
if (cls.name.startsWith(instance.`type`.name()))
moduleInitializer(cls, instance)
else Recoverable(s"Cannot access module $moduleCls from ${instance.`type`.name()}")
case _ => Recoverable(s"Cannot access module $moduleCls")
Expand All @@ -139,14 +144,18 @@ class RuntimeDefaultValidator(val frame: JdiFrame, val sourceLookUp: SourceLookU

def validateClass(name: String, of: Validation[RuntimeTree]): Validation[ClassTree] =
searchClasses(name.stripSuffix("$"), of.map(_.`type`.name()).toOption)
.transform { cls =>
(cls, of) match {
case (Valid(_), Valid(_: RuntimeEvaluableTree) | _: Invalid) => cls
case (Valid(c), Valid(ct: ClassTree)) =>
if (c.`type`.isStatic()) cls
else CompilerRecoverable(s"Cannot access non-static class ${c.`type`.name} from ${ct.`type`.name()}")
case (_, Valid(value)) => validateOuter(value).flatMap(o => validateClass(name, Valid(o)))
case (_, _: Invalid) => Recoverable(s"Cannot find class $name")
.flatMap { cls =>
of match {
case Valid(_: RuntimeEvaluableTree) | _: Invalid => Valid(ClassTree(cls))
case Valid(ct: ClassTree) =>
if (cls.isStatic()) Valid(ClassTree(cls))
else CompilerRecoverable(s"Cannot access non-static class ${cls.name} from ${ct.`type`.name()}")
}
}
.orElse {
of match {
case Valid(value) => validateOuter(value).flatMap(o => validateClass(name, Valid(o)))
case _ => Recoverable(s"Cannot access class $name")
}
}

Expand All @@ -165,16 +174,15 @@ class RuntimeDefaultValidator(val frame: JdiFrame, val sourceLookUp: SourceLookU
of
.flatMap { of =>
member
.orElse(validateModule(name, Some(of)))
.orElse(validateModule(value, Some(of)))
.orElse(validateOuter(of).flatMap(o => validateName(value, Valid(o), methodFirst)))
}
.orElse {
of match {
case Valid(_: ThisTree) | _: Recoverable => localVarTreeByName(name)
case _ => Recoverable(s"${name} is not a local variable")
case _ => Recoverable(s"${value} is not a local variable")
}
}
.orElse { validateModule(name, None) }
}

/* -------------------------------------------------------------------------- */
Expand Down Expand Up @@ -216,19 +224,36 @@ class RuntimeDefaultValidator(val frame: JdiFrame, val sourceLookUp: SourceLookU

def validateMethod(call: Call): Validation[RuntimeEvaluableTree] = {
lazy val preparedCall = call.fun match {
case select: Term.Select =>
PreparedCall(validateWithClass(select.qual), select.name.value)
case Term.Select(qual, name) =>
val qualTree = validateWithClass(qual).orElse {
searchClassesQCN(qual.toString + "$")
}
PreparedCall(qualTree, name.value)
case name: Term.Name => PreparedCall(thisTree, name.value)
}

for {
args <- call.argClause.map(validate).traverse
val validatedArgs = call.argClause.map(validate).traverse

val method = for {
args <- validatedArgs
lhs <- preparedCall.qual
methodTree <-
PrimitiveUnaryOpTree(lhs, preparedCall.name)
.orElse { PrimitiveBinaryOpTree(lhs, args, preparedCall.name) }
.orElse { findMethod(lhs, preparedCall.name, args) }
} yield methodTree

call.fun match {
case Term.Select(qual, name) =>
method.orElse {
for {
cls <- searchClassesQCN(qual.toString + "." + name.value)
args <- validatedArgs
m <- validateApply(cls, args)
} yield m
}
case _ => method
}
}

/* -------------------------------------------------------------------------- */
Expand All @@ -249,8 +274,8 @@ class RuntimeDefaultValidator(val frame: JdiFrame, val sourceLookUp: SourceLookU
for {
args <- argClauses.flatMap(_.map(validate(_))).traverse
(outer, cls) <- validateType(tpe, thisTree.toOption)(validate)
allArgs = outer.filter(_ => needsOuter(cls.`type`)).toSeq ++ args
newInstance <- newInstanceTreeByArgs(cls.`type`, allArgs)
allArgs = outer.filter(_ => needsOuter(cls)).toSeq ++ args
newInstance <- newInstanceTreeByArgs(cls, allArgs)
} yield newInstance
}

Expand All @@ -265,7 +290,7 @@ class RuntimeDefaultValidator(val frame: JdiFrame, val sourceLookUp: SourceLookU
/* -------------------------------------------------------------------------- */

def validateIf(tree: Term.If): Validation[RuntimeEvaluableTree] = {
lazy val objType = loadClass("java.lang.Object").get.`type`
lazy val objType = loadClass("java.lang.Object").get
for {
cond <- validate(tree.cond)
thenp <- validate(tree.thenp)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,13 @@ private[evaluator] class RuntimeEvaluationHelpers(frame: JdiFrame, sourceLookup:
/* -------------------------------------------------------------------------- */
/* Class helpers */
/* -------------------------------------------------------------------------- */
def loadClass(name: String): Validation[ClassTree] =
Validation.fromTry(frame.classLoader().flatMap(_.loadClass(name)).extract(c => ClassTree(c.cls)))
def loadClass(name: String): Validation[ClassType] =
Validation.fromTry {
frame
.classLoader()
.flatMap(_.loadClass(name))
.extract(_.cls)
}

def checkClassStatus(tpe: => Type) = Try(tpe) match {
case Failure(e: ClassNotLoadedException) => loadClass(e.className())
Expand All @@ -248,8 +253,8 @@ private[evaluator] class RuntimeEvaluationHelpers(frame: JdiFrame, sourceLookup:

// ! TO REFACTOR :sob:
def resolveInnerType(qual: Type, name: String) = {
var tpe: Validation[ClassTree] = Recoverable(s"Cannot find outer class for $qual")
def loop(on: Type): Validation[ClassTree] =
var tpe: Validation[ClassType] = Recoverable(s"Cannot find outer class for $qual")
def loop(on: Type): Validation[ClassType] =
on match {
case _: ArrayType | _: PrimitiveType | _: VoidType =>
Recoverable("Cannot find outer class on non reference type")
Expand Down Expand Up @@ -289,7 +294,7 @@ private[evaluator] class RuntimeEvaluationHelpers(frame: JdiFrame, sourceLookup:

def validateType(tpe: MType, thisTree: Option[RuntimeEvaluableTree])(
termValidation: Term => Validation[RuntimeEvaluableTree]
): Validation[(Option[RuntimeEvaluableTree], ClassTree)] =
): Validation[(Option[RuntimeEvaluableTree], ClassType)] =
tpe match {
case MType.Name(name) =>
// won't work if the class is defined in one of the outer of this
Expand All @@ -299,9 +304,11 @@ private[evaluator] class RuntimeEvaluationHelpers(frame: JdiFrame, sourceLookup:
qual <- termValidation(qual)
tpe <- resolveInnerType(qual.`type`, name.value)
} yield
if (tpe.`type`.isStatic()) (None, tpe)
if (tpe.isStatic()) (None, tpe)
else (Some(qual), tpe)
cls.orElse(searchClasses(qual.toString + "." + name.value, thisTree.map(_.`type`.name)).map((None, _)))
cls.orElse {
searchClassesQCN(qual.toString + "." + name.value).map(c => (None, c.`type`.asInstanceOf[ClassType]))
}
case _ => Recoverable("Type not supported at runtime")
}

Expand All @@ -314,55 +321,48 @@ private[evaluator] class RuntimeEvaluationHelpers(frame: JdiFrame, sourceLookup:
.orElse {
removeLastInnerTypeFromFQCN(tree.`type`.name())
.map(name => loadClass(name + "$")) match {
case Some(Valid(Module(mod))) => Valid(mod)
case Some(Valid(Module(mod))) => Valid(TopLevelModuleTree(mod))
case res => Recoverable(s"Cannot find $$outer for ${tree.`type`.name()}}")
}
}
case _ => Recoverable(s"Cannot find $$outer for non-reference type ${tree.`type`.name()}}")
}

def searchAllClassesFor(name: String, in: Option[String]): Validation[ClassTree] = {
def declaringTypeName = frame.current().location().declaringType().name()
def searchClasses(name: String, in: Option[String]): Validation[ClassType] = {
def baseName = in.getOrElse { frame.current().location().declaringType().name() }

val candidates = sourceLookup.classesByName(name)

val bestMatch = candidates.size match {
case 0 | 1 => candidates
case _ =>
candidates.filter(_.startsWith {
in.getOrElse(declaringTypeName)
})
candidates.filter { name =>
name.contains(s".$baseName") || name.startsWith(baseName)
}
}

bestMatch
.validateSingle(s"Cannot find class $name")
.flatMap { loadClass }
}

def searchClasses(name: String, in: Option[String]): Validation[ClassTree] = {
if (isFQCN(name)) loadClass(name)
else searchAllClassesFor(name, in)
}

def isFQCN(name: String): Boolean = {
val regex = """([\w\.]+)\.([^\.]+)""".r
name match {
case regex(_, _) => true
case _ => false
}
def searchClassesQCN(partialClassName: String): Validation[RuntimeTree] = {
val name = SourceLookUpProvider.getScalaClassName(partialClassName)
searchClasses(name + "$", Some(partialClassName))
.map { TopLevelModuleTree(_) }
.orElse { searchClasses(name, Some(partialClassName)).map { ClassTree(_) } }
}

/* -------------------------------------------------------------------------- */
/* Initialize module */
/* -------------------------------------------------------------------------- */
def moduleInitializer(modCls: ClassType, of: RuntimeEvaluableTree): Validation[NestedModuleTree] =
for {
initMethodName <- Validation.fromOption(getLastInnerType(modCls.name()))
initMethod <- of.`type` match {
case ref: ReferenceType => zeroArgMethodByName(ref, initMethodName)
case _ => Recoverable(s"Cannot find module initializer for non-reference type $modCls")
}
} yield NestedModuleTree(modCls, InstanceMethodTree(initMethod, Seq(), of))
of.`type` match {
case ref: ReferenceType =>
zeroArgMethodByName(ref, SourceLookUpProvider.getScalaClassName(modCls.name).stripSuffix("$"))
.map(m => NestedModuleTree(modCls, InstanceMethodTree(m, Seq.empty, of)))
case _ => Recoverable(s"Cannot find module initializer for non-reference type $modCls")
}

def illegalAccess(x: Any, typeName: String) = Fatal {
new ClassCastException(s"Cannot cast $x to $typeName")
Expand All @@ -383,19 +383,13 @@ private[evaluator] class RuntimeEvaluationHelpers(frame: JdiFrame, sourceLookup:
/* -------------------------------------------------------------------------- */
/* Nested types regex */
/* -------------------------------------------------------------------------- */
def getLastInnerType(className: String): Option[String] = {
val pattern = """(.+\$)([^$]+)$""".r
className.stripSuffix("$") match {
case pattern(_, innerType) => Some(innerType)
case _ => None
}
}

def removeLastInnerTypeFromFQCN(className: String): Option[String] = {
val pattern = """(.+)\$[\w]+\${0,1}$""".r
className match {
case pattern(baseName) => Some(baseName)
case _ => None
val (packageName, clsName) = className.splitAt(className.lastIndexOf('.') + 1)
val name = NameTransformer.decode(clsName)
val lastDollar = name.stripSuffix("$").lastIndexOf('$')
lastDollar match {
case -1 => None
case _ => Some(packageName + name.dropRight(name.length - lastDollar))
}
}

Expand Down
Loading
Loading