[手写系列]Go手写db — — 第三版(实现分组、排序、聚合函数等)

发布于:2025-09-08 ⋅ 阅读:(20) ⋅ 点赞:(0)

[手写系列]Go手写db — — 第三版

第一版文章地址:https://blog.csdn.net/weixin_45565886/article/details/147839627
第二版文章地址:https://blog.csdn.net/weixin_45565886/article/details/150869791

  • 🏠整体项目Github地址:https://github.com/ziyifast/ZiyiDB
  • 🚀请大家多多支持,也欢迎大家star⭐️和共同维护这个项目~

序言:只要接触过后端开发,必不可少会使用到关系型数据库,比如:MySQL、Oracle等,那么我们经常使用的字段默认值、以及聚合函数底层是如何实现的呢?本文会给大家提供一些思路,实现相关功能。

主要介绍如何在 ZiyiDB之前的基础上,实现更多新功能,给大家提供实现数据库的简单思路,以及数据库底层实现的流程,后续更多功能,大家可以参考着实现。

一、功能列表

  1. 默认值支持(DEFAULT 关键字)
  2. 聚合函数支持(COUNT, SUM, AVG, MAX, MIN)
  3. Group by分组能力
  4. Order by 排序能力

二、实现细节

1. 默认值实现

设计思路

默认值是数据库中一个重要的数据完整性特性。当插入数据时,如果没有为某列提供值,数据库会自动使用该列的默认值。

在 ZiyiDB 中,默认值的实现需要考虑以下几点:

  • 语法解析:在 CREATE TABLE 语句中识别 DEFAULT 关键字和默认值
  • 存储:在表结构中保存每列的默认值
  • 执行:在 INSERT 语句中应用默认值

1.在lexer/token.go中新增default字符,然后在lexer/lexer.go的lookupIdentifier方法中新增对于default的case语句,用于匹配识别用户输入的SQL

token.go:
在这里插入图片描述
lexer.go:
在这里插入图片描述
2. internal/ast/ast.go抽象语法树中新增DefaultExpression,同时列定义中新增默认值字段,用于存储列的默认值
在这里插入图片描述
在这里插入图片描述
3. parser中的parseCreateTableStatement函数新增对create SQL中默认值的读取和封装,解析用户输入SQL中的字段默认值类型和value
在这里插入图片描述
4. internal/storage/memory.go 存储引擎处理Insert方法时,新增对默认值的处理。
在这里插入图片描述

代码实现

1.语法解析层(Parser)

在 internal/parser/parser.go 中,parseCreateTableStatement 方法被增强以支持默认值:

// parseCreateTableStatement 解析CREATE TABLE语句
func (p *Parser) parseCreateTableStatement() (*ast.CreateTableStatement, error) {
    stmt := &ast.CreateTableStatement{Token: p.curToken}
    // ... 其他代码

    // 解析列定义
    for !p.peekTokenIs(lexer.RPAREN) {
        p.nextToken()

        if !p.curTokenIs(lexer.IDENT) {
            return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.curToken.Literal)
        }

        col := ast.ColumnDefinition{
            Name: p.curToken.Literal,
        }

        if !p.expectPeek(lexer.INT) &&
            !p.expectPeek(lexer.TEXT) &&
            !p.expectPeek(lexer.FLOAT) &&
            !p.expectPeek(lexer.DATETIME) {
            return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.peekToken.Literal)
        }
        col.Type = string(p.curToken.Type)

        if p.peekTokenIs(lexer.PRIMARY) {
            p.nextToken()
            if !p.expectPeek(lexer.KEY) {
                return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.peekToken.Literal)
            }
            col.Primary = true
        }

        if p.peekTokenIs(lexer.DEFAULT) {
            p.nextToken() // 消费 DEFAULT 关键字
            p.nextToken() // 移动到默认值表达式开始位置

            // 解析复杂默认值表达式(支持函数调用、数学表达式等)
            defaultValue, err := p.parseExpression()
            if err != nil {
                return nil, fmt.Errorf("Invalid default value for column '%s': %v", col.Name, err)
            }

            // 创建 DefaultExpression 节点
            col.Default = &ast.DefaultExpression{
                Token: p.curToken,
                Value: defaultValue,
            }
        }

        stmt.Columns = append(stmt.Columns, col)

        if p.peekTokenIs(lexer.COMMA) {
            p.nextToken()
        }
    }
    // ... 其他代码
}

