【golang】27、用 golang 实现一个数据库:lex、parse 解析、操作 sql

发布于:2024-03-11 ⋅ 阅读:(53) ⋅ 点赞:(0)

我们将会用 golang 实现一个数据库,完整源码见 https://github.com/eatonphil/gosql。

首先,我们实现一个 parser 来解析 CREATE、INSERT、SELECT 语句,然后会在内存实现一个 db server,并支持 TEXT 和 INT 类型,实现一种 REPL(交互式终端)。

我们希望达到如下效果:

$ go run *.go
Welcome to gosql.
# CREATE TABLE users (id INT, name TEXT);
ok
# INSERT INTO users VALUES (1, 'Phil');
ok
# SELECT id, name FROM users;
| id | name |
====================
| 1 |  Phil |
ok
# INSERT INTO users VALUES (2, 'Kate');
ok
# SELECT name, id FROM users;
| name | id |
====================
| Phil |  1 |
| Kate |  2 |
ok

首先要实现 lexing,即将一个 SQL 文本映射为一串 tokens,然后,我们会调用 parse 函数来查找单独的 SQL 语句(如 SELECT),这些 parse 函数会一次调用他们自己的 helper 函数,来查找可递归解析的块、关键字、符号(比如圆括号)、标识符(比如 table name)、数字、字符串文本等。

然后,我们会实现一个内存的 server,它基于 AST 操作。最终实现一个 REPL 来接受 SQL 文本,并发送给 server backend。

一、lexing 词法分析

词法分析,寻找不同的字符分组,主要包括 identifiers、numbers、strings、symbols。

逻辑是由各个 helper function 完成的

  • 如果成功找到了 token,则返回 true 和 下次开始处理的 location。
  • 否则,会持续迭代到循环结束。

1.1 数据结构

首先,在 lexer.go 定义类型和常量:

package gosql

import (
    "fmt"
    "strings"
)

type location struct {
    line uint
    col  uint
}

type keyword string

const (
    selectKeyword keyword = "select"
    fromKeyword   keyword = "from"
    asKeyword     keyword = "as"
    tableKeyword  keyword = "table"
    createKeyword keyword = "create"
    insertKeyword keyword = "insert"
    intoKeyword   keyword = "into"
    valuesKeyword keyword = "values"
    intKeyword    keyword = "int"
    textKeyword   keyword = "text"
)

type symbol string

const (
    semicolonSymbol  symbol = ";"
    asteriskSymbol   symbol = "*"
    commaSymbol      symbol = ","
    leftparenSymbol  symbol = "("
    rightparenSymbol symbol = ")"
)

type tokenKind uint

const (
    keywordKind tokenKind = iota
    symbolKind
    identifierKind
    stringKind
    numericKind
)

type token struct {
    value string
    kind  tokenKind
    loc   location
}

type cursor struct {
    pointer uint
    loc     location
}

func (t *token) equals(other *token) bool {
    return t.value == other.value && t.kind == other.kind
}

type lexer func(string, cursor) (*token, cursor, bool)

1.2 主框架

然后,实现 main loop 逻辑:

func lex(source string) ([]*token, error) {
	tokens := []*token{}
	cur := cursor{}

lex:
	for cur.pointer < uint(len(source)) {
		lexers := []lexer{lexKeyword, lexSymbol, lexString, lexNumeric, lexIdentifier}
		for _, l := range lexers { // 尝试解析每种类型
			if token, newCursor, ok := l(source, cur); ok {
				cur = newCursor

				// Omit nil tokens for valid, but empty syntax like newlines
				if token != nil {
					tokens = append(tokens, token)
				}

				continue lex // 若解析成功,则退出
			}
		}

	  // 若每种类型都解析失败,则报错,指明详细原因 line 和 col
		hint := ""
		if len(tokens) > 0 {
			hint = " after " + tokens[len(tokens)-1].value
		}
		return nil, fmt.Errorf("Unable to lex token%s, at %d:%d", hint, cur.loc.line, cur.loc.col)
	}

	return tokens, nil
}

1.3 分析 number

因为 number 规则复杂,所以参考 PostgreSQL documentation (section 4.1.2.6) 的规则。

