使用CST创建AST

发布于:2022-12-29 ⋅ 阅读:(330) ⋅ 点赞:(0)

以计算器为例:

grammar:

grammar LabeledExpr; // rename to distinguish from Expr.g4

compileUnit
    :   expr EOF
    ;

expr
    :   '(' expr ')'                         # parensExpr
    |   op=('+'|'-') expr                    # unaryExpr
    |   left=expr op=('*'|'/') right=expr    # infixExpr
    |   left=expr op=('+'|'-') right=expr    # infixExpr
    |   value=NUM                            # numberExpr
    ;

OP_ADD: '+';
OP_SUB: '-';
OP_MUL: '*';
OP_DIV: '/';

NUM :   [0-9]+ ('.' [0-9]+)? ([eE] [+-]? [0-9]+)?;
ID  :   [a-zA-Z]+;
WS  :   [ \t\r\n] -> channel(HIDDEN);

第一步:创建AST结点类。根据计算器的语法,需要创建数字结点,二元运算结点和一元运算结点

可以将二元运算(+,-,*,/)根据运算符创建为不同种类的结点,也可以创建为一种结点,此时运算符作为结点的一个属性。这里创建为不同的结点

创建AST结点可以参考 

https://www.jianshu.com/p/0cf054a8365f

Antlr4如何自动解析得到AST而不是ParseTree - 知乎

node.py

class Node(Object):
    """
    所有结点的基类
    """
    pass


class ExpressionNode(Node):
    """
    表达式结点的基类
    """
    pass


class InfixExpressionNode(ExpressionNode):
    """
    二元表达式结点的基类
    """
    def __init__(self):
        self.Left = None
        self.Right = None

# 二元运算结点
class AdditionNode(InfixExpressionNode):
    """
    加法结点
    """
    pass


class SubtractionNode(InfixExpressionNode):
    pass


class MultiplicationNode(InfixExpressionNode):
    pass


class DivisionNode(InfixExpressionNode):
    pass


class NegateNode(ExpressionNode):
    """
    一元运算(+1 / -1)结点
    """
    def __init__(self, InnerNode):
        self.InnerNode = InnerNode


class NumberNode(ExpressionNode):
    """
    数字结点
    """
    def __init__(self, value):
        self.Value = value

第二步 :使用antlr提供的遍历CST的方法遍历CST,遍历过程中将创建AST所需的信息存储到AST结点。创建AST的过程是递归的

AstBuilder.py

from ASTTest.gen.LabeledExprParser import LabeledExprParser
from ASTTest.gen.LabeledExprVisitor import LabeledExprVisitor
from ASTTest.node.node import *

# 继承antlr提供的vositor,重载访问语法结点的方法,创建并返回AST结点
class ASTBuilder(LabeledExprVisitor):
    def visitCompileUnit(self, ctx: LabeledExprParser.CompileUnitContext):
        return self.visit(ctx.expr())

    # left=expr op=('*'|'/') right=expr    # infixExpr
    # left=expr op=('+'|'-') right=expr    # infixExpr
    def visitInfixExpr(self, ctx: LabeledExprParser.InfixExprContext):
        # 根据运算符创建对应的二元操作结点
        if (ctx.op.type == LabeledExprParser.OP_ADD):
            node = AdditionNode()
        elif (ctx.op.type == LabeledExprParser.OP_SUB):
            node = SubtractionNode()
        elif (ctx.op.type == LabeledExprParser.OP_MUL):
            node = MultiplicationNode()
        elif (ctx.op.type == LabeledExprParser.OP_DIV):
            node = DivisionNode
        else:
            print("符号匹配错误")
        # 二元操作的左右值存入结点
        node.Left = self.visit(ctx.left)
        node.Right = self.visit(ctx.right)
        return node

    # op=('+'|'-') expr                    # unaryExpr
    def visitUnaryExpr(self, ctx: LabeledExprParser.UnaryExprContext):
        if (ctx.op == LabeledExprParser.OP_ADD):
            return self.visit(ctx.expr())
        else:
            innerNode = self.visit(ctx.expr())
            node = NegateNode(innerNode)
            return node

    # value=NUM                            # numberExpr
    def visitNumberExpr(self, ctx: LabeledExprParser.NumberExprContext):
        # 数字结点
        return NumberNode(ctx.value.text)

    # '(' expr ')'                         # parensExpr
    def visitParensExpr(self, ctx: LabeledExprParser.ParensExprContext):
        # AST不需要括号,AST的结构隐含了括号信息,因此只返回expr
        return self.visit(ctx.expr())

第三步:第二步已经创建好AST,这一步使用AST实现计算器的功能

ast_visitor.py

from ASTTest.node.node import *


class ASTVisitor:
    # 递归访问AST结点
    def visit(self, node):
        if type(node) is AdditionNode:
            return int(self.visit(node.Left)) + int(self.visit(node.Right))
        elif type(node) is SubtractionNode:
            return int(self.visit(node.Left)) - int(self.visit(node.Right))
        elif type(node) is MultiplicationNode:
            return int(self.visit(node.Left)) * int(self.visit(node.Right))
        elif type(node) is DivisionNode:
            return int(self.visit(node.Left)) / int(self.visit(node.Right))
        elif type(node) is NegateNode:
            print("InnerNode:", self.visit(node.InnerNode))
            return -int(self.visit(node.InnerNode))
        # 数字结点直接返回数值
        elif type(node) is NumberNode:
            return node.Value
        else:
            return None

第四步: 将前面几步串起来,实现计算器

from antlr4 import CommonTokenStream, FileStream

from ASTTest.ASTBuilder import ASTBuilder
from ASTTest.ast.ast_visitor import ASTVisitor
from ASTTest.ast.evaluator import EvaluateExpressionVisitor
from ASTTest.gen.LabeledExprLexer import LabeledExprLexer
from ASTTest.gen.LabeledExprParser import LabeledExprParser
from ASTTest.node.node import Node

if __name__ == '__main__':
    inputStream = FileStream("input.txt")
    lexer = LabeledExprLexer(inputStream)
    tokenStream = CommonTokenStream(lexer)
    parser = LabeledExprParser(tokenStream)
    # antlr生成解析树
    cst = parser.compileUnit()

    # 遍历解析树,创建ast
    ast = ASTBuilder().visitCompileUnit(cst)

    # 遍历ast,实现计算器功能
    value = ASTVisitor().visit(ast)
    print("计算结果:", value)

运行:

输入:2*(3+1)-3

输出:

 输入:-3

输出:

本文含有隐藏内容,请 开通VIP 后查看