Skip to content

Commit

Permalink
compile: make closures compile properly
Browse files Browse the repository at this point in the history
  • Loading branch information
ncw committed Apr 30, 2015
1 parent 329523c commit 5a2a35a
Show file tree
Hide file tree
Showing 5 changed files with 428 additions and 63 deletions.
254 changes: 203 additions & 51 deletions compile/compile.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
// compile python code

package compile

import (
Expand Down Expand Up @@ -40,12 +39,26 @@ func (ls loopstack) Top() *loop {
return &ls[len(ls)-1]
}

type compilerScopeType uint8

const (
compilerScopeModule compilerScopeType = iota
compilerScopeClass
compilerScopeFunction
compilerScopeLambda
compilerScopeComprehension
)

// State for the compiler
type compiler struct {
Code *py.Code // code being built up
OpCodes Instructions
loops loopstack
SymTable *symtable.SymTable
Code *py.Code // code being built up
OpCodes Instructions
loops loopstack
SymTable *symtable.SymTable
scopeType compilerScopeType
qualname string
parent *compiler
depth int
}

// Set in py to avoid circular import
Expand All @@ -66,7 +79,7 @@ func init() {
// the effects of any future statements in effect in the code calling
// compile; if absent or zero these statements do influence the compilation,
// in addition to any features explicitly specified.
func Compile(str, filename, mode string, flags int, dont_inherit bool) (py.Object, error) {
func Compile(str, filename, mode string, futureFlags int, dont_inherit bool) (py.Object, error) {
// Parse Ast
Ast, err := parser.ParseString(str, mode)
if err != nil {
Expand All @@ -77,11 +90,27 @@ func Compile(str, filename, mode string, flags int, dont_inherit bool) (py.Objec
if err != nil {
return nil, err
}
return CompileAst(Ast, filename, flags, dont_inherit, SymTable)
c := newCompiler(nil, compilerScopeModule)
return c.compileAst(Ast, filename, futureFlags, dont_inherit, SymTable)
}

// Make a new compiler
func newCompiler(parent *compiler, scopeType compilerScopeType) *compiler {
c := &compiler{
// Code: code,
// SymTable: SymTable,
parent: parent,
scopeType: scopeType,
depth: 1,
}
if parent != nil {
c.depth = parent.depth + 1
}
return c
}

// As Compile but takes an Ast
func CompileAst(Ast ast.Ast, filename string, flags int, dont_inherit bool, SymTable *symtable.SymTable) (code *py.Code, err error) {
func (c *compiler) compileAst(Ast ast.Ast, filename string, futureFlags int, dont_inherit bool, SymTable *symtable.SymTable) (code *py.Code, err error) {
defer func() {
if r := recover(); r != nil {
err = py.MakeException(r)
Expand All @@ -90,14 +119,19 @@ func CompileAst(Ast ast.Ast, filename string, flags int, dont_inherit bool, SymT
//fmt.Println(ast.Dump(Ast))
code = &py.Code{
Filename: filename,
Firstlineno: 1, // FIXME
Name: "<module>", // FIXME
Flags: int32(flags | py.CO_NOFREE), // FIXME
}
c := &compiler{
Code: code,
SymTable: SymTable,
Firstlineno: 1, // FIXME
Name: "<module>", // FIXME
// Argcount: int32(len(node.Args.Args)),
// Name: string(node.Name),
// Kwonlyargcount: int32(len(node.Args.Kwonlyargs)),
// Nlocals: int32(len(SymTable.Varnames)),
}
c.Code = code
c.SymTable = SymTable
code.Varnames = append(code.Varnames, SymTable.Varnames...)
code.Cellvars = SymTable.Find(symtable.ScopeCell, 0)
code.Freevars = SymTable.Find(symtable.ScopeFree, symtable.DefFreeClass)
code.Flags = c.codeFlags(SymTable) | int32(futureFlags&py.CO_COMPILER_FLAGS_MASK)
valueOnStack := false
switch node := Ast.(type) {
case *ast.Module:
Expand All @@ -112,10 +146,13 @@ func CompileAst(Ast ast.Ast, filename string, flags int, dont_inherit bool, SymT
case ast.Expr:
// Make None the first constant as lambda can't have a docstring
c.Const(py.None)
c.Code.Name = "<lambda>"
code.Name = "<lambda>"
c.setQualname() // FIXME is this in the right place!
c.Expr(node)
valueOnStack = true
case *ast.FunctionDef:
code.Name = string(node.Name)
c.setQualname() // FIXME is this in the right place!
c.Stmts(c.docString(node.Body))
default:
panic(py.ExceptionNewf(py.SyntaxError, "Unknown ModuleBase: %v", Ast))
Expand All @@ -130,6 +167,7 @@ func CompileAst(Ast ast.Ast, filename string, flags int, dont_inherit bool, SymT
}
code.Code = c.OpCodes.Assemble()
code.Stacksize = int32(c.OpCodes.StackDepth())
code.Nlocals = int32(len(code.Varnames))
return code, nil
}

Expand Down Expand Up @@ -168,14 +206,23 @@ func (c *compiler) LoadConst(obj py.Object) {
c.OpArg(vm.LOAD_CONST, c.Const(obj))
}

// Returns the index into the slice provided, updating the slice if necessary
func (c *compiler) Index(Id string, Names *[]string) uint32 {
// Finds the Id in the slice provided, returning -1 if not found
func (c *compiler) FindId(Id string, Names []string) int {
// FIXME back this with a dict to stop O(N**2) behaviour on lots of vars
for i, s := range *Names {
for i, s := range Names {
if Id == s {
return uint32(i)
return i
}
}
return -1
}

// Returns the index into the slice provided, updating the slice if necessary
func (c *compiler) Index(Id string, Names *[]string) uint32 {
i := c.FindId(Id, *Names)
if i >= 0 {
return uint32(i)
}
*Names = append(*Names, Id)
return uint32(len(*Names) - 1)
}
Expand Down Expand Up @@ -234,6 +281,131 @@ func (c *compiler) Stmts(stmts []ast.Stmt) {
}
}

/* The test for LOCAL must come before the test for FREE in order to
handle classes where name is both local and free. The local var is
a method and the free var is a free var referenced within a method.
*/
func (c *compiler) getRefType(name string) symtable.Scope {
if c.scopeType == compilerScopeClass && name == "__class__" {
return symtable.ScopeCell
}
scope := c.SymTable.GetScope(name)
if scope == symtable.ScopeInvalid {
panic(fmt.Sprintf("compile: getRefType: unknown scope for %s in %s\nsymbols: %s\nlocals: %s\nglobals: %s", name, c.Code.Name, c.SymTable.Symbols, c.Code.Varnames, c.Code.Names))
}
return scope
}

// makeClosure constructs the function or closure for a func/class/lambda etc
func (c *compiler) makeClosure(code *py.Code, args uint32, child *compiler) {
free := uint32(len(code.Freevars))
qualname := child.qualname
if qualname == "" {
qualname = c.qualname
}

if free == 0 {
c.LoadConst(code)
c.LoadConst(py.String(qualname))
c.OpArg(vm.MAKE_FUNCTION, args)
return
}
for i := range code.Freevars {
/* Bypass com_addop_varname because it will generate
LOAD_DEREF but LOAD_CLOSURE is needed.
*/
name := code.Freevars[i]

/* Special case: If a class contains a method with a
free variable that has the same name as a method,
the name will be considered free *and* local in the
class. It should be handled by the closure, as
well as by the normal name loookup logic.
*/
reftype := c.getRefType(name)
arg := 0
if reftype == symtable.ScopeCell {
arg = c.FindId(name, c.Code.Cellvars)
} else { /* (reftype == FREE) */
arg = c.FindId(name, c.Code.Freevars)
}
if arg < 0 {
panic(fmt.Sprintf("compile: makeClosure: lookup %q in %q %v %v\nfreevars of %q: %v\n", name, c.SymTable.Name, reftype, arg, code.Name, code.Freevars))
}
c.OpArg(vm.LOAD_CLOSURE, uint32(arg))
}
c.OpArg(vm.BUILD_TUPLE, free)
c.LoadConst(code)
c.LoadConst(py.String(qualname))
c.OpArg(vm.MAKE_CLOSURE, args)
}

// Compute the flags for the current Code
func (c *compiler) codeFlags(st *symtable.SymTable) (flags int32) {
if st.Type == symtable.FunctionBlock {
flags |= py.CO_NEWLOCALS
if st.Unoptimized == 0 {
flags |= py.CO_OPTIMIZED
}
if st.Nested {
flags |= py.CO_NESTED
}
if st.Generator {
flags |= py.CO_GENERATOR
}
if st.Varargs {
flags |= py.CO_VARARGS
}
if st.Varkeywords {
flags |= py.CO_VARKEYWORDS
}
}

/* (Only) inherit compilerflags in PyCF_MASK */
flags |= c.Code.Flags & py.CO_COMPILER_FLAGS_MASK

if len(c.Code.Freevars) == 0 && len(c.Code.Cellvars) == 0 {
flags |= py.CO_NOFREE
}

return flags
}

// Sets the qualname
func (c *compiler) setQualname() {
var base string
if c.depth > 1 {
force_global := false
parent := c.parent
if parent == nil {
panic("compile: setQualname: expecting a parent")
}
if c.scopeType == compilerScopeFunction || c.scopeType == compilerScopeClass {
// FIXME mangled = _Py_Mangle(parent.u_private, u.u_name)
mangled := c.Code.Name
scope := parent.SymTable.GetScope(mangled)
if scope == symtable.ScopeGlobalImplicit {
panic("compile: setQualname: not expecting scopeGlobalImplicit")
}
if scope == symtable.ScopeGlobalExplicit {
force_global = true
}
}
if !force_global {
if parent.scopeType == compilerScopeFunction || parent.scopeType == compilerScopeLambda {
base = parent.qualname + ".<locals>"
} else {
base = parent.qualname
}
}
}
if base != "" {
c.qualname = base + "." + c.Code.Name
} else {
c.qualname = c.Code.Name
}
}

// Compile statement
func (c *compiler) Stmt(stmt ast.Stmt) {
switch node := stmt.(type) {
Expand All @@ -247,34 +419,15 @@ func (c *compiler) Stmt(stmt ast.Stmt) {
if newSymTable == nil {
panic("No symtable found for function")
}
code, err := CompileAst(node, c.Code.Filename, int(c.Code.Flags)|py.CO_OPTIMIZED|py.CO_NEWLOCALS, false, newSymTable) // FIXME pass on compile args
newC := newCompiler(c, compilerScopeFunction)
code, err := newC.compileAst(node, c.Code.Filename, 0, false, newSymTable)
if err != nil {
panic(err)
}
// FIXME need these set in code before we compile - (pass in node?)
code.Argcount = int32(len(node.Args.Args))
code.Name = string(node.Name)
code.Kwonlyargcount = int32(len(node.Args.Kwonlyargs))
code.Nlocals = code.Kwonlyargcount + int32(len(node.Args.Args))
if code.Kwonlyargcount > 0 {
code.Flags |= py.CO_VARARGS
}

// Arguments
for _, arg := range node.Args.Args {
c.Index(string(arg.Arg), &code.Varnames)
}
for _, arg := range node.Args.Kwonlyargs {
c.Index(string(arg.Arg), &code.Varnames)
}
if node.Args.Vararg != nil {
code.Nlocals++
c.Index(string(node.Args.Vararg.Arg), &code.Varnames)
}
if node.Args.Kwarg != nil {
code.Nlocals++
c.Index(string(node.Args.Kwarg.Arg), &code.Varnames)
code.Flags |= py.CO_VARKEYWORDS
}

// Defaults
posdefaults := uint32(len(node.Args.Defaults))
Expand Down Expand Up @@ -316,10 +469,10 @@ func (c *compiler) Stmt(stmt ast.Stmt) {
c.LoadConst(annotations)
}

c.LoadConst(code)
c.LoadConst(py.String(node.Name))
c.OpArg(vm.MAKE_FUNCTION, posdefaults+(kwdefaults<<8)+(num_annotations<<16))
c.OpArg(vm.STORE_NAME, c.Name(node.Name))
args := uint32(posdefaults + (kwdefaults << 8) + (num_annotations << 16))
c.makeClosure(code, args, newC)
c.NameOp(string(node.Name), ast.Store)

case *ast.ClassDef:
// Name Identifier
// Bases []Expr
Expand Down Expand Up @@ -750,16 +903,15 @@ func (c *compiler) Expr(expr ast.Expr) {
if newSymTable == nil {
panic("No symtable found for lambda")
}
code, err := CompileAst(node.Body, c.Code.Filename, int(c.Code.Flags)|py.CO_OPTIMIZED|py.CO_NEWLOCALS, false, newSymTable) // FIXME pass on compile args
newC := newCompiler(c, compilerScopeLambda)
code, err := newC.compileAst(node.Body, c.Code.Filename, 0, false, newSymTable)
if err != nil {
panic(err)
}

code.Argcount = int32(len(node.Args.Args))
c.LoadConst(code)
c.LoadConst(py.String("<lambda>"))
// FIXME node.Args
c.OpArg(vm.MAKE_FUNCTION, 0)
// FIXME node.Args - more work on lambda needed
c.makeClosure(code, 0, newC)
case *ast.IfExp:
// Test Expr
// Body Expr
Expand Down
Loading

0 comments on commit 5a2a35a

Please sign in to comment.