2.AST 定义

在 internal/ast/ast.go 中,我们添加了 DefaultExpression 类型来表示默认值:

// DefaultExpression 表示DEFAULT表达式
type DefaultExpression struct {
    Token lexer.Token
    Value Expression
}

func (de *DefaultExpression) expressionNode()      {}
func (de *DefaultExpression) TokenLiteral() string { return de.Token.Literal }

同时,ColumnDefinition 结构也被更新以包含默认值:

// ColumnDefinition 表示列定义
type ColumnDefinition struct {
    Name     string
    Type     string
    Primary  bool
    Nullable bool
    Default  interface{} //列默认值
}

3.存储引擎实现

在 internal/storage/memory.go 中,Insert 方法被增强以支持默认值:

// Insert 插入数据
func (b *MemoryBackend) Insert(stmt *ast.InsertStatement) error {
    table, exists := b.tables[stmt.TableName]
    if !exists {
        return fmt.Errorf("Table '%s' doesn't exist", stmt.TableName)
    }
    // 构建列名到表列索引的映射
    colIndexMap := make(map[string]int)
    for idx, col := range table.Columns {
        colIndexMap[col.Name] = idx
    }
    // 初始化行数据(长度为表的总列数)
    row := make([]ast.Cell, len(table.Columns))
    // 处理插入列列表(用户显式指定的列或隐式全列)
    var insertCols []*ast.Identifier
    //用户SQL需要插入的列名、值的映射
    userColMap := make(map[string]ast.Expression)
    if len(stmt.Columns) > 0 {
        insertCols = stmt.Columns
        for i, col := range stmt.Columns {
            userColMap[col.Token.Literal] = stmt.Values[i]
        }
    } else {
        // 未指定列时默认使用表的所有列
        insertCols = make([]*ast.Identifier, len(table.Columns))
        for i, col := range table.Columns {
            insertCols[i] = &ast.Identifier{Value: col.Name}
            userColMap[col.Name] = stmt.Values[i]
        }
    }
    // 检查值数量与指定列数量是否匹配
    if len(stmt.Values) != len(insertCols) {
        return fmt.Errorf("Column count doesn't match value count at row 1 (got %d, want %d)", len(stmt.Values), len(insertCols))
    }

    // 转换值
    // 填充行数据(处理用户值或默认值)
    for i, tableCol := range table.Columns {
        // 优先使用用户提供的值,否则使用默认值
        var expr ast.Expression
        expr = userColMap[tableCol.Name]
        if expr == nil && tableCol.Default != nil {
            expr = tableCol.Default.(*ast.DefaultExpression).Value
        }
        //获取当前列名
        colName := table.Columns[i].Name
        tableColIdx, ok := colIndexMap[colName]
        if !ok {
            return fmt.Errorf("Unknown column '%s' in INSERT statement", colName)
        }
        // 转换值类型
        value, err := evaluateExpression(expr)
        if err != nil {
            return fmt.Errorf("invalid value for column '%s': %v", colName, err)
        }

        // 类型转换
        switch v := value.(type) {
        case string:
            if tableCol.Type == "INT" {
                intVal, err := strconv.ParseInt(v, 10, 32)
                if err != nil {
                    return fmt.Errorf("Incorrect integer value: '%s' for column '%s'", v, tableCol.Name)
                }
                row[tableColIdx] = ast.Cell{Type: ast.CellTypeInt, IntValue: int32(intVal)}
            } else {
                row[tableColIdx] = ast.Cell{Type: ast.CellTypeText, TextValue: v}
            }
        case int32:
            row[tableColIdx] = ast.Cell{Type: ast.CellTypeInt, IntValue: v}
        case float32:
            row[tableColIdx] = ast.Cell{Type: ast.CellTypeFloat, FloatValue: v}
        case time.Time:
            row[tableColIdx] = ast.Cell{Type: ast.CellTypeDateTime, TimeValue: v.Format("2006-01-02 15:04:05")}
        default:
            return fmt.Errorf("Unsupported value type: %T for column '%s'", value, tableCol.Name)
        }
    }
    // ... 其他代码
}

测试

测试SQL:

-- 创建带默认值的表
CREATE TABLE users (
    id INT PRIMARY KEY,
    name TEXT,
    age INT DEFAULT 18,
    score FLOAT,
    ctime DATETIME DEFAULT '2023-07-04 12:00:00'
);