func lexNumeric(source string, ic cursor) (*token, cursor, bool) {
    cur := ic

    periodFound := false
    expMarkerFound := false

    for ; cur.pointer < uint(len(source)); cur.pointer++ {
        c := source[cur.pointer]
        cur.loc.col++

        isDigit := c >= '0' && c <= '9'
        isPeriod := c == '.'
        isExpMarker := c == 'e'

        // Must start with a digit or period
        if cur.pointer == ic.pointer {
            if !isDigit && !isPeriod {
                return nil, ic, false
            }

            periodFound = isPeriod
            continue
        }

        if isPeriod {
            if periodFound {
                return nil, ic, false
            }

            periodFound = true
            continue
        }

        if isExpMarker {
            if expMarkerFound {
                return nil, ic, false
            }

            // No periods allowed after expMarker
            periodFound = true
            expMarkerFound = true

            // expMarker must be followed by digits
            if cur.pointer == uint(len(source)-1) {
                return nil, ic, false
            }

            cNext := source[cur.pointer+1]
            if cNext == '-' || cNext == '+' {
                cur.pointer++
                cur.loc.col++
            }

            continue
        }

        if !isDigit {
            break
        }
    }

    // No characters accumulated
    if cur.pointer == ic.pointer {
        return nil, ic, false
    }

    return &token{
        value: source[ic.pointer:cur.pointer],
        loc:   ic.loc,
        kind:  numericKind,
    }, cur, true
}

1.4 分析 strings

字符串必须以单个撇号开头和结尾。如果后面跟着另一个撇号,则它们可以包含一个撇号。我们将把这种字符分隔的词法逻辑放入辅助函数中,以便我们在分析标识符时可以再次使用它。

func lexCharacterDelimited(source string, ic cursor, delimiter byte) (*token, cursor, bool) {
    cur := ic

    if len(source[cur.pointer:]) == 0 {
        return nil, ic, false
    }

    if source[cur.pointer] != delimiter {
        return nil, ic, false
    }

    cur.loc.col++
    cur.pointer++

    var value []byte
    for ; cur.pointer < uint(len(source)); cur.pointer++ {
        c := source[cur.pointer]

        if c == delimiter {
            // SQL escapes are via double characters, not backslash. SQL转义是通过双字符进行的,而不是反斜杠。
            if cur.pointer+1 >= uint(len(source)) || source[cur.pointer+1] != delimiter {
                return &token{
                    value: string(value),
                    loc:   ic.loc,
                    kind:  stringKind,
                }, cur, true
            } else {
                value = append(value, delimiter)
                cur.pointer++
                cur.loc.col++
            }
        }

        value = append(value, c)
        cur.loc.col++
    }

    return nil, ic, false
}

func lexString(source string, ic cursor) (*token, cursor, bool) {
    return lexCharacterDelimited(source, ic, '\'')
}

1.5 分析符号和关键字

符号来自一组固定的字符串,因此很容易进行比较。空白字符应该被丢弃。

func lexSymbol(source string, ic cursor) (*token, cursor, bool) {
    c := source[ic.pointer]
    cur := ic
    // Will get overwritten later if not an ignored syntax
    cur.pointer++
    cur.loc.col++

    switch c {
    // Syntax that should be thrown away
    case '\n':
        cur.loc.line++
        cur.loc.col = 0
        fallthrough
    case '\t':
        fallthrough
    case ' ':
        return nil, cur, true
    }

    // Syntax that should be kept
    symbols := []symbol{
        commaSymbol,
        leftParenSymbol,
        rightParenSymbol,
        semicolonSymbol,
        asteriskSymbol,
    }

    var options []string
    for _, s := range symbols {
        options = append(options, string(s))
    }

    // Use `ic`, not `cur`
    match := longestMatch(source, ic, options)
    // Unknown character
    if match == "" {
        return nil, ic, false
    }

    cur.pointer = ic.pointer + uint(len(match))
    cur.loc.col = ic.loc.col + uint(len(match))

    return &token{
        value: match,
        loc:   ic.loc,
        kind:  symbolKind,
    }, cur, true
}

