对抽象语法树(AST)求值的思路(3)

隐藏

前言

函数调用这一部分是非常关键的地方,也是我比较模糊的地方, 这里根据我的分析思考如何实现, 可能跟标准的、成熟的做法不符。

思路

有些函数可以作为函数指针用,或者类似lambda表达式, 所以函数跟string,bool等作为同一个类别,函数调用也只是一种操作,对这种类别的操作而已。 函数调用的难点在于参数与返回值。

函数分两种,一种是预定义的,相当于一个运行时环境, 另一种是代码内定义的,预定义的运行时环境可以说是必须的, 例如C语言中printf函数,否则解释器很多能力受限。 而代码内自定义的函数能进一步扩展设计的语言的能力。

先考虑如何实现代码内自定义函数:

假设实现一个类似于:

func add(i int, j int) {
    return i + j
}

可以想象,调用函数时,有以下几个步骤:

  • 首先跟查普通变量一样,在环境中查找到该变量
  • 确定该变量的类型为函数,而不是string、float之类。
  • 传递参数,并确定参数类型一致,参数作用域在函数内,函数执行前创建一个环境。
  • 执行语句,这里可能有多条语句,用一个[]Node
  • 定义一个ReturnNode,遇到return语句节点时返回结果

对于函数节点FunctionNode,调用每一条Node执行, 直到遇到ReturnNode,将其结果作为返回结果, 问题就在于ReturnNode如何通知FunctionNode,而无需继续执行下去, 这里就必须在Eval()后留下特殊的信息。 我的解决方法是添加一个TypeReturn类型,跟TypeString、TypeFloat等一个级别, 如果是TypeReturn就将该节点内的数据取出,作为返回,并于返回值类型比较。

实现

改进环境层次结构

之前将所有变量看作是全局变量,但因为函数定义参数时用的变量名不应该放到全局变量里, 所以先改进环境,将其设计为层次结构(stack),在本层查找失败,再去更外层查找, 直到找到或者已到最外层(global)。

之所以Push、Pop函数需要用**Environment,是为了改变环境, 否则对原环境(类型为*Environment)没有影响。

type Environment struct {
    name  string
    upper *Environment
    store map[string]NodeValue
}

//global env
func NewEnvironment(n string) *Environment {
    return &Environment{
        name:  n,
        upper: nil,
        store: make(map[string]NodeValue),
    }
}

func PushNewEnv(env **Environment, n string) {
    top := &Environment{
        name:  n,
        upper: *env,
        store: make(map[string]NodeValue),
    }
    *env = top
}

func PopEnv(env **Environment) {
    *env = (*env).upper
}

func (e *Environment) SaveVariable(varName string, value NodeValue) {
    e.store[varName] = value
}

func (e *Environment) GetVariable(varName string) NodeValue {
    value, exists := e.store[varName]
    if !exists {
        if e.upper != nil {
            value = e.upper.GetVariable(varName)
        } else {
            panic("Variable not found.")
        }
    }
    return value
}

函数声明定义

func quote(i int) string {
    ret = "\"" + i + "\""
    return ret
}

对于一个函数定义如上所示,可以看到有以下信息需要保存,

  • 函数名quote
  • 符号 i 及其对应的类型 int,之所以需要符号i,因为函数体内会引用i
  • 返回值类型,简单起见不考虑多返回值,类型是为了在函数调用时作类型检查
  • 函数体,即{...}内的部分,是一个block,即0到多个表达式(一个表达式对应一个Node树)

然后将函数注册到当前环境即可,代码如下:

//函数声明
type FuncDeclarationNode struct {
    env  **Environment
    name string
    Para []Symbol
    Ret  NodeValueType
    Body []Node
}

func NewFunctionDeclarationNode(e **Environment, n string, p []Symbol, r NodeValueType, b ...Node) *FuncDeclarationNode {
    return &FuncDeclarationNode{
        env:  e,
        name: n,
        Para: p,
        Ret:  r,
        Body: b,
    }
}