-- 插入部分列数据(未指定的列将使用默认值)
INSERT INTO users (id, name, score) VALUES (1, 'Alice', 90.0);
INSERT INTO users (id, name, age, score) VALUES (2, 'Bob', 25, 85.5);

-- 查询数据验证默认值
SELECT * FROM users;

效果:
在这里插入图片描述

2. 聚合函数实现

设计思路

聚合函数是 SQL 中用于对一组值执行计算并返回单个值的函数。在 ZiyiDB 中,我们实现了以下聚合函数:

  • COUNT:计算行数
  • SUM:计算数值列的总和
  • AVG:计算数值列的平均值
  • MAX:找出列中的最大值
  • MIN:找出列中的最小值

聚合函数的实现需要考虑以下几点:
语法解析:在 SELECT 语句中识别函数调用
执行逻辑:在存储引擎中计算聚合结果
结果返回:以统一的格式返回结果

这里以count聚合函数为例,其他聚合函数同理

  1. internal/ast/ast.go中新增FunctionCall函数调用类型,用于后续执行函数调用,比如count、max等聚合函数
    在这里插入图片描述
  2. internal/parser/parser.go中新增对函数类型的解析和封装
    在这里插入图片描述
  3. internal/storage/memory.go存储引擎Select方法中新增对聚合函数的判断
    在这里插入图片描述
    同时memory.go中添加calculateFunctionResults方法,实现对函数的执行和底层实现
    在这里插入图片描述
    在这里插入图片描述

代码实现

  1. 语法解析层(Parser)

在 internal/parser/parser.go 中,我们增强了 parseSelectStatement 方法来支持函数调用:

// parseSelectStatement 解析SELECT语句
func (p *Parser) parseSelectStatement() (*ast.SelectStatement, error) {
    stmt := &ast.SelectStatement{Token: p.curToken}

    // 解析选择列表
    for !p.peekTokenIs(lexer.FROM) {
        p.nextToken()

        if p.curToken.Type == lexer.ASTERISK {
            stmt.Fields = append(stmt.Fields, &ast.StarExpression{})
            break
        }

        expr, err := p.parseExpression()
        if err != nil {
            return nil, err
        }

        stmt.Fields = append(stmt.Fields, expr)

        if p.peekTokenIs(lexer.COMMA) {
            p.nextToken()
        }
    }
    // ... 其他代码
}

parseExpression 方法也进行了增强,以支持函数调用的解析:

// parseExpression 解析表达式
func (p *Parser) parseExpression() (ast.Expression, error) {
    switch p.curToken.Type {
    // ... 其他情况
    case lexer.IDENT:
        if p.peekTokenIs(lexer.LPAREN) {
            return p.parseFunctionCall()
        }
        return &ast.Identifier{
            Token: p.curToken,
            Value: p.curToken.Literal,
        }, nil
    // ...
    }
}

// parseFunctionCall 解析函数调用
func (p *Parser) parseFunctionCall() (ast.Expression, error) {
    fn := &ast.FunctionCall{
        Token:  p.curToken,
        Name:   p.curToken.Literal,
        Params: []ast.Expression{},
    }

    // 检查下一个token是否为左括号
    if !p.expectPeek(lexer.LPAREN) {
        return nil, fmt.Errorf("expected ( after function name")
    }

    // 如果是右括号,说明没有参数
    if p.peekTokenIs(lexer.RPAREN) {
        p.nextToken()
        return fn, nil
    }

    // 解析参数列表
    for !p.peekTokenIs(lexer.RPAREN) {
        p.nextToken()
        param, err := p.parseExpression()
        if err != nil {
            return nil, err
        }
        fn.Params = append(fn.Params, param)

        if p.peekTokenIs(lexer.COMMA) {
            p.nextToken()
        } else if !p.peekTokenIs(lexer.RPAREN) {
            return nil, fmt.Errorf("expected comma or closing parenthesis in function call")
        }
    }

    if !p.expectPeek(lexer.RPAREN) {
        return nil, fmt.Errorf("Missing closing parenthesis for function call")
    }

    return fn, nil
}
  1. AST 定义

在 internal/ast/ast.go 中,我们添加了 FunctionCall 类型来表示函数调用:

// FunctionCall 表示函数调用
type FunctionCall struct {
    Token  lexer.Token
    Name   string
    Params []Expression
}