Keywords 甚至更简单, 用 longestMatch() 函数实现:

// longestMatch iterates through a source string starting at the given
// cursor to find the longest matching substring among the provided
// options
func longestMatch(source string, ic cursor, options []string) string {
    var value []byte
    var skipList []int
    var match string

    cur := ic

    for cur.pointer < uint(len(source)) {

        value = append(value, strings.ToLower(string(source[cur.pointer]))...)
        cur.pointer++

    match:
        for i, option := range options {
            for _, skip := range skipList {
                if i == skip {
                    continue match
                }
            }

            // Deal with cases like INT vs INTO
            if option == string(value) {
                skipList = append(skipList, i)
                if len(option) > len(match) {
                    match = option
                }

                continue
            }

            sharesPrefix := string(value) == option[:cur.pointer-ic.pointer]
            tooLong := len(value) > len(option)
            if tooLong || !sharesPrefix {
                skipList = append(skipList, i)
            }
        }

        if len(skipList) == len(options) {
            break
        }
    }

    return match
}

1.6 分析 indetifiers(标识符)

标识符是双引号字符串或以字母字符开头并且可能包含数字和下划线的一组字符。

func lexIdentifier(source string, ic cursor) (*token, cursor, bool) {
    // Handle separately if is a double-quoted identifier
    if token, newCursor, ok := lexCharacterDelimited(source, ic, '"'); ok {
        return token, newCursor, true
    }

    cur := ic

    c := source[cur.pointer]
    // Other characters count too, big ignoring non-ascii for now
    isAlphabetical := (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z')
    if !isAlphabetical {
        return nil, ic, false
    }
    cur.pointer++
    cur.loc.col++

    value := []byte{c}
    for ; cur.pointer < uint(len(source)); cur.pointer++ {
        c = source[cur.pointer]

        // Other characters count too, big ignoring non-ascii for now
        isAlphabetical := (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z')
        isNumeric := c >= '0' && c <= '9'
        if isAlphabetical || isNumeric || c == '$' || c == '_' {
            value = append(value, c)
            cur.loc.col++
            continue
        }

        break
    }

    if len(value) == 0 {
        return nil, ic, false
    }

    return &token{
        // Unquoted dentifiers are case-insensitive
        value: strings.ToLower(string(value)),
        loc:   ic.loc,
        kind:  identifierKind,
    }, cur, true
}

至此,全部的 lexer 就完成了,现在 lexer_test.go 也可以通过了。debug 一遍单测就能完全明白上文了。

		{
			input: "insert into users Values (105, 233)",
			Tokens: []Token{
				{
					Loc:   Location{Col: 0, Line: 0},
					Value: string(InsertKeyword),
					Kind:  KeywordKind,
				},
				{
					Loc:   Location{Col: 7, Line: 0},
					Value: string(IntoKeyword),
					Kind:  KeywordKind,
				},
				{
					Loc:   Location{Col: 12, Line: 0},
					Value: "users",
					Kind:  IdentifierKind,
				},
				{
					Loc:   Location{Col: 18, Line: 0},
					Value: string(ValuesKeyword),
					Kind:  KeywordKind,
				},
				{
					Loc:   Location{Col: 25, Line: 0},
					Value: "(",
					Kind:  SymbolKind,
				},
				{
					Loc:   Location{Col: 26, Line: 0},
					Value: "105",
					Kind:  NumericKind,
				},
				{
					Loc:   Location{Col: 30, Line: 0},
					Value: ",",
					Kind:  SymbolKind,
				},
				{
					Loc:   Location{Col: 32, Line: 0},
					Value: "233",
					Kind:  NumericKind,
				},
				{
					Loc:   Location{Col: 36, Line: 0},
					Value: ")",
					Kind:  SymbolKind,
				},
			},
			err: nil,
		},

二、AST 抽象语法树

2.1 AST、Statement 数据结构

在最高层级, AST 是 statements 的集合

package main

type Ast struct {
    Statements []*Statement
}

目前,statement 是 INSERT、CREATE、SELECT 之一:

type AstKind uint

const (
    SelectKind AstKind = iota
    CreateTableKind
    InsertKind
)

type Statement struct {
    SelectStatement      *SelectStatement
    CreateTableStatement *CreateTableStatement
    InsertStatement      *InsertStatement
    Kind                 AstKind
}

2.2 INSERT

目前,INSERT 有 table name 和一列 values

type InsertStatement struct {
	table token
	values *[]*expression
}

expression 目前是 literal token,将来可能是一个 function call,或者一个 inline operation:

type expressionKind uint

const (
    literalKind expressionKind = iota
)

type expression struct {
    literal *token
    kind    expressionKind
}

2.3 CREATE

目前,CREATE 有 table name 和一列 column names 和 column types

type columnDefinition struct {
    name     token
    datatype token
}

type CreateTableStatement struct {
    name token
    cols *[]*columnDefinition
}

2.4 SELECT

目前,SELECT 有 table name 和一列 column names

type SelectStatement struct {
    item []*expression
    from token
}

至此,AST 完毕。

三、Parsing

parse() 的输入是一堆 tokens,持续解析到分号为止。

通常的策略是,传递并递增一个 cursor 参数,每个 helper() 函数都会返回一个新的 cursor,其表示了后续的起始位置。

3.1 主框架

package main

import (
    "errors"
    "fmt"
)

func tokenFromKeyword(k keyword) token {
    return token{
        kind:  keywordKind,
        value: string(k),
    }
}

func tokenFromSymbol(s symbol) token {
    return token{
        kind:  symbolKind,
        value: string(s),
    }
}

func expectToken(tokens []*token, cursor uint, t token) bool {
    if cursor >= uint(len(tokens)) {
        return false
    }

    return t.equals(tokens[cursor])
}

func helpMessage(tokens []*token, cursor uint, msg string) {
    var c *token
    if cursor < uint(len(tokens)) {
        c = tokens[cursor]
    } else {
        c = tokens[cursor-1]
    }

    fmt.Printf("[%d,%d]: %s, got: %s\n", c.loc.line, c.loc.col, msg, c.value)
}

func Parse(source string) (*Ast, error) {
    tokens, err := lex(source)
    if err != nil {
        return nil, err
    }

    a := Ast{}
    cursor := uint(0)
    for cursor < uint(len(tokens)) {
        stmt, newCursor, ok := parseStatement(tokens, cursor, tokenFromSymbol(semicolonSymbol))
        if !ok {
            helpMessage(tokens, cursor, "Expected statement")
            return nil, errors.New("Failed to parse, expected statement")
        }
        cursor = newCursor

        a.Statements = append(a.Statements, stmt)

        atLeastOneSemicolon := false
        for expectToken(tokens, cursor, tokenFromSymbol(semicolonSymbol)) {
            cursor++
            atLeastOneSemicolon = true
        }

        if !atLeastOneSemicolon {
            helpMessage(tokens, cursor, "Expected semi-colon delimiter between statements")
            return nil, errors.New("Missing semi-colon between statements")
        }
    }

    return &a, nil
}

3.2 Parsing Statements

每条语句将是 INSERT、CREATE 或 SELECT之一。ParseStatement helper 将对这些语句类型中的每一种调用一个 helper,如果其中一种语句解析成功,则返回 TRUE。

func parseStatement(tokens []*token, initialCursor uint, delimiter token) (*Statement, uint, bool) {
    cursor := initialCursor

    // Look for a SELECT statement
    semicolonToken := tokenFromSymbol(semicolonSymbol)
    slct, newCursor, ok := parseSelectStatement(tokens, cursor, semicolonToken)
    if ok {
        return &Statement{
            Kind:            SelectKind,
            SelectStatement: slct,
        }, newCursor, true
    }

    // Look for a INSERT statement
    inst, newCursor, ok := parseInsertStatement(tokens, cursor, semicolonToken)
    if ok {
        return &Statement{
            Kind:            InsertKind,
            InsertStatement: inst,
        }, newCursor, true
    }

    // Look for a CREATE statement
    crtTbl, newCursor, ok := parseCreateTableStatement(tokens, cursor, semicolonToken)
    if ok {
        return &Statement{
            Kind:                 CreateTableKind,
            CreateTableStatement: crtTbl,
        }, newCursor, true
    }

    return nil, initialCursor, false
}

3.2.1 Parsing select statements

SELECT 的模式比较简单,如下:

  1. SELECT
  2. $expression [, …]
  3. FROM
  4. $table-name

按这四个部分的模式,得到框架如下:

func parseSelectStatement(tokens []*token, initialCursor uint, delimiter token) (*SelectStatement, uint, bool) {
    cursor := initialCursor
    if !expectToken(tokens, cursor, tokenFromKeyword(selectKeyword)) {
        return nil, initialCursor, false
    }
    cursor++

    slct := SelectStatement{}

    exps, newCursor, ok := parseExpressions(tokens, cursor, []token{tokenFromKeyword(fromKeyword), delimiter})
    if !ok {
        return nil, initialCursor, false
    }

    slct.item = *exps
    cursor = newCursor

    if expectToken(tokens, cursor, tokenFromKeyword(fromKeyword)) {
        cursor++

        from, newCursor, ok := parseToken(tokens, cursor, identifierKind)
        if !ok {
            helpMessage(tokens, cursor, "Expected FROM token")
            return nil, initialCursor, false
        }

        slct.from = *from
        cursor = newCursor
    }

    return &slct, cursor, true
}

其中 parseToken() 会查找指定的 token 类型:

func parseToken(tokens []*token, initialCursor uint, kind tokenKind) (*token, uint, bool) {
    cursor := initialCursor

    if cursor >= uint(len(tokens)) {
        return nil, initialCursor, false
    }

    current := tokens[cursor]
    if current.kind == kind {
        return current, cursor + 1, true
    }

    return nil, initialCursor, false
}

ParseExpressions helper 将查找由逗号分隔的 token,直到找到分隔符。它将使用现有的 parseExpression helper 函数。

func parseExpressions(tokens []*token, initialCursor uint, delimiters []token) (*[]*expression, uint, bool) {
    cursor := initialCursor

    exps := []*expression{}
outer:
    for {
        if cursor >= uint(len(tokens)) {
            return nil, initialCursor, false
        }

        // Look for delimiter
        current := tokens[cursor]
        for _, delimiter := range delimiters {
            if delimiter.equals(current) {
                break outer
            }
        }

        // Look for comma
        if len(exps) > 0 {
            if !expectToken(tokens, cursor, tokenFromSymbol(commaSymbol)) {
                helpMessage(tokens, cursor, "Expected comma")
                return nil, initialCursor, false
            }

            cursor++
        }

        // Look for expression
        exp, newCursor, ok := parseExpression(tokens, cursor, tokenFromSymbol(commaSymbol))
        if !ok {
            helpMessage(tokens, cursor, "Expected expression")
            return nil, initialCursor, false
        }
        cursor = newCursor

        exps = append(exps, exp)
    }

    return &exps, cursor, true
}

parseExpression helper 函数,会查找 number、string、 或者 identifier token。

func parseExpression(tokens []*token, initialCursor uint, _ token) (*expression, uint, bool) {
    cursor := initialCursor

    kinds := []tokenKind{identifierKind, numericKind, stringKind}
    for _, kind := range kinds {
        t, newCursor, ok := parseToken(tokens, cursor, kind)
        if ok {
            return &expression{
                literal: t,
                kind:    literalKind,
            }, newCursor, true
        }
    }

    return nil, initialCursor, false
}

至此,parse select 就完成了。

3.2.2 Parsing insert statements

模式如下:

  1. INSERT
  2. INTO
  3. $table-name
  4. VALUES
  5. (
  6. $expression [, ...]
  7. )

可以复用现有的 helper 函数,框架如下:

func parseInsertStatement(tokens []*token, initialCursor uint, delimiter token) (*InsertStatement, uint, bool) {
    cursor := initialCursor

    // Look for INSERT
    if !expectToken(tokens, cursor, tokenFromKeyword(insertKeyword)) {
        return nil, initialCursor, false
    }
    cursor++

    // Look for INTO
    if !expectToken(tokens, cursor, tokenFromKeyword(intoKeyword)) {
        helpMessage(tokens, cursor, "Expected into")
        return nil, initialCursor, false
    }
    cursor++

    // Look for table name
    table, newCursor, ok := parseToken(tokens, cursor, identifierKind)
    if !ok {
        helpMessage(tokens, cursor, "Expected table name")
        return nil, initialCursor, false
    }
    cursor = newCursor

    // Look for VALUES
    if !expectToken(tokens, cursor, tokenFromKeyword(valuesKeyword)) {
        helpMessage(tokens, cursor, "Expected VALUES")
        return nil, initialCursor, false
    }
    cursor++

    // Look for left paren
    if !expectToken(tokens, cursor, tokenFromSymbol(leftparenSymbol)) {
        helpMessage(tokens, cursor, "Expected left paren")
        return nil, initialCursor, false
    }
    cursor++

    // Look for expression list
    values, newCursor, ok := parseExpressions(tokens, cursor, []token{tokenFromSymbol(rightparenSymbol)})
    if !ok {
        return nil, initialCursor, false
    }
    cursor = newCursor

    // Look for right paren
    if !expectToken(tokens, cursor, tokenFromSymbol(rightparenSymbol)) {
        helpMessage(tokens, cursor, "Expected right paren")
        return nil, initialCursor, false
    }
    cursor++

    return &InsertStatement{
        table:  *table,
        values: values,
    }, cursor, true
}

3.2.3 Parsing create statements

模式如下:

  1. CREATE
  2. $table-name
  3. (
  4. [$column-name $column-type [, ...]]
  5. )

通过一个新的 parseColumnDefinitions helper 勾勒出这一点,我们得到:

func parseCreateTableStatement(tokens []*token, initialCursor uint, delimiter token) (*CreateTableStatement, uint, bool) {
    cursor := initialCursor

    if !expectToken(tokens, cursor, tokenFromKeyword(createKeyword)) {
        return nil, initialCursor, false
    }
    cursor++

    if !expectToken(tokens, cursor, tokenFromKeyword(tableKeyword)) {
        return nil, initialCursor, false
    }
    cursor++

    name, newCursor, ok := parseToken(tokens, cursor, identifierKind)
    if !ok {
        helpMessage(tokens, cursor, "Expected table name")
        return nil, initialCursor, false
    }
    cursor = newCursor

    if !expectToken(tokens, cursor, tokenFromSymbol(leftparenSymbol)) {
        helpMessage(tokens, cursor, "Expected left parenthesis")
        return nil, initialCursor, false
    }
    cursor++

    cols, newCursor, ok := parseColumnDefinitions(tokens, cursor, tokenFromSymbol(rightparenSymbol))
    if !ok {
        return nil, initialCursor, false
    }
    cursor = newCursor

    if !expectToken(tokens, cursor, tokenFromSymbol(rightparenSymbol)) {
        helpMessage(tokens, cursor, "Expected right parenthesis")
        return nil, initialCursor, false
    }
    cursor++

    return &CreateTableStatement{
        name: *name,
        cols: cols,
    }, cursor, true
}

ParseColumnDefinitions helper 将查看 column name,后跟以逗号分隔并以某个分隔符结尾的 column type:

func parseColumnDefinitions(tokens []*token, initialCursor uint, delimiter token) (*[]*columnDefinition, uint, bool) {
    cursor := initialCursor

    cds := []*columnDefinition{}
    for {
        if cursor >= uint(len(tokens)) {
            return nil, initialCursor, false
        }

        // Look for a delimiter
        current := tokens[cursor]
        if delimiter.equals(current) {
            break
        }

        // Look for a comma
        if len(cds) > 0 {
            if !expectToken(tokens, cursor, tokenFromSymbol(commaSymbol)) {
                helpMessage(tokens, cursor, "Expected comma")
                return nil, initialCursor, false
            }

            cursor++
        }

        // Look for a column name
        id, newCursor, ok := parseToken(tokens, cursor, identifierKind)
        if !ok {
            helpMessage(tokens, cursor, "Expected column name")
            return nil, initialCursor, false
        }
        cursor = newCursor

        // Look for a column type
        ty, newCursor, ok := parseToken(tokens, cursor, keywordKind)
        if !ok {
            helpMessage(tokens, cursor, "Expected column type")
            return nil, initialCursor, false
        }
        cursor = newCursor

        cds = append(cds, &columnDefinition{
            name:     *id,
            datatype: *ty,
        })
    }

    return &cds, cursor, true
}

至此,parser 就结束了,可以在 parser_test.go 跑一遍单测。

四、in-memory backend

允许用户 create、select、insert 数据的通用后端界面:

package main

import "errors"

type ColumnType uint

const (
    TextType ColumnType = iota
    IntType
)

type Cell interface {
    AsText() string
    AsInt() int32
}

type Results struct {
    Columns []struct {
        Type ColumnType
        Name string
    }
    Rows [][]Cell
}

var (
    ErrTableDoesNotExist  = errors.New("Table does not exist")
    ErrColumnDoesNotExist = errors.New("Column does not exist")
    ErrInvalidSelectItem  = errors.New("Select item is not valid")
    ErrInvalidDatatype    = errors.New("Invalid datatype")
    ErrMissingValues      = errors.New("Missing values")
)

type Backend interface {
    CreateTable(*CreateTableStatement) error
    Insert(*InsertStatement) error
    Select(*SelectStatement) (*Results, error)
}

4.1 memory layout

内存需要存储一堆 tables,每个 table 有一堆 columns 和 rows。每个 column 有 name 和 type。每个 row 是一个 bytes[]。

package main

import (
    "bytes"
    "encoding/binary"
    "fmt"
    "strconv"
)

type MemoryCell []byte

func (mc MemoryCell) AsInt() int32 {
    var i int32
    err := binary.Read(bytes.NewBuffer(mc), binary.BigEndian, &i)
    if err != nil {
        panic(err)
    }

    return i
}

func (mc MemoryCell) AsText() string {
    return string(mc)
}

type table struct {
    columns     []string
    columnTypes []ColumnType
    rows        [][]MemoryCell
}

type MemoryBackend struct {
    tables map[string]*table
}

func NewMemoryBackend() *MemoryBackend {
    return &MemoryBackend{
        tables: map[string]*table{},
    }
}

4.2 实现 Create table

会在 tables map 数据结构中,增加一项。然后 create columns。

func (mb *MemoryBackend) CreateTable(crt *CreateTableStatement) error {
    t := table{}
    mb.tables[crt.name.value] = &t
    if crt.cols == nil {

        return nil
    }

    for _, col := range *crt.cols {
        t.columns = append(t.columns, col.name.value)

        var dt ColumnType
        switch col.datatype.value {
        case "int":
            dt = IntType
        case "text":
            dt = TextType
        default:
            return ErrInvalidDatatype
        }

        t.columnTypes = append(t.columnTypes, dt)
    }

    return nil
}

4.3 实现 insert

为简单起见,我们假定传递的 value,可以正确 map 到指定 column type。我们将引用一个帮助器将值映射到内部存储tokenToCell。

func (mb *MemoryBackend) Insert(inst *InsertStatement) error {
    table, ok := mb.tables[inst.table.value]
    if !ok {
        return ErrTableDoesNotExist
    }

    if inst.values == nil {
        return nil
    }

    row := []MemoryCell{}

    if len(*inst.values) != len(table.columns) {
        return ErrMissingValues
    }

    for _, value := range *inst.values {
        if value.kind != literalKind {
            fmt.Println("Skipping non-literal.")
            continue
        }

        row = append(row, mb.tokenToCell(value.literal))
    }

    table.rows = append(table.rows, row)
    return nil
}

TokenToCell helper 将把数字写成二进制字节,把字符串写成字节:

func (mb *MemoryBackend) tokenToCell(t *token) MemoryCell {
    if t.kind == numericKind {
        buf := new(bytes.Buffer)
        i, err := strconv.Atoi(t.value)
        if err != nil {
            panic(err)
        }

        err = binary.Write(buf, binary.BigEndian, int32(i))
        if err != nil {
            panic(err)
        }
        return MemoryCell(buf.Bytes())
    }

    if t.kind == stringKind {
        return MemoryCell(t.value)
    }

    return nil
}

4.4 实现 select

最后,对于SELECT,我们将迭代表中的每一行,并根据AST指定的列返回单元格。

func (mb *MemoryBackend) Select(slct *SelectStatement) (*Results, error) {
    table, ok := mb.tables[slct.from.table]
    if !ok {
        return nil, ErrTableDoesNotExist
    }

    results := [][]Cell{}
    columns := []struct {
        Type ColumnType
        Name string
    }{}

    for i, row := range table.rows {
        result := []Cell{}
        isFirstRow := i == 0

        for _, exp := range slct.item {
            if exp.kind != literalKind {
                // Unsupported, doesn't currently exist, ignore.
                fmt.Println("Skipping non-literal expression.")
                continue
            }

            lit := exp.literal
            if lit.kind == identifierKind {
                found := false
                for i, tableCol := range table.columns {
                    if tableCol == lit.value {
                        if isFirstRow {
                            columns = append(columns, struct {
                                Type ColumnType
                                Name string
                            }{
                                Type: table.columnTypes[i],
                                Name: lit.value,
                            })
                        }

                        result = append(result, row[i])
                        found = true
                        break
                    }
                }

                if !found {
                    return nil, ErrColumnDoesNotExist
                }

                continue
            }

            return nil, ErrColumnDoesNotExist
        }

        results = append(results, result)
    }

    return &Results{
        Columns: columns,
        Rows:    results,
    }, nil
}

五、REPL 交互式终端

最后,我们准备将解析器和内存后端包装在REPL中。最复杂的部分是显示SELECT查询的结果表。

package main

import (
    "bufio"
    "fmt"
    "os"
    "strings"

    "github.com/eatonphil/gosql"
)

func main() {
    mb := gosql.NewMemoryBackend()

    reader := bufio.NewReader(os.Stdin)
    fmt.Println("Welcome to gosql.")
    for {
        fmt.Print("# ")
        text, err := reader.ReadString('\n')
        text = strings.Replace(text, "\n", "", -1)

        ast, err := gosql.Parse(text)
        if err != nil {
            panic(err)
        }

        for _, stmt := range ast.Statements {
            switch stmt.Kind {
            case gosql.CreateTableKind:
                err = mb.CreateTable(ast.Statements[0].CreateTableStatement)
                if err != nil {
                    panic(err)
                }
                fmt.Println("ok")
            case gosql.InsertKind:
                err = mb.Insert(stmt.InsertStatement)
                if err != nil {
                    panic(err)
                }

                fmt.Println("ok")
            case gosql.SelectKind:
                results, err := mb.Select(stmt.SelectStatement)
                if err != nil {
                    panic(err)
                }

                for _, col := range results.Columns {
                    fmt.Printf("| %s ", col.Name)
                }
                fmt.Println("|")

                for i := 0; i < 20; i++ {
                    fmt.Printf("=")
                }
                fmt.Println()

                for _, result := range results.Rows {
                    fmt.Printf("|")

                    for i, cell := range result {
                        typ := results.Columns[i].Type
                        s := ""
                        switch typ {
                        case gosql.IntType:
                            s = fmt.Sprintf("%d", cell.AsInt())
                        case gosql.TextType:
                            s = cell.AsText()
                        }

                        fmt.Printf(" %s | ", s)
                    }

                    fmt.Println()
                }

                fmt.Println("ok")
            }
        }
    }
}

结果如下:

$ go run *.go
Welcome to gosql.
# CREATE TABLE users (id INT, name TEXT);
ok
# INSERT INTO users VALUES (1, 'Phil');
ok
# SELECT id, name FROM users;
| id | name |
====================
| 1 |  Phil |
ok
# INSERT INTO users VALUES (2, 'Kate');
ok
# SELECT name, id FROM users;
| name | id |
====================
| Phil |  1 |
| Kate |  2 |
ok

至此,实现了简单的 SQL DB。后续会实现 filter、sorting、indexing 等更复杂的功能。

本文含有隐藏内容,请 开通VIP 后查看