Skip to content

Commit

Permalink
Adding Clause Interleaving to method definitions (#14019)
Browse files Browse the repository at this point in the history
This aims to add the ability to declare functions with many clauses of
type parameters, instead of at most one, and to allow those clauses to
be interleaved with term clauses:
```scala
def foo[A](x: A)[B](y: B) = (x, y)
```
All user-facing details can be found in the [Scala 3 new features
doc](https://github.com/lampepfl/dotty/blob/2a079f92bfc05420e90987d908c41589ac94418d/docs/_docs/reference/other-new-features/generalized-method-syntax.md),
and in the [SIP
proposal](scala/improvement-proposals#47)

The implementation details are described below in the commit messages

The community's opinion of the feature can be found on [the scala
contributors
forum](https://contributors.scala-lang.org/t/clause-interweaving-allowing-def-f-t-x-t-u-y-u)
(note however that the description there is somewhat outdated)

### Dependencies

This feature has been [accepted by the SIP committee for
implementation](scala/improvement-proposals#47 (comment)),
it can therefore become part of the language as an experimental feature
at any time.

The feature will be available with `import
scala.language.experimental.clauseInterleaving`

### How to make non-experimental
`git revert` the commits named "Make clause interleaving experimental"
and "Add import to tests"

### Future Work
1. Implement given aliases with clause interweaving: (to have types
depends on terms)
```scala
given myGiven[T](using x: T)[U](using y: U) = (x, y)
```
2. Add interleaved clauses to the left-hand side of extension methods:
```scala
extension (using Int)[A](using A)(a: A)[B](using B)
  def foo: (A, B) = ???
```
3. Investigate usefulness/details of clause interweaving for classes and
type currying for types:
```scala
class Foo[A](a: A)[B](b: B)
new Foo(0)("Hello!") // type: Foo[Int][String] ?

type Bar[A][B] = Map[A, B]
Bar[Char] // should this mean [B] =>> Bar[Char][B] ?
```

(Started as a semester project with supervision from @smarter, now part
of my missions as an intern at the scala center)
  • Loading branch information
Quentin Bernet authored Feb 3, 2023
2 parents af95ceb + 65091c3 commit ef815fd
Show file tree
Hide file tree
Showing 45 changed files with 942 additions and 228 deletions.
12 changes: 7 additions & 5 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -915,16 +915,16 @@ object desugar {
name = normalizeName(mdef, mdef.tpt).asTermName,
paramss =
if mdef.name.isRightAssocOperatorName then
val (typaramss, paramss) = mdef.paramss.span(isTypeParamClause) // first extract type parameters
val (rightTyParams, paramss) = mdef.paramss.span(isTypeParamClause) // first extract type parameters

paramss match
case params :: paramss1 => // `params` must have a single parameter and without `given` flag
case rightParam :: paramss1 => // `rightParam` must have a single parameter and without `given` flag

def badRightAssoc(problem: String) =
report.error(em"right-associative extension method $problem", mdef.srcPos)
extParamss ++ mdef.paramss

params match
rightParam match
case ValDefs(vparam :: Nil) =>
if !vparam.mods.is(Given) then
// we merge the extension parameters with the method parameters,
Expand All @@ -934,8 +934,10 @@ object desugar {
// def %:[E](f: F)(g: G)(using H): Res = ???
// will be encoded as
// def %:[A](using B)[E](f: F)(c: C)(using D)(g: G)(using H): Res = ???
val (leadingUsing, otherExtParamss) = extParamss.span(isUsingOrTypeParamClause)
leadingUsing ::: typaramss ::: params :: otherExtParamss ::: paramss1
//
// If you change the names of the clauses below, also change them in right-associative-extension-methods.md
val (leftTyParamsAndLeadingUsing, leftParamAndTrailingUsing) = extParamss.span(isUsingOrTypeParamClause)
leftTyParamsAndLeadingUsing ::: rightTyParams ::: rightParam :: leftParamAndTrailingUsing ::: paramss1
else
badRightAssoc("cannot start with using clause")
case _ =>
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/dotty/tools/dotc/config/Feature.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ object Feature:
val symbolLiterals = deprecated("symbolLiterals")
val fewerBraces = experimental("fewerBraces")
val saferExceptions = experimental("saferExceptions")
val clauseInterleaving = experimental("clauseInterleaving")
val pureFunctions = experimental("pureFunctions")
val captureChecking = experimental("captureChecking")
val into = experimental("into")
Expand Down Expand Up @@ -76,6 +77,8 @@ object Feature:

def namedTypeArgsEnabled(using Context) = enabled(namedTypeArguments)

def clauseInterleavingEnabled(using Context) = enabled(clauseInterleaving)

def genericNumberLiteralsEnabled(using Context) = enabled(genericNumberLiterals)

def scala2ExperimentalMacroEnabled(using Context) = enabled(scala2macros)
Expand Down
144 changes: 100 additions & 44 deletions compiler/src/dotty/tools/dotc/parsing/Parsers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3079,6 +3079,42 @@ object Parsers {

/* -------- PARAMETERS ------------------------------------------- */

/** DefParamClauses ::= DefParamClause { DefParamClause } -- and two DefTypeParamClause cannot be adjacent
* DefParamClause ::= DefTypeParamClause
* | DefTermParamClause
* | UsingParamClause
*/
def typeOrTermParamClauses(
ownerKind: ParamOwner,
numLeadParams: Int = 0
): List[List[TypeDef] | List[ValDef]] =

def recur(firstClause: Boolean, numLeadParams: Int, prevIsTypeClause: Boolean): List[List[TypeDef] | List[ValDef]] =
newLineOptWhenFollowedBy(LPAREN)
newLineOptWhenFollowedBy(LBRACKET)
if in.token == LPAREN then
val paramsStart = in.offset
val params = termParamClause(
numLeadParams,
firstClause = firstClause)
val lastClause = params.nonEmpty && params.head.mods.flags.is(Implicit)
params :: (
if lastClause then Nil
else recur(firstClause = false, numLeadParams + params.length, prevIsTypeClause = false))
else if in.token == LBRACKET then
if prevIsTypeClause then
syntaxError(
em"Type parameter lists must be separated by a term or using parameter list",
in.offset
)
typeParamClause(ownerKind) :: recur(firstClause, numLeadParams, prevIsTypeClause = true)
else Nil
end recur

recur(firstClause = true, numLeadParams = numLeadParams, prevIsTypeClause = false)
end typeOrTermParamClauses


/** ClsTypeParamClause::= ‘[’ ClsTypeParam {‘,’ ClsTypeParam} ‘]’
* ClsTypeParam ::= {Annotation} [‘+’ | ‘-’]
* id [HkTypeParamClause] TypeParamBounds
Expand Down Expand Up @@ -3132,34 +3168,39 @@ object Parsers {

/** ContextTypes ::= FunArgType {‘,’ FunArgType}
*/
def contextTypes(ofClass: Boolean, nparams: Int, impliedMods: Modifiers): List[ValDef] =
def contextTypes(ofClass: Boolean, numLeadParams: Int, impliedMods: Modifiers): List[ValDef] =
val tps = commaSeparated(funArgType)
var counter = nparams
var counter = numLeadParams
def nextIdx = { counter += 1; counter }
val paramFlags = if ofClass then LocalParamAccessor else Param
tps.map(makeSyntheticParameter(nextIdx, _, paramFlags | Synthetic | impliedMods.flags))

/** ClsParamClause ::= ‘(’ [‘erased’] ClsParams ‘)’ | UsingClsParamClause
* UsingClsParamClause::= ‘(’ ‘using’ [‘erased’] (ClsParams | ContextTypes) ‘)’
/** ClsTermParamClause ::= ‘(’ [‘erased’] ClsParams ‘)’ | UsingClsTermParamClause
* UsingClsTermParamClause::= ‘(’ ‘using’ [‘erased’] (ClsParams | ContextTypes) ‘)’
* ClsParams ::= ClsParam {‘,’ ClsParam}
* ClsParam ::= {Annotation}
*
* TypelessClause ::= DefTermParamClause
* | UsingParamClause
*
* DefParamClause ::= ‘(’ [‘erased’] DefParams ‘)’ | UsingParamClause
* UsingParamClause ::= ‘(’ ‘using’ [‘erased’] (DefParams | ContextTypes) ‘)’
* DefParams ::= DefParam {‘,’ DefParam}
* DefParam ::= {Annotation} [‘inline’] Param
* DefTermParamClause::= [nl] ‘(’ [DefTermParams] ‘)’
* UsingParamClause ::= ‘(’ ‘using’ [‘erased’] (DefTermParams | ContextTypes) ‘)’
* DefImplicitClause ::= [nl] ‘(’ ‘implicit’ DefTermParams ‘)’
* DefTermParams ::= DefTermParam {‘,’ DefTermParam}
* DefTermParam ::= {Annotation} [‘inline’] Param
*
* Param ::= id `:' ParamType [`=' Expr]
*
* @return the list of parameter definitions
*/
def paramClause(nparams: Int, // number of parameters preceding this clause
ofClass: Boolean = false, // owner is a class
ofCaseClass: Boolean = false, // owner is a case class
prefix: Boolean = false, // clause precedes name of an extension method
givenOnly: Boolean = false, // only given parameters allowed
firstClause: Boolean = false // clause is the first in regular list of clauses
): List[ValDef] = {
def termParamClause(
numLeadParams: Int, // number of parameters preceding this clause
ofClass: Boolean = false, // owner is a class
ofCaseClass: Boolean = false, // owner is a case class
prefix: Boolean = false, // clause precedes name of an extension method
givenOnly: Boolean = false, // only given parameters allowed
firstClause: Boolean = false // clause is the first in regular list of clauses
): List[ValDef] = {
var impliedMods: Modifiers = EmptyModifiers

def addParamMod(mod: () => Mod) = impliedMods = addMod(impliedMods, atSpan(in.skipToken()) { mod() })
Expand Down Expand Up @@ -3224,7 +3265,7 @@ object Parsers {
checkVarArgsRules(rest)
}

// begin paramClause
// begin termParamClause
inParens {
if in.token == RPAREN && !prefix && !impliedMods.is(Given) then Nil
else
Expand All @@ -3239,41 +3280,43 @@ object Parsers {
|| startParamTokens.contains(in.token)
|| isIdent && (in.name == nme.inline || in.lookahead.isColon)
if isParams then commaSeparated(() => param())
else contextTypes(ofClass, nparams, impliedMods)
else contextTypes(ofClass, numLeadParams, impliedMods)
checkVarArgsRules(clause)
clause
}
}

/** ClsParamClauses ::= {ClsParamClause} [[nl] ‘(’ [‘implicit’] ClsParams ‘)’]
* DefParamClauses ::= {DefParamClause} [[nl] ‘(’ [‘implicit’] DefParams ‘)’]
/** ClsTermParamClauses ::= {ClsTermParamClause} [[nl] ‘(’ [‘implicit’] ClsParams ‘)’]
* TypelessClauses ::= TypelessClause {TypelessClause}
*
* @return The parameter definitions
*/
def paramClauses(ofClass: Boolean = false,
ofCaseClass: Boolean = false,
givenOnly: Boolean = false,
numLeadParams: Int = 0): List[List[ValDef]] =
def termParamClauses(
ofClass: Boolean = false,
ofCaseClass: Boolean = false,
givenOnly: Boolean = false,
numLeadParams: Int = 0
): List[List[ValDef]] =

def recur(firstClause: Boolean, nparams: Int): List[List[ValDef]] =
def recur(firstClause: Boolean, numLeadParams: Int): List[List[ValDef]] =
newLineOptWhenFollowedBy(LPAREN)
if in.token == LPAREN then
val paramsStart = in.offset
val params = paramClause(
nparams,
val params = termParamClause(
numLeadParams,
ofClass = ofClass,
ofCaseClass = ofCaseClass,
givenOnly = givenOnly,
firstClause = firstClause)
val lastClause = params.nonEmpty && params.head.mods.flags.is(Implicit)
params :: (
if lastClause then Nil
else recur(firstClause = false, nparams + params.length))
else recur(firstClause = false, numLeadParams + params.length))
else Nil
end recur

recur(firstClause = true, numLeadParams)
end paramClauses
end termParamClauses

/* -------- DEFS ------------------------------------------- */

Expand Down Expand Up @@ -3514,11 +3557,15 @@ object Parsers {
}
}



/** DefDef ::= DefSig [‘:’ Type] ‘=’ Expr
* | this ParamClause ParamClauses `=' ConstrExpr
* | this TypelessClauses [DefImplicitClause] `=' ConstrExpr
* DefDcl ::= DefSig `:' Type
* DefSig ::= id [DefTypeParamClause] DefParamClauses
* | ExtParamClause [nl] [‘.’] id DefParamClauses
* DefSig ::= id [DefTypeParamClause] DefTermParamClauses
*
* if clauseInterleaving is enabled:
* DefSig ::= id [DefParamClauses] [DefImplicitClause]
*/
def defDefOrDcl(start: Offset, mods: Modifiers, numLeadParams: Int = 0): DefDef = atSpan(start, nameStart) {

Expand All @@ -3537,7 +3584,7 @@ object Parsers {

if (in.token == THIS) {
in.nextToken()
val vparamss = paramClauses(numLeadParams = numLeadParams)
val vparamss = termParamClauses(numLeadParams = numLeadParams)
if (vparamss.isEmpty || vparamss.head.take(1).exists(_.mods.isOneOf(GivenOrImplicit)))
in.token match {
case LBRACKET => syntaxError(em"no type parameters allowed here")
Expand All @@ -3555,9 +3602,18 @@ object Parsers {
val mods1 = addFlag(mods, Method)
val ident = termIdent()
var name = ident.name.asTermName
val tparams = typeParamClauseOpt(ParamOwner.Def)
val vparamss = paramClauses(numLeadParams = numLeadParams)
val paramss =
if in.featureEnabled(Feature.clauseInterleaving) then
// If you are making interleaving stable manually, please refer to the PR introducing it instead, section "How to make non-experimental"
typeOrTermParamClauses(ParamOwner.Def, numLeadParams = numLeadParams)
else
val tparams = typeParamClauseOpt(ParamOwner.Def)
val vparamss = termParamClauses(numLeadParams = numLeadParams)

joinParams(tparams, vparamss)

var tpt = fromWithinReturnType { typedOpt() }

if (migrateTo3) newLineOptWhenFollowedBy(LBRACE)
val rhs =
if in.token == EQUALS then
Expand All @@ -3574,7 +3630,7 @@ object Parsers {
accept(EQUALS)
expr()

val ddef = DefDef(name, joinParams(tparams, vparamss), tpt, rhs)
val ddef = DefDef(name, paramss, tpt, rhs)
if (isBackquoted(ident)) ddef.pushAttachment(Backquoted, ())
finalizeDef(ddef, mods1, start)
}
Expand Down Expand Up @@ -3695,12 +3751,12 @@ object Parsers {
val templ = templateOpt(constr)
finalizeDef(TypeDef(name, templ), mods, start)

/** ClassConstr ::= [ClsTypeParamClause] [ConstrMods] ClsParamClauses
/** ClassConstr ::= [ClsTypeParamClause] [ConstrMods] ClsTermParamClauses
*/
def classConstr(isCaseClass: Boolean = false): DefDef = atSpan(in.lastOffset) {
val tparams = typeParamClauseOpt(ParamOwner.Class)
val cmods = fromWithinClassConstr(constrModsOpt())
val vparamss = paramClauses(ofClass = true, ofCaseClass = isCaseClass)
val vparamss = termParamClauses(ofClass = true, ofCaseClass = isCaseClass)
makeConstructor(tparams, vparamss).withMods(cmods)
}

Expand Down Expand Up @@ -3802,7 +3858,7 @@ object Parsers {
newLineOpt()
val vparamss =
if in.token == LPAREN && in.lookahead.isIdent(nme.using)
then paramClauses(givenOnly = true)
then termParamClauses(givenOnly = true)
else Nil
newLinesOpt()
val noParams = tparams.isEmpty && vparamss.isEmpty
Expand Down Expand Up @@ -3837,32 +3893,32 @@ object Parsers {
finalizeDef(gdef, mods1, start)
}

/** Extension ::= ‘extension’ [DefTypeParamClause] {UsingParamClause} ‘(’ DefParam ‘)’
/** Extension ::= ‘extension’ [DefTypeParamClause] {UsingParamClause} ‘(’ DefTermParam ‘)’
* {UsingParamClause} ExtMethods
*/
def extension(): ExtMethods =
val start = in.skipToken()
val tparams = typeParamClauseOpt(ParamOwner.Def)
val leadParamss = ListBuffer[List[ValDef]]()
def nparams = leadParamss.map(_.length).sum
def numLeadParams = leadParamss.map(_.length).sum
while
val extParams = paramClause(nparams, prefix = true)
val extParams = termParamClause(numLeadParams, prefix = true)
leadParamss += extParams
isUsingClause(extParams)
do ()
leadParamss ++= paramClauses(givenOnly = true, numLeadParams = nparams)
leadParamss ++= termParamClauses(givenOnly = true, numLeadParams = numLeadParams)
if in.isColon then
syntaxError(em"no `:` expected here")
in.nextToken()
val methods: List[Tree] =
if in.token == EXPORT then
exportClause()
else if isDefIntro(modifierTokens) then
extMethod(nparams) :: Nil
extMethod(numLeadParams) :: Nil
else
in.observeIndented()
newLineOptWhenFollowedBy(LBRACE)
if in.isNestedStart then inDefScopeBraces(extMethods(nparams))
if in.isNestedStart then inDefScopeBraces(extMethods(numLeadParams))
else { syntaxErrorOrIncomplete(em"Extension without extension methods") ; Nil }
val result = atSpan(start)(ExtMethods(joinParams(tparams, leadParamss.toList), methods))
val comment = in.getDocComment(start)
Expand Down
29 changes: 15 additions & 14 deletions compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -896,30 +896,31 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
if isExtension then
val paramss =
if tree.name.isRightAssocOperatorName then
// If you change the names of the clauses below, also change them in right-associative-extension-methods.md
// we have the following encoding of tree.paramss:
// (leadingTyParamss ++ leadingUsing
// ++ rightTyParamss ++ rightParamss
// ++ leftParamss ++ trailingUsing ++ rest)
// (leftTyParams ++ leadingUsing
// ++ rightTyParams ++ rightParam
// ++ leftParam ++ trailingUsing ++ rest)
// e.g.
// extension [A](using B)(c: C)(using D)
// def %:[E](f: F)(g: G)(using H): Res = ???
// will have the following values:
// - leadingTyParamss = List(`[A]`)
// - leftTyParams = List(`[A]`)
// - leadingUsing = List(`(using B)`)
// - rightTyParamss = List(`[E]`)
// - rightParamss = List(`(f: F)`)
// - leftParamss = List(`(c: C)`)
// - rightTyParams = List(`[E]`)
// - rightParam = List(`(f: F)`)
// - leftParam = List(`(c: C)`)
// - trailingUsing = List(`(using D)`)
// - rest = List(`(g: G)`, `(using H)`)
// we need to swap (rightTyParams ++ rightParamss) with (leftParamss ++ trailingUsing)
val (leadingTyParamss, rest1) = tree.paramss.span(isTypeParamClause)
// we need to swap (rightTyParams ++ rightParam) with (leftParam ++ trailingUsing)
val (leftTyParams, rest1) = tree.paramss.span(isTypeParamClause)
val (leadingUsing, rest2) = rest1.span(isUsingClause)
val (rightTyParamss, rest3) = rest2.span(isTypeParamClause)
val (rightParamss, rest4) = rest3.splitAt(1)
val (leftParamss, rest5) = rest4.splitAt(1)
val (rightTyParams, rest3) = rest2.span(isTypeParamClause)
val (rightParam, rest4) = rest3.splitAt(1)
val (leftParam, rest5) = rest4.splitAt(1)
val (trailingUsing, rest6) = rest5.span(isUsingClause)
if leftParamss.nonEmpty then
leadingTyParamss ::: leadingUsing ::: leftParamss ::: trailingUsing ::: rightTyParamss ::: rightParamss ::: rest6
if leftParam.nonEmpty then
leftTyParams ::: leadingUsing ::: leftParam ::: trailingUsing ::: rightTyParams ::: rightParam ::: rest6
else
tree.paramss // it wasn't a binary operator, after all.
else
Expand Down
Loading

0 comments on commit ef815fd

Please sign in to comment.