func (fc *FunctionCall) expressionNode()      {}
func (fc *FunctionCall) TokenLiteral() string { return fc.Token.Literal }
  1. 存储引擎实现

在 internal/storage/memory.go 中,Select 方法被增强以支持聚合函数:

// Select 查询数据
func (b *MemoryBackend) Select(stmt *ast.SelectStatement) (*ast.Results, error) {
    table, exists := b.tables[stmt.TableName]
    if !exists {
        return nil, fmt.Errorf("Table '%s' doesn't exist", stmt.TableName)
    }

    results := &ast.Results{
        Columns: make([]ast.ResultColumn, 0),
        Rows:    make([][]ast.Cell, 0),
    }

    // 检查是否为聚合函数查询
    isAggregation := false
    var aggregateFunc *ast.FunctionCall

    // 处理select列表
    if len(stmt.Fields) == 1 {
        // 检查是否为 SELECT *
        if _, ok := stmt.Fields[0].(*ast.StarExpression); ok {
            // SELECT *
            for _, col := range table.Columns {
                results.Columns = append(results.Columns, ast.ResultColumn{
                    Name: col.Name,
                    Type: col.Type,
                })
            }
        } else if fn, ok := stmt.Fields[0].(*ast.FunctionCall); ok {
            // 处理函数调用
            isAggregation = true
            aggregateFunc = fn
            results.Columns = append(results.Columns, ast.ResultColumn{
                Name: fn.Name,
                Type: "FUNCTION",
            })
        }
        // ... 其他情况
    }
    // ... 其他情况

    // 如果是聚合函数查询,直接计算结果
    if isAggregation {
        // 处理WHERE子句
        filteredRows := make([][]ast.Cell, 0)
        for _, row := range table.Rows {
            if stmt.Where != nil {
                match, err := evaluateWhereCondition(stmt.Where, row, table.Columns)
                if err != nil {
                    return nil, err
                }
                if !match {
                    continue
                }
            }
            filteredRows = append(filteredRows, row)
        }

        functionResult := calculateFunctionResults(aggregateFunc, table, filteredRows)
        results.Rows = [][]ast.Cell{functionResult}
        return results, nil
    }
    // ... 非聚合函数的处理
}

每个聚合函数都有对应的计算方法:

// calculateFunctionResults 计算函数结果
func calculateFunctionResults(fn *ast.FunctionCall, table *Table, rows [][]ast.Cell) []ast.Cell {
    // 根据函数类型计算结果
    switch strings.ToUpper(fn.Name) {
    case "COUNT":
        return calculateCount(fn, table, rows)
    case "SUM":
        return calculateSum(fn, table, rows)
    case "AVG":
        return calculateAvg(fn, table, rows)
    case "MAX":
        return calculateMax(fn, table, rows)
    case "MIN":
        return calculateMin(fn, table, rows)
    default:
        return []ast.Cell{{Type: ast.CellTypeText, TextValue: fmt.Sprintf("ERROR: Unknown function '%s'", fn.Name)}}
    }
}

// calculateCount 计算COUNT函数结果
func calculateCount(fn *ast.FunctionCall, table *Table, rows [][]ast.Cell) []ast.Cell {
    return []ast.Cell{{Type: ast.CellTypeInt, IntValue: int32(len(rows))}}
}

// calculateSum 计算SUM函数结果
func calculateSum(fn *ast.FunctionCall, table *Table, rows [][]ast.Cell) []ast.Cell {
    // 处理 SUM(column) 情况
    if len(fn.Params) != 1 {
        return []ast.Cell{{Type: ast.CellTypeText, TextValue: "ERROR: SUM function requires exactly one parameter"}}
    }
    var columnName string
    // 检查参数类型
    switch param := fn.Params[0].(type) {
    case *ast.Identifier:
        columnName = param.Value
    default:
        return []ast.Cell{{Type: ast.CellTypeText, TextValue: fmt.Sprintf("ERROR: SUM function requires a column name, got %T", param)}}
    }

    // 查找列索引
    colIndex := -1
    for i, col := range table.Columns {
        if col.Name == columnName {
            colIndex = i
            break
        }
    }

    if colIndex == -1 {
        return []ast.Cell{{Type: ast.CellTypeText, TextValue: fmt.Sprintf("ERROR: Unknown column '%s'", columnName)}}
    }

    // 计算SUM值
    var sumInt int32 = 0
    var sumFloat float32 = 0.0
    hasFloat := false

    for _, row := range rows {
        cell := row[colIndex]
        switch cell.Type {
        case ast.CellTypeInt:
            sumInt += cell.IntValue
        case ast.CellTypeFloat:
            // 如果之前有整数,需要转换为浮点数
            if !hasFloat {
                sumFloat = float32(sumInt)
                hasFloat = true
            }
            sumFloat += cell.FloatValue
        }
    }

    // 返回结果
    if hasFloat {
        return []ast.Cell{{Type: ast.CellTypeFloat, FloatValue: sumFloat}}
    }
    return []ast.Cell{{Type: ast.CellTypeInt, IntValue: sumInt}}
}
// ... 其他聚合函数的实现