func (d *FuncDeclarationNode) Eval() NodeValue {
    n := NodeValue{Typ: TypeFunction, Value: d}
    //函数注册到环境
    (*(d.env)).SaveVariable(d.name, n)
    return n
}

函数调用

函数调用过程,以分析了其执行思路:

  • 首先跟查普通变量一样在环境中查找到该变量
  • 确定该变量的类型为函数,而不是string、float之类。
  • 传递参数,并确定参数类型一致,参数作用域在函数内,函数执行前创建一个环境。
  • 执行语句,这里可能有多条语句,用一个[]Node
  • 定义一个ReturnNode,遇到return语句节点时返回结果

实现参见如下代码:

//函数调用
type FuncCallNode struct {
    env  **Environment
    name string
    arg  []Node
}

func NewFuncCallNode(e **Environment, n string, a ...Node) *FuncCallNode {
    return &FuncCallNode{
        env:  e,
        name: n,
        arg:  a,
    }
}

func (n *FuncCallNode) Eval() NodeValue {
    var fd *FuncDeclarationNode
    if val := (*(n.env)).GetVariable(n.name); val.Typ != TypeFunction {
        panic("Call invalid function.")
    } else {
        fd = val.Value.(*FuncDeclarationNode)
    }
    PushNewEnv(n.env, n.name+"_func")
    //实参赋给形参、类型需一致
    matchPara(*n.env, fd.Para, n.arg)

    var nodeValue NodeValue
    for i := 0; i < len(fd.Body); i++ {
        nodeValue = fd.Body[i].Eval()
        if nodeValue.Typ == TypeReturn {
            ret := nodeValue.Value.(NodeValue)
            if ret.Typ != fd.Ret {
                panic("Return type mismatch.")
            } else {
                PopEnv(n.env)
                return ret
            }
        }
    }
    ret := nodeValue
    if ret.Typ != fd.Ret {
        panic("Return type mismatch.")
    } else {
        PopEnv(n.env)
        return ret
    }
}

func matchPara(env *Environment, para []Symbol, arg []Node) {
    if arg == nil && para == nil {
        return
    } else if arg == nil || para == nil {
        panic("Argument & parameter mismatch, one is nil.")
    }
    paraIdx := 0
    argIdx := 0
    for ; paraIdx < len(para) && argIdx < len(arg); {
        argValue := arg[argIdx].Eval()
        if para[paraIdx].Type != argValue.Typ {
            panic("Argument & parameter's type mismatch.")
        }
        env.SaveVariable(para[paraIdx].Name, argValue)
        paraIdx++
        argIdx++
    }
    if paraIdx != len(para) || argIdx != len(arg) {
        panic("Argument & parameter's length mismatch.")
    }
}

完整代码

package main

import (
    "fmt"
)

type NodeValueType uint8

const (
    TypeUnknown  = NodeValueType(iota)
    TypeInteger
    TypeFloat
    TypeString
    TypeBool
    TypeFunction
    TypeReturn
)

type NodeValue struct {
    Typ   NodeValueType
    Value interface{}
}

type Environment struct {
    name  string
    upper *Environment
    store map[string]NodeValue
}

//global env
func NewEnvironment(n string) *Environment {
    return &Environment{
        name:  n,
        upper: nil,
        store: make(map[string]NodeValue),
    }
}

func PushNewEnv(env **Environment, n string) {
    top := &Environment{
        name:  n,
        upper: *env,
        store: make(map[string]NodeValue),
    }
    *env = top
}

func PopEnv(env **Environment) {
    *env = (*env).upper
}

func (e *Environment) SaveVariable(varName string, value NodeValue) {
    e.store[varName] = value
}

func (e *Environment) GetVariable(varName string) NodeValue {
    value, exists := e.store[varName]
    if !exists {
        if e.upper != nil {
            value = e.upper.GetVariable(varName)
        } else {
            panic("Variable not found.")
        }
    }
    return value
}

