对抽象语法树(AST)求值的思路(3)
前言
函数调用这一部分是非常关键的地方,也是我比较模糊的地方, 这里根据我的分析思考如何实现, 可能跟标准的、成熟的做法不符。
思路
有些函数可以作为函数指针用,或者类似lambda表达式, 所以函数跟string,bool等作为同一个类别,函数调用也只是一种操作,对这种类别的操作而已。 函数调用的难点在于参数与返回值。
函数分两种,一种是预定义的,相当于一个运行时环境, 另一种是代码内定义的,预定义的运行时环境可以说是必须的, 例如C语言中printf函数,否则解释器很多能力受限。 而代码内自定义的函数能进一步扩展设计的语言的能力。
先考虑如何实现代码内自定义函数:
假设实现一个类似于:
func add(i int, j int) {
return i + j
}
可以想象,调用函数时,有以下几个步骤:
- 首先跟查普通变量一样,在环境中查找到该变量
- 确定该变量的类型为函数,而不是string、float之类。
- 传递参数,并确定参数类型一致,参数作用域在函数内,函数执行前创建一个环境。
- 执行语句,这里可能有多条语句,用一个
[]Node
- 定义一个ReturnNode,遇到return语句节点时返回结果
对于函数节点FunctionNode,调用每一条Node执行,
直到遇到ReturnNode,将其结果作为返回结果,
问题就在于ReturnNode如何通知FunctionNode,而无需继续执行下去,
这里就必须在Eval()
后留下特殊的信息。
我的解决方法是添加一个TypeReturn类型,跟TypeString、TypeFloat等一个级别,
如果是TypeReturn就将该节点内的数据取出,作为返回,并于返回值类型比较。
实现
改进环境层次结构
之前将所有变量看作是全局变量,但因为函数定义参数时用的变量名不应该放到全局变量里, 所以先改进环境,将其设计为层次结构(stack),在本层查找失败,再去更外层查找, 直到找到或者已到最外层(global)。
之所以Push、Pop函数需要用**Environment
,是为了改变环境,
否则对原环境(类型为*Environment
)没有影响。
type Environment struct {
name string
upper *Environment
store map[string]NodeValue
}
//global env
func NewEnvironment(n string) *Environment {
return &Environment{
name: n,
upper: nil,
store: make(map[string]NodeValue),
}
}
func PushNewEnv(env **Environment, n string) {
top := &Environment{
name: n,
upper: *env,
store: make(map[string]NodeValue),
}
*env = top
}
func PopEnv(env **Environment) {
*env = (*env).upper
}
func (e *Environment) SaveVariable(varName string, value NodeValue) {
e.store[varName] = value
}
func (e *Environment) GetVariable(varName string) NodeValue {
value, exists := e.store[varName]
if !exists {
if e.upper != nil {
value = e.upper.GetVariable(varName)
} else {
panic("Variable not found.")
}
}
return value
}
函数声明定义
func quote(i int) string {
ret = "\"" + i + "\""
return ret
}
对于一个函数定义如上所示,可以看到有以下信息需要保存,
- 函数名quote
- 符号 i 及其对应的类型 int,之所以需要符号i,因为函数体内会引用i
- 返回值类型,简单起见不考虑多返回值,类型是为了在函数调用时作类型检查
- 函数体,即
{...}
内的部分,是一个block,即0到多个表达式(一个表达式对应一个Node树)
然后将函数注册到当前环境即可,代码如下:
//函数声明
type FuncDeclarationNode struct {
env **Environment
name string
Para []Symbol
Ret NodeValueType
Body []Node
}
func NewFunctionDeclarationNode(e **Environment, n string, p []Symbol, r NodeValueType, b ...Node) *FuncDeclarationNode {
return &FuncDeclarationNode{
env: e,
name: n,
Para: p,
Ret: r,
Body: b,
}
}
func (d *FuncDeclarationNode) Eval() NodeValue {
n := NodeValue{Typ: TypeFunction, Value: d}
//函数注册到环境
(*(d.env)).SaveVariable(d.name, n)
return n
}
函数调用
函数调用过程,以分析了其执行思路:
- 首先跟查普通变量一样在环境中查找到该变量
- 确定该变量的类型为函数,而不是string、float之类。
- 传递参数,并确定参数类型一致,参数作用域在函数内,函数执行前创建一个环境。
- 执行语句,这里可能有多条语句,用一个
[]Node
- 定义一个ReturnNode,遇到return语句节点时返回结果
实现参见如下代码:
//函数调用
type FuncCallNode struct {
env **Environment
name string
arg []Node
}
func NewFuncCallNode(e **Environment, n string, a ...Node) *FuncCallNode {
return &FuncCallNode{
env: e,
name: n,
arg: a,
}
}
func (n *FuncCallNode) Eval() NodeValue {
var fd *FuncDeclarationNode
if val := (*(n.env)).GetVariable(n.name); val.Typ != TypeFunction {
panic("Call invalid function.")
} else {
fd = val.Value.(*FuncDeclarationNode)
}
PushNewEnv(n.env, n.name+"_func")
//实参赋给形参、类型需一致
matchPara(*n.env, fd.Para, n.arg)
var nodeValue NodeValue
for i := 0; i < len(fd.Body); i++ {
nodeValue = fd.Body[i].Eval()
if nodeValue.Typ == TypeReturn {
ret := nodeValue.Value.(NodeValue)
if ret.Typ != fd.Ret {
panic("Return type mismatch.")
} else {
PopEnv(n.env)
return ret
}
}
}
ret := nodeValue
if ret.Typ != fd.Ret {
panic("Return type mismatch.")
} else {
PopEnv(n.env)
return ret
}
}
func matchPara(env *Environment, para []Symbol, arg []Node) {
if arg == nil && para == nil {
return
} else if arg == nil || para == nil {
panic("Argument & parameter mismatch, one is nil.")
}
paraIdx := 0
argIdx := 0
for ; paraIdx < len(para) && argIdx < len(arg); {
argValue := arg[argIdx].Eval()
if para[paraIdx].Type != argValue.Typ {
panic("Argument & parameter's type mismatch.")
}
env.SaveVariable(para[paraIdx].Name, argValue)
paraIdx++
argIdx++
}
if paraIdx != len(para) || argIdx != len(arg) {
panic("Argument & parameter's length mismatch.")
}
}
完整代码
package main
import (
"fmt"
)
type NodeValueType uint8
const (
TypeUnknown = NodeValueType(iota)
TypeInteger
TypeFloat
TypeString
TypeBool
TypeFunction
TypeReturn
)
type NodeValue struct {
Typ NodeValueType
Value interface{}
}
type Environment struct {
name string
upper *Environment
store map[string]NodeValue
}
//global env
func NewEnvironment(n string) *Environment {
return &Environment{
name: n,
upper: nil,
store: make(map[string]NodeValue),
}
}
func PushNewEnv(env **Environment, n string) {
top := &Environment{
name: n,
upper: *env,
store: make(map[string]NodeValue),
}
*env = top
}
func PopEnv(env **Environment) {
*env = (*env).upper
}
func (e *Environment) SaveVariable(varName string, value NodeValue) {
e.store[varName] = value
}
func (e *Environment) GetVariable(varName string) NodeValue {
value, exists := e.store[varName]
if !exists {
if e.upper != nil {
value = e.upper.GetVariable(varName)
} else {
panic("Variable not found.")
}
}
return value
}
func (r *NodeValue) ToString() string {
switch r.Typ {
case TypeInteger, TypeString, TypeFloat, TypeBool:
return fmt.Sprint(r.Value)
default:
return ""
}
}
func (r *NodeValue) ToFloat() float64 {
switch r.Typ {
case TypeFloat:
return r.Value.(float64)
case TypeInteger:
return float64(r.Value.(int64))
default:
panic("Unexcept Behavior.")
}
}
type Node interface {
Eval() NodeValue
}
//整数节点
type IntegerNode struct {
val int64
}
func NewIntegerNode(v int64) *IntegerNode {
return &IntegerNode{val: v}
}
func (n *IntegerNode) Eval() NodeValue {
return NodeValue{
Typ: TypeInteger,
Value: n.val,
}
}
//float节点
type FloatNode struct {
val float64
}
func NewFloatNode(v float64) *FloatNode {
return &FloatNode{val: v}
}
func (n *FloatNode) Eval() NodeValue {
return NodeValue{
Typ: TypeFloat,
Value: n.val,
}
}
//string节点
type StringNode struct {
val string
}
func NewStringNode(v string) *StringNode {
return &StringNode{val: v}
}
func (n *StringNode) Eval() NodeValue {
return NodeValue{
Typ: TypeString,
Value: n.val,
}
}
//bool节点
type BoolNode struct {
val bool
}
func NewBoolNode(v bool) *BoolNode {
return &BoolNode{val: v}
}
func (n *BoolNode) Eval() NodeValue {
return NodeValue{
Typ: TypeBool,
Value: n.val,
}
}
//声明变量
type DeclarationNode struct {
varName string
val Node
env **Environment
}
func NewDeclarationNode(e **Environment, n string, v Node) *DeclarationNode {
return &DeclarationNode{
varName: n,
val: v,
env: e,
}
}
func (n *DeclarationNode) Eval() NodeValue {
value := n.val.Eval()
(*(n.env)).SaveVariable(n.varName, value)
return value
}
//获取变量
type VariableNode struct {
name string
env **Environment
}
func NewVariableNode(e **Environment, n string) *VariableNode {
return &VariableNode{
name: n,
env: e,
}
}
func (n *VariableNode) Eval() NodeValue {
return (*(n.env)).GetVariable(n.name)
}
//Return节点
type ReturnNode struct {
retStatement Node
}
func NewReturnNode(r Node) *ReturnNode {
return &ReturnNode{retStatement: r}
}
func (r *ReturnNode) Eval() NodeValue {
ret := r.retStatement.Eval()
return NodeValue{
Typ: TypeReturn,
Value: ret,
}
}
type Symbol struct {
Name string
Type NodeValueType
}
//函数声明
type FuncDeclarationNode struct {
env **Environment
name string
Para []Symbol
Ret NodeValueType
Body []Node
}
func NewFunctionDeclarationNode(e **Environment, n string, p []Symbol, r NodeValueType, b ...Node) *FuncDeclarationNode {
return &FuncDeclarationNode{
env: e,
name: n,
Para: p,
Ret: r,
Body: b,
}
}
func (d *FuncDeclarationNode) Eval() NodeValue {
n := NodeValue{Typ: TypeFunction, Value: d}
//函数注册到环境
(*(d.env)).SaveVariable(d.name, n)
return n
}
//函数调用
type FuncCallNode struct {
env **Environment
name string
arg []Node
}
func NewFuncCallNode(e **Environment, n string, a ...Node) *FuncCallNode {
return &FuncCallNode{
env: e,
name: n,
arg: a,
}
}
func (n *FuncCallNode) Eval() NodeValue {
var fd *FuncDeclarationNode
if val := (*(n.env)).GetVariable(n.name); val.Typ != TypeFunction {
panic("Call invalid function.")
} else {
fd = val.Value.(*FuncDeclarationNode)
}
PushNewEnv(n.env, n.name+"_func")
//实参赋给形参、类型需一致
matchPara(*n.env, fd.Para, n.arg)
var nodeValue NodeValue
for i := 0; i < len(fd.Body); i++ {
nodeValue = fd.Body[i].Eval()
if nodeValue.Typ == TypeReturn {
ret := nodeValue.Value.(NodeValue)
if ret.Typ != fd.Ret {
panic("Return type mismatch.")
} else {
PopEnv(n.env)
return ret
}
}
}
ret := nodeValue
if ret.Typ != fd.Ret {
panic("Return type mismatch.")
} else {
PopEnv(n.env)
return ret
}
}
func matchPara(env *Environment, para []Symbol, arg []Node) {
if arg == nil && para == nil {
return
} else if arg == nil || para == nil {
panic("Argument & parameter mismatch, one is nil.")
}
paraIdx := 0
argIdx := 0
for ; paraIdx < len(para) && argIdx < len(arg); {
argValue := arg[argIdx].Eval()
if para[paraIdx].Type != argValue.Typ {
panic("Argument & parameter's type mismatch.")
}
env.SaveVariable(para[paraIdx].Name, argValue)
paraIdx++
argIdx++
}
if paraIdx != len(para) || argIdx != len(arg) {
panic("Argument & parameter's length mismatch.")
}
}
//加法节点
type AddNode struct {
left Node
right Node
}
func NewAddNode(l, r Node) *AddNode {
return &AddNode{
left: l,
right: r,
}
}
func (n *AddNode) Eval() NodeValue {
leftEval := n.left.Eval()
rightEval := n.right.Eval()
switch {
case leftEval.Typ == TypeString || rightEval.Typ == TypeString:
return NodeValue{
Typ: TypeString,
Value: leftEval.ToString() + rightEval.ToString(),
}
case leftEval.Typ == TypeFloat || rightEval.Typ == TypeFloat:
return NodeValue{
Typ: TypeString,
Value: leftEval.ToFloat() + rightEval.ToFloat(),
}
case leftEval.Typ == TypeInteger && rightEval.Typ == TypeInteger:
return NodeValue{
Typ: TypeInteger,
Value: leftEval.Value.(int64) + rightEval.Value.(int64),
}
default:
return NodeValue{
Typ: TypeUnknown,
Value: nil,
}
}
}
func main() {
env := NewEnvironment("global")
// func quote(i int) string {
// ret = "\"" + i + "\""
// return ret
// }
declareFunc := NewFunctionDeclarationNode(&env, "quote",
[]Symbol{ {"i", TypeInteger}, }, TypeString,
NewDeclarationNode(&env, "ret",
NewAddNode(
NewAddNode(
NewStringNode("\""),
NewVariableNode(&env, "i")),
NewStringNode("\""))),
NewReturnNode(NewVariableNode(&env, "ret")))
declareFunc.Eval()
// quote(100 + 10) + 200 = "110"200
ast := NewAddNode(
NewFuncCallNode(&env, "quote", NewAddNode(NewIntegerNode(100), NewIntegerNode(10))),
NewIntegerNode(200))
r := ast.Eval()
if r.Typ == TypeString {
fmt.Println("quote(100 + 10) + 200 = ", r.Value.(string))
}
}