测试

测试SQL:

-- 创建测试表
CREATE TABLE users (id INT PRIMARY KEY, name TEXT, age INT);

-- 插入测试数据
INSERT INTO users VALUES (1, 'Alice', 20);
INSERT INTO users VALUES (2, 'Bob', 25);
INSERT INTO users VALUES (3, 'Charlie', 30);

-- 使用聚合函数
SELECT COUNT(*) FROM users;
SELECT SUM(age) FROM users;
SELECT AVG(age) FROM users;
SELECT MAX(age) FROM users;
SELECT MIN(age) FROM users;

-- 带WHERE条件的聚合函数
SELECT COUNT(*) FROM users WHERE age > 25;
SELECT SUM(age) FROM users WHERE age >= 25;

效果:
在这里插入图片描述

3. group by 实现

设计思路

1.语法解析:

首先在internal/lexer/token.go中新增group by关键字

在这里插入图片描述然后在internal/lexer/lexer.go词法分析器的lookupIdentifier方法中新增对group by关键字的识别
在这里插入图片描述
接下来在internal/parser/parser.go词法分析器中的parseSelectStatement方法中添加 GROUP 和 BY 关键字的解析,将其解析并封装为ast的一部分
在这里插入图片描述
在 internal/ast/ast.go 中添加 GroupBy 字段到 SelectStatement 结构体
在这里插入图片描述
2. 执行引擎:

首先在internal/storage/memory.go存储引擎中的Select方法实现对分组逻辑的调用
在这里插入图片描述
接着selectWithGroupBy方法,实现底层分组原理,按指定列对数据进行分组
在这里插入图片描述

在这里插入图片描述
3. internal/storage/memory.go中的selectWithGroupBy对聚合函数进行处理,确保查询结果列是聚合函数列或者分组列
在这里插入图片描述

代码实现

  1. 在词法分析器中添加新的关键字
// internal/lexer/token.go
const (
    // ... 其他关键字
    GROUP   TokenType = "GROUP"
    BY      TokenType = "BY"
)

// internal/lexer/lexer.go
func (l *Lexer) lookupIdentifier(ident string) TokenType {
    switch strings.ToUpper(ident) {
    // ... 其他关键字
    case "GROUP":
        return GROUP
    case "BY":
        return BY
    default:
        return IDENT
    }
}
  1. 在 AST 中添加新的结构体以支持 GROUP BY
// internal/ast/ast.go

// SelectStatement 表示SELECT语句
type SelectStatement struct {
    Token     lexer.Token
    Fields    []Expression
    TableName string
    Where     Expression
    GroupBy   []Expression    // 添加 GroupBy 字段
}
  1. 在语法分析器中添加对 GROUP BY 子句的解析
// internal/parser/parser.go

// parseSelectStatement 解析SELECT语句
func (p *Parser) parseSelectStatement() (*ast.SelectStatement, error) {
    stmt := &ast.SelectStatement{Token: p.curToken}

    // ... 解析选择列表和 FROM 子句 ...

    // 解析WHERE子句
    if p.peekTokenIs(lexer.WHERE) {
        p.nextToken()
        whereExpr, err := p.parseWhereClause()
        if err != nil {
            return nil, err
        }
        stmt.Where = whereExpr
    }

    // 解析GROUP BY子句
    if p.peekTokenIs(lexer.GROUP) {
        p.nextToken() // 跳过 GROUP
        if !p.expectPeek(lexer.BY) {
            return nil, fmt.Errorf("expected BY after GROUP")
        }

        // 解析GROUP BY字段列表
        for {
            p.nextToken()
            if !p.curTokenIs(lexer.IDENT) {
                return nil, fmt.Errorf("expected identifier in GROUP BY clause")
            }

            expr := &ast.Identifier{
                Token: p.curToken,
                Value: p.curToken.Literal,
            }
            stmt.GroupBy = append(stmt.GroupBy, expr)

            if !p.peekTokenIs(lexer.COMMA) {
                break
            }
            p.nextToken() // 跳过逗号
        }
    }

    return stmt, nil
}
  1. 在存储引擎中实现 GROUP BY 的执行逻辑