func (r *NodeValue) ToString() string {
    switch r.Typ {
    case TypeInteger, TypeString, TypeFloat, TypeBool:
        return fmt.Sprint(r.Value)
    default:
        return ""
    }
}

func (r *NodeValue) ToFloat() float64 {
    switch r.Typ {
    case TypeFloat:
        return r.Value.(float64)
    case TypeInteger:
        return float64(r.Value.(int64))
    default:
        panic("Unexcept Behavior.")
    }
}

type Node interface {
    Eval() NodeValue
}

//整数节点
type IntegerNode struct {
    val int64
}

func NewIntegerNode(v int64) *IntegerNode {
    return &IntegerNode{val: v}
}

func (n *IntegerNode) Eval() NodeValue {
    return NodeValue{
        Typ:   TypeInteger,
        Value: n.val,
    }
}

//float节点
type FloatNode struct {
    val float64
}

func NewFloatNode(v float64) *FloatNode {
    return &FloatNode{val: v}
}

func (n *FloatNode) Eval() NodeValue {
    return NodeValue{
        Typ:   TypeFloat,
        Value: n.val,
    }
}

//string节点
type StringNode struct {
    val string
}

func NewStringNode(v string) *StringNode {
    return &StringNode{val: v}
}

func (n *StringNode) Eval() NodeValue {
    return NodeValue{
        Typ:   TypeString,
        Value: n.val,
    }
}

//bool节点
type BoolNode struct {
    val bool
}

func NewBoolNode(v bool) *BoolNode {
    return &BoolNode{val: v}
}

func (n *BoolNode) Eval() NodeValue {
    return NodeValue{
        Typ:   TypeBool,
        Value: n.val,
    }
}

//声明变量
type DeclarationNode struct {
    varName string
    val     Node
    env     **Environment
}

func NewDeclarationNode(e **Environment, n string, v Node) *DeclarationNode {
    return &DeclarationNode{
        varName: n,
        val:     v,
        env:     e,
    }
}

func (n *DeclarationNode) Eval() NodeValue {
    value := n.val.Eval()
    (*(n.env)).SaveVariable(n.varName, value)
    return value
}

//获取变量
type VariableNode struct {
    name string
    env  **Environment
}

func NewVariableNode(e **Environment, n string) *VariableNode {
    return &VariableNode{
        name: n,
        env:  e,
    }
}

func (n *VariableNode) Eval() NodeValue {
    return (*(n.env)).GetVariable(n.name)
}

//Return节点
type ReturnNode struct {
    retStatement Node
}

func NewReturnNode(r Node) *ReturnNode {
    return &ReturnNode{retStatement: r}
}

func (r *ReturnNode) Eval() NodeValue {
    ret := r.retStatement.Eval()
    return NodeValue{
        Typ:   TypeReturn,
        Value: ret,
    }
}

type Symbol struct {
    Name string
    Type NodeValueType
}

//函数声明
type FuncDeclarationNode struct {
    env  **Environment
    name string
    Para []Symbol
    Ret  NodeValueType
    Body []Node
}

func NewFunctionDeclarationNode(e **Environment, n string, p []Symbol, r NodeValueType, b ...Node) *FuncDeclarationNode {
    return &FuncDeclarationNode{
        env:  e,
        name: n,
        Para: p,
        Ret:  r,
        Body: b,
    }
}

func (d *FuncDeclarationNode) Eval() NodeValue {
    n := NodeValue{Typ: TypeFunction, Value: d}
    //函数注册到环境
    (*(d.env)).SaveVariable(d.name, n)
    return n
}

//函数调用
type FuncCallNode struct {
    env  **Environment
    name string
    arg  []Node
}

func NewFuncCallNode(e **Environment, n string, a ...Node) *FuncCallNode {
    return &FuncCallNode{
        env:  e,
        name: n,
        arg:  a,
    }
}