// internal/storage/memory.go

// Select 查询数据
func (b *MemoryBackend) Select(stmt *ast.SelectStatement) (*Results, error) {
    table, exists := b.tables[stmt.TableName]
    if !exists {
        return nil, fmt.Errorf("Table '%s' doesn't exist", stmt.TableName)
    }

    // 如果有 GROUP BY 子句
    if len(stmt.GroupBy) > 0 {
        return b.selectWithGroupBy(stmt, table)
    }

    // ... 原有的查询逻辑 ...
}

// selectWithGroupBy 处理带有 GROUP BY 的查询
func (b *MemoryBackend) selectWithGroupBy(stmt *ast.SelectStatement, table *Table) (*Results, error) {
    results := &Results{
        Columns: make([]ResultColumn, 0),
        Rows:    make([][]Cell, 0),
    }

    // 验证 GROUP BY 字段存在于表中
    groupByIndices := make([]int, len(stmt.GroupBy))
    for i, expr := range stmt.GroupBy {
        if identifier, ok := expr.(*ast.Identifier); ok {
            found := false
            for j, col := range table.Columns {
                if col.Name == identifier.Value {
                    groupByIndices[i] = j
                    found = true
                    break
                }
            }
            if !found {
                return nil, fmt.Errorf("Unknown column '%s' in 'group statement'", identifier.Value)
            }
        } else {
            return nil, fmt.Errorf("GROUP BY only supports column names")
        }
    }

    // 构建结果列
    for _, expr := range stmt.Fields {
        switch e := expr.(type) {
        case *ast.Identifier:
            found := false
            for _, col := range table.Columns {
                if col.Name == e.Value {
                    results.Columns = append(results.Columns, ResultColumn{
                        Name: col.Name,
                        Type: col.Type,
                    })
                    found = true
                    break
                }
            }
            if !found {
                return nil, fmt.Errorf("Unknown column '%s' in 'field list'", e.Value)
            }
        case *ast.FunctionCall:
            results.Columns = append(results.Columns, ResultColumn{
                Name: e.Name,
                Type: "FUNCTION",
            })
        case *ast.StarExpression:
            for _, col := range table.Columns {
                results.Columns = append(results.Columns, ResultColumn{
                    Name: col.Name,
                    Type: col.Type,
                })
            }
        default:
            return nil, fmt.Errorf("Unsupported select expression type")
        }
    }

    // 处理WHERE子句
    filteredRows := make([][]Cell, 0)
    for _, row := range table.Rows {
        if stmt.Where != nil {
            match, err := evaluateWhereCondition(stmt.Where, row, table.Columns)
            if err != nil {
                return nil, err
            }
            if !match {
                continue
            }
        }
        filteredRows = append(filteredRows, row)
    }

    // 按 GROUP BY 字段分组
    groups := make(map[string][][]Cell)
    for _, row := range filteredRows {
        // 构建分组键
        groupKey := ""
        for _, idx := range groupByIndices {
            groupKey += row[idx].String() + "|"
        }

        // 将行添加到对应的组中
        groups[groupKey] = append(groups[groupKey], row)
    }

    // 为每个组计算结果
    for _, groupRows := range groups {
        if len(groupRows) == 0 {
            continue
        }

        resultRow := make([]Cell, len(results.Columns))
        colIndex := 0

        // 处理非聚合字段(GROUP BY 字段)
        for _, expr := range stmt.Fields {
            if identifier, ok := expr.(*ast.Identifier); ok {
                // 检查是否为 GROUP BY 字段
                isGroupByField := false
                for _, groupByExpr := range stmt.GroupBy {
                    if groupByIdent, ok := groupByExpr.(*ast.Identifier); ok {
                        if groupByIdent.Value == identifier.Value {
                            isGroupByField = true
                            break
                        }
                    }
                }

                if isGroupByField {
                    // 对于 GROUP BY 字段,取第一个值(所有行应该相同)
                    for k, tableCol := range table.Columns {
                        if tableCol.Name == identifier.Value {
                            resultRow[colIndex] = groupRows[0][k]
                            break
                        }
                    }
                }
                colIndex++
            }
        }

        // 处理聚合函数
        for i, expr := range stmt.Fields {
            if fn, ok := expr.(*ast.FunctionCall); ok {
                functionResult := calculateFunctionResults(fn, table, groupRows)
                resultRow[i] = functionResult[0]
            }
        }

        results.Rows = append(results.Rows, resultRow)
    }

    return results, nil
}

测试

测试SQL:

CREATE TABLE sales (id INT PRIMARY KEY, product TEXT, category TEXT, amount FLOAT);
INSERT INTO sales VALUES (1, 'Apple', 'Fruit', 10.5);
INSERT INTO sales VALUES (2, 'Banana', 'Fruit', 8.0);
INSERT INTO sales VALUES (3, 'Carrot', 'Vegetable', 5.2);
INSERT INTO sales VALUES (4, 'Broccoli', 'Vegetable', 7.3);
INSERT INTO sales VALUES (5, 'Orange', 'Fruit', 9.8);
SELECT category, COUNT(*) FROM sales GROUP BY category;
SELECT category, SUM(amount) FROM sales GROUP BY category;
SELECT category, AVG(amount) FROM sales GROUP BY category;

效果:
在这里插入图片描述

4. order by 实现

设计思路

与group by实现基本一致

1.语法解析:

在词法分析器中添加 ORDER、BY、ASC 和 DESC 关键字

  • internal/lexer/token.go:
    在这里插入图片描述
  • internal/lexer/lexer.go的lookupIdentifier方法:
    在这里插入图片描述

在语法分析器中解析 ORDER BY 子句:
在这里插入图片描述

在 internal/ast/ast.go中添加 OrderBy 字段到 SelectStatement 结构体
在这里插入图片描述

2.执行引擎:

在internal/storage/memory.go存储引擎的Select方法中实现对order by的解析调用:
在这里插入图片描述
同时实现排序逻辑,使用 Go 标准库的 sort.Slice 进行排序同时实现自定义比较函数以支持不同数据类型的比较:
在这里插入图片描述

在这里插入图片描述

代码实现

  1. 在词法分析器中添加新的关键字
// internal/lexer/token.go
const (
    // ... 其他关键字
    ORDER   TokenType = "ORDER"
    ASC     TokenType = "ASC"
    DESC    TokenType = "DESC"
)

// internal/lexer/lexer.go
func (l *Lexer) lookupIdentifier(ident string) TokenType {
    switch strings.ToUpper(ident) {
    // ... 其他关键字
    case "ORDER":
        return ORDER
    case "ASC":
        return ASC
    case "DESC":
        return DESC
    default:
        return IDENT
    }
}
  1. 在 AST 中添加新的结构体以支持 ORDER BY
// internal/ast/ast.go

// SelectStatement 表示SELECT语句
type SelectStatement struct {
    Token     lexer.Token
    Fields    []Expression
    TableName string
    Where     Expression
    OrderBy   []OrderByClause // 添加 OrderBy 字段
}

// OrderByClause 表示 ORDER BY 子句中的排序项
type OrderByClause struct {
    Expression Expression
    Direction  string // "ASC" 或 "DESC"
}
  1. 在语法分析器中添加对 ORDER BY 子句的解析
// internal/parser/parser.go

// parseSelectStatement 解析SELECT语句
func (p *Parser) parseSelectStatement() (*ast.SelectStatement, error) {
    stmt := &ast.SelectStatement{Token: p.curToken}

    // ... 解析选择列表、FROM 子句和 WHERE 子句 ...

    // 解析GROUP BY子句(如果有的话)
    if p.peekTokenIs(lexer.GROUP) {
        // ... GROUP BY 解析逻辑 ...
    }

    // 解析ORDER BY子句
    if p.peekTokenIs(lexer.ORDER) {
        orderExprs, err := p.parseOrderByClause()
        if err != nil {
            return nil, err
        }
        stmt.OrderBy = orderExprs
    }

    return stmt, nil
}