func (n *FuncCallNode) Eval() NodeValue {
    var fd *FuncDeclarationNode
    if val := (*(n.env)).GetVariable(n.name); val.Typ != TypeFunction {
        panic("Call invalid function.")
    } else {
        fd = val.Value.(*FuncDeclarationNode)
    }
    PushNewEnv(n.env, n.name+"_func")
    //实参赋给形参、类型需一致
    matchPara(*n.env, fd.Para, n.arg)

    var nodeValue NodeValue
    for i := 0; i < len(fd.Body); i++ {
        nodeValue = fd.Body[i].Eval()
        if nodeValue.Typ == TypeReturn {
            ret := nodeValue.Value.(NodeValue)
            if ret.Typ != fd.Ret {
                panic("Return type mismatch.")
            } else {
                PopEnv(n.env)
                return ret
            }
        }
    }
    ret := nodeValue
    if ret.Typ != fd.Ret {
        panic("Return type mismatch.")
    } else {
        PopEnv(n.env)
        return ret
    }
}

func matchPara(env *Environment, para []Symbol, arg []Node) {
    if arg == nil && para == nil {
        return
    } else if arg == nil || para == nil {
        panic("Argument & parameter mismatch, one is nil.")
    }
    paraIdx := 0
    argIdx := 0
    for ; paraIdx < len(para) && argIdx < len(arg); {
        argValue := arg[argIdx].Eval()
        if para[paraIdx].Type != argValue.Typ {
            panic("Argument & parameter's type mismatch.")
        }
        env.SaveVariable(para[paraIdx].Name, argValue)
        paraIdx++
        argIdx++
    }
    if paraIdx != len(para) || argIdx != len(arg) {
        panic("Argument & parameter's length mismatch.")
    }
}

//加法节点
type AddNode struct {
    left  Node
    right Node
}

func NewAddNode(l, r Node) *AddNode {
    return &AddNode{
        left:  l,
        right: r,
    }
}

func (n *AddNode) Eval() NodeValue {
    leftEval := n.left.Eval()
    rightEval := n.right.Eval()

    switch {
    case leftEval.Typ == TypeString || rightEval.Typ == TypeString:
        return NodeValue{
            Typ:   TypeString,
            Value: leftEval.ToString() + rightEval.ToString(),
        }
    case leftEval.Typ == TypeFloat || rightEval.Typ == TypeFloat:
        return NodeValue{
            Typ:   TypeString,
            Value: leftEval.ToFloat() + rightEval.ToFloat(),
        }
    case leftEval.Typ == TypeInteger && rightEval.Typ == TypeInteger:
        return NodeValue{
            Typ:   TypeInteger,
            Value: leftEval.Value.(int64) + rightEval.Value.(int64),
        }
    default:
        return NodeValue{
            Typ:   TypeUnknown,
            Value: nil,
        }
    }
}

func main() {
    env := NewEnvironment("global")

    // func quote(i int) string {
    //     ret = "\"" + i + "\""
    //     return ret
    // }
    declareFunc := NewFunctionDeclarationNode(&env, "quote",
        []Symbol{ {"i", TypeInteger}, }, TypeString,
        NewDeclarationNode(&env, "ret",
            NewAddNode(
                NewAddNode(
                    NewStringNode("\""),
                    NewVariableNode(&env, "i")),
                NewStringNode("\""))),
        NewReturnNode(NewVariableNode(&env, "ret")))
    declareFunc.Eval()

    // quote(100 + 10) + 200 = "110"200
    ast := NewAddNode(
        NewFuncCallNode(&env, "quote", NewAddNode(NewIntegerNode(100), NewIntegerNode(10))),
        NewIntegerNode(200))
    r := ast.Eval()
    if r.Typ == TypeString {
        fmt.Println("quote(100 + 10) + 200 = ", r.Value.(string))
    }
}
-----EOF-----

Categories: programming Tags: AST pattern