// parseOrderByClause 解析ORDER BY子句
func (p *Parser) parseOrderByClause() ([]ast.OrderByClause, error) {
    // 跳过 ORDER 关键字
    if !p.expectPeek(lexer.ORDER) {
        return nil, fmt.Errorf("expected ORDER keyword")
    }

    // 跳过 BY 关键字
    if !p.expectPeek(lexer.BY) {
        return nil, fmt.Errorf("expected BY keyword")
    }

    var orderExprs []ast.OrderByClause

    for {
        p.nextToken()
        // 解析表达式(列名)
        if !p.curTokenIs(lexer.IDENT) {
            return nil, fmt.Errorf("expected identifier in ORDER BY clause")
        }

        expr := &ast.Identifier{
            Token: p.curToken,
            Value: p.curToken.Literal,
        }

        orderClause := ast.OrderByClause{
            Expression: expr,
            Direction:  "ASC", // 默认升序
        }

        // 检查是否有 ASC 或 DESC
        if p.peekTokenIs(lexer.ASC) || p.peekTokenIs(lexer.DESC) {
            p.nextToken()
            orderClause.Direction = p.curToken.Literal
        }

        orderExprs = append(orderExprs, orderClause)

        // 如果没有逗号,说明结束了
        if !p.peekTokenIs(lexer.COMMA) {
            break
        }
        p.nextToken() // 跳过逗号
    }

    return orderExprs, nil
}
  1. 在存储引擎中实现 ORDER BY 的执行逻辑
// internal/storage/memory.go

// Select 查询数据
func (b *MemoryBackend) Select(stmt *ast.SelectStatement) (*Results, error) {
    // ... 原有的查询逻辑 ...

    // 处理 ORDER BY
    if len(stmt.OrderBy) > 0 {
        var err error
        results.Rows, err = b.orderBy(results.Rows, results.Columns, stmt.OrderBy, table.Columns)
        if err != nil {
            return nil, err
        }
    }

    return results, nil
}

// orderBy 根据 ORDER BY 子句对结果进行排序
func (b *MemoryBackend) orderBy(rows [][]Cell, resultCols []ResultColumn, orderBy []ast.OrderByClause, tableCols []ast.ColumnDefinition) ([][]Cell, error) {
    // 创建列名到索引的映射
    colIndexMap := make(map[string]int)
    for i, col := range resultCols {
        colIndexMap[col.Name] = i
    }

    // 创建排序键的索引和方向
    type sortKey struct {
        index     int
        direction string
    }

    var sortKeys []sortKey
    for _, ob := range orderBy {
        identifier, ok := ob.Expression.(*ast.Identifier)
        if !ok {
            return nil, fmt.Errorf("ORDER BY only supports column names")
        }

        index, exists := colIndexMap[identifier.Value]
        if !exists {
            return nil, fmt.Errorf("Unknown column '%s' in 'order clause'", identifier.Value)
        }

        sortKeys = append(sortKeys, sortKey{
            index:     index,
            direction: ob.Direction,
        })
    }

    // 使用 sort.Slice 进行排序
    sort.Slice(rows, func(i, j int) bool {
        for _, key := range sortKeys {
            left := rows[i][key.index]
            right := rows[j][key.index]

            // 比较两个值
            result, err := compareValues(left, right, "<")
            if err != nil {
                // 如果比较出错,保持原有顺序
                return false
            }

            if result {
                // 如果是升序,返回 true
                // 如果是降序,返回 false
                return key.direction == "ASC"
            } else {
                // 检查是否相等
                equal, _ := compareValues(left, right, "=")
                if !equal {
                    // 如果是降序,返回 true
                    // 如果是升序,返回 false
                    return key.direction == "DESC"
                }
                // 如果相等,继续比较下一个排序键
            }
        }
        // 所有键都相等,保持原有顺序
        return false
    })

    return rows, nil
}

测试

测试SQL:

CREATE TABLE sales (id INT PRIMARY KEY, product TEXT, category TEXT, amount FLOAT);
INSERT INTO sales VALUES (1, 'Apple', 'Fruit', 10.5);
INSERT INTO sales VALUES (2, 'Banana', 'Fruit', 8.0);
INSERT INTO sales VALUES (3, 'Carrot', 'Vegetable', 5.2);
INSERT INTO sales VALUES (4, 'Broccoli', 'Vegetable', 7.3);
INSERT INTO sales VALUES (5, 'Orange', 'Fruit', 9.8);

SELECT * FROM sales ORDER BY amount;
SELECT * FROM sales ORDER BY amount DESC;
SELECT * FROM sales ORDER BY category, amount DESC;

效果:
在这里插入图片描述


网站公告

今日签到

点亮在社区的每一天
去签到