简易实现 Go 的 ORM 框架(上) | 青训营

一、数据库基础

1.啥是ORM啊

对象关系映射(Object Relational Mapping,简称ORM)是通过使用描述对象和数据库之间映射的元数据,将面向对象语言程序中的对象自动持久化到关系数据库中。通过ORM框架,应用程序中的类和对象可以直接映射到数据库的表和列,ORM框架负责管理实体对象和关系数据库之间的映射。开发人员无需编写复杂的SQL语句,只需要使用面向对象编程语言中的类和对象来进行数据库操作,因此减少了程序开发的工作量和开发时间。

如果你学过Java的话,可以认为我们在写一个go语言的Mybatis,具体包括以下功能:

  • 表的创建、删除、迁移。
  • 记录的增删查改,查询条件的链式操作。
  • 单一主键的设置(primary key)。
  • 钩子(在创建/更新/删除/查找之前或之后)
  • 事务(transaction)

2.玩玩Sqlite

首先准备一台Linux的虚拟机或是服务器,不出意外的话,都内置了Sqlite,在命令行中输入sqlite3

即可看到版本信息:

SQLite version 3.7.17 2013-05-20 00:56:22
Enter ".help" for instructions
Enter SQL statements terminated with a ";"

如果没有安装sqlite,一行命令即可解决:

apt-get install sqlite3

接下来,让我们输入几行命令玩一玩:

[root@VM-12-14-centos ~]# sqlite3 gee.db                                   //创建名为gee的数据库
SQLite version 3.7.17 2013-05-20 00:56:22
Enter ".help" for instructions
Enter SQL statements terminated with a ";"
sqlite> CREATE TABLE User(Name text, Age integer);                         //创建名为User的表,拥有Name、Age字段
sqlite> INSERT INTO User(Name, Age) VALUES ("Tom", 18), ("Jack", 25);      //向表中插入信息
sqlite> .head on                                                           //显示表头信息
sqlite> select * from User;                                                //查询User表中的所有内容
Name|Age
Tom|18
Jack|25
sqlite> 
sqlite> .table                                                             //查看库中拥有的表格
User

以上命令熟悉以后,就用go语言来实现吧!

import (
	"database/sql"
	"log"

	_ "github.com/mattn/go-sqlite3"
)

func main() {
	db, _ := sql.Open("sqlite3", "gee.db")
	defer func() { _ = db.Close() }()
	_, _ = db.Exec("DEOP TABLE IF EXISTS User;")
	_, _ = db.Exec("CREATE TABLE User(Name text);")
	result, err := db.Exec("INSERT INTO User(`Name`) values (?),(?)", "Tom", "Sam")
	if err == nil {
		affected, _ := result.RowsAffected()
		log.Println(affected)
	}
	row := db.QueryRow("SELECT Name FROM User LIMIT 1")
	var name string
	if err := row.Scan(&name); err == nil {
		log.Println(name)
	}
}

尽管没有任何注释,但也比较容易看懂。Open创建了sqlite的连接,Exec执行相应的sql语句,?用作占位符防止sql注入,运行结果如下所示: image.png 到此,算是熟悉了一下database/sql这个包。

3.实现一个log库

Go原生的log库没有实现日志分级、打印文件行号等操作,对于后面的调试和定位出错不太友好,因此封装一个具有如下功能的log库:

  • 支持日志分级(Info、Error、Disabled 三级)。
  • 不同层级日志显示时使用不同的颜色区分。
  • 显示打印日志代码对应的文件名和行号。

开始编写log/log.go

var (
	errorLog = log.New(os.Stdout, "\033[31m[error]\033[0m ", log.LstdFlags|log.Lshortfile)
	infoLog  = log.New(os.Stdout, "\033[34m[error]\033[0m ", log.LstdFlags|log.Lshortfile)
	loggers  = []*log.Logger{errorLog, infoLog}
	mu       sync.Mutex
)

var (
	Error  = errorLog.Println
	Errorf = errorLog.Printf
	Info   = infoLog.Println
	Infof  = infoLog.Printf
)

const (
	InfoLevel = iota
	ErrorLevel
	Disabled
)

func SetLevel(level int) {
	mu.Lock()
	defer mu.Unlock()
	for _, logger := range loggers {
		logger.SetOutput(os.Stdout)
	}
	if ErrorLevel < level {
		errorLog.SetOutput(ioutil.Discard)
	}
	if InfoLevel < level {
		infoLog.SetOutput(ioutil.Discard)
	}
}

代码中设置了[info]为蓝色,[error]为红色,并暴露ErrorErrorfInfoInfof4个方法;将Info、Error、Disabled声明为三个常量,通过Output来控制日志是否打印。至此,我们完成了log库的编写。

4.核心结构Session

下面就在session/raw.go写一下核心的数据结构Session,用于对数据库进行操作:

type Session struct {
	db      *sql.DB
	sql     strings.Builder
	sqlVars []interface{}
}

func New(db *sql.DB) *Session {
	return &Session{db: db}
}

func (s *Session) Clear() {
	s.sql.Reset()
	s.sqlVars = nil
}

func (s *Session) DB() *sql.DB {
	return s.db
}

func (s *Session) Raw(sql string, values ...interface{}) *Session {
	s.sql.WriteString(sql)
	s.sql.WriteString(" ")
	s.sqlVars = append(s.sqlVars, values...)
	return s
}

func (s *Session) Exec() (result sql.Result, err error) {
	defer s.Clear()
	log.Info(s.sql.String(), s.sqlVars)
	if result, err = s.DB().Exec(s.sql.String(), s.sqlVars...); err != nil {
		log.Error(err)
	}
	return
}

func (s *Session) QueryRow() *sql.Row {
	defer s.Clear()
	log.Info(s.sql.String(), s.sqlVars)
	return s.DB().QueryRow(s.sql.String(), s.sqlVars)
}

func (s *Session) QueryRows() (rows *sql.Rows, err error) {
	defer s.Clear()
	log.Info(s.sql.String(), s.sqlVars)
	if rows, err = s.DB().Query(s.sql.String(), s.sqlVars...); err != nil {
		log.Error(err)
	}
	return
}
  • db *sql.DB,即使用 sql.Open() 方法连接数据库成功之后返回的指针。
  • 第二个和第三个成员变量用来拼接 SQL 语句和 SQL 语句中占位符的对应值。用户调用 Raw() 方法即可改变这两个变量的值。
  • Exec()QueryRow()QueryRows()三个方法无非是对原有方法的封装,加上了日志打印部分,且每次执行完后都清空s.sqls.sqlVars,实现了开启一次会话,执行多次SQL的功能。

个人觉得这部分很好理解,就是把比较底层的东西封装了一层,封装完以后继续封装它们对应的方法。

5.核心结构Engine

Session的功能是和数据库交互,但是在交互之前,需要验证数据库是否连接成功,在交互以后,还要处理数据库连接的关闭操作,因此在geeorm.go中封装Engine解决这个问题:

type Engine struct {
	db *sql.DB
}

func NewEngine(driver, source string) (e *Engine, err error) {
	db, err := sql.Open(driver, source)
	if err != nil {
		log.Error(err)
		return
	}
	if err = db.Ping(); err != nil {
		log.Error(err)
		return
	}
	e = &Engine{db: db}
	log.Info("Connect database success")
	return
}

func (e *Engine) Close() {
	if err := e.db.Close(); err != nil {
		log.Error("Failed to close database")
	}
	log.Info("Close database success")
}

func (e *Engine) NewSession() *session.Session {
	return session.New(e.db)
}

Engine的代码也不难,NewEngine()方法返回了*sql.DB,还用db.Ping()检验了连接是否成功。NewSession()则可以通过Engine创建Session,进而与数据库进行交互。

二、对象表结构映射

这部分要解决两个问题:

(1)类型转换:Go语言中的类型和数据库中的类型是有差异的,例如Go中的intint8int16对应Sqlite中的Integer类型;

(2)对象和表的转换:把Go语言中的结构体转换为数据库中的表;

1.类型转换

为了兼容不同数据库数据类型的差异,我们将其提取出来实现最大程度的复用和解耦,下面编写dialect/dialect.go

var dialectsMap = map[string]Dialect{}

type Dialect interface {
	DataTypeOf(typ reflect.Value) string
	TableExistSQL(tableName string) (string, []interface{})
}

func RegisterDialect(name string, dialect Dialect) {
	dialectsMap[name] = dialect
}

func GetDialect(name string) (dialect Dialect, ok bool) {
	dialect, ok = dialectsMap[name]
	return
}

DataTypeOf():输入一个 reflect.Value 类型的数据,返回它在数据库中对应的数据类型的字符串。

TableExistSQL():输入表名,返回查找该表是否存在的 SQL 语句和查询参数。

为了支持不同的数据库,可以根据需要实现具体的Dialect接口,并通过RegisterDialect方法将其注册到全局的方言映射表dialectsMap中。其中,name是用于标识这个具体方言实例的字符串,dialect是实现了Dialect接口的具体方言实例。通过 GetDialect方法,可以根据name获取对应的具体方言实例。

紧接着,我们就在dialect/sqlite3.go中添加框架对于Sqlite的支持:

type sqlite3 struct {
}

var _ Dialect = (*sqlite3)(nil)

func init() {
	RegisterDialect("sqlite3", &sqlite3{})
}

func (s *sqlite3) DataTypeOf(typ reflect.Value) string {
	switch typ.Kind() {
	case reflect.Bool:
		return "bool"
	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
		return "integer"
	case reflect.Int64, reflect.Uint64:
		return "bigint"
	case reflect.Float32, reflect.Float64:
		return "real"
	case reflect.String:
		return "text"
	case reflect.Array, reflect.Slice:
		return "blob"
	case reflect.Struct:
		if _, ok := typ.Interface().(time.Time); ok {
			return "datatime"
		}
	}
	panic(fmt.Sprintf("invalid sql type %s (%s)", typ.Type().Name(), typ.Kind()))
}

func (s *sqlite3) TableExistSQL(tableName string) (string, []interface{}) {
	args := []interface{}{tableName}
	return "SELECT name FROM sqlite_master WHERE type='table' and name = ?", args
}

在定义了sqlite3类型并指定其实现Dialect接口后,代码通过init()函数将sqlite3实例注册到全局的方言映射表dialectsMap中,方言名称为sqlite3DataTypeOf()通过一系列的switch-case语句,把Go语言中的类型转换为sqlite中的类型,TableExistSQL()方法输入表名,返回一个SQL语句字符串和一个参数列表,用于后续查询数据库中是否存在该表。

下面讲一下我比较疑惑的一段代码:

var _ Dialect = (*sqlite3)(nil)

sqlite3类型转换为Dialect接口类型,并使用空指针对其进行初始化。这样一来,如果sqlite3 类型没有实现Dialect接口中的某个方法,编译器会在代码编译期间直接提示编译错误,告诉我们应该在sqlite3类型中实现哪些方法。

这个技巧被称为 "duck typing"(鸭子类型)。即“如果一个对象走起路来像鸭子,叫起来也像鸭子,那么它就可以被当作鸭子”。在这里借助了反射机制,将sqlite3类型当做Dialect类型看待,并在编译期检查了其是否实现了Dialect接口中的所有方法,实现了编译时类型检查的效果。

2.对象和表的转换

在进行对象和表的转换时,要注意以下问题:

  • 表名(table name) —— 结构体名(struct name)
  • 字段名和字段类型 —— 成员变量和类型。
  • 额外的约束条件(例如非空、主键等) —— 成员变量的Tag(Go 语言通过 Tag 实现,Java、Python 等语言通过注解实现)

接下来编写schema/schema.go实现上述功能:

type Field struct {
	Name string
	Type string
	Tag  string
}

type Schema struct {
	Model      interface{}
	Name       string
	Fields     []*Field
	FieldNames []string
	fieldMap   map[string]*Field
}

func (s *Schema) GetField(name string) *Field {
	return s.fieldMap[name]
}

func Parse(dest interface{}, d dialect.Dialect) *Schema {
	modelType := reflect.Indirect(reflect.ValueOf(dest)).Type()
	schema := &Schema{
		Model:    dest,
		Name:     modelType.Name(),
		fieldMap: make(map[string]*Field),
	}
	for i := 0; i < modelType.NumField(); i++ {
		p := modelType.Field(i)
		if !p.Anonymous && ast.IsExported(p.Name) {
			field := &Field{
				Name: p.Name,
				Type: d.DataTypeOf(reflect.Indirect(reflect.New(p.Type))),
			}
			if v, ok := p.Tag.Lookup("geeorm"); ok {
				field.Tag = v
			}
			schema.Fields = append(schema.Fields, field)
			schema.FieldNames = append(schema.FieldNames, p.Name)
			schema.fieldMap[p.Name] = field
		}
	}
	return schema
}

数据库中的表是由多个列组成的。每个列都有自己的字段名、数据类型和其他属性(如是否允许为空、是否唯一等),所以用Field 结构体表示一个字段的元信息,可以理解为对应数据库表中的列。

结构体是 Schema代表的就是数据库中的表,在它的定义中,我们使用 Fields 存储每个字段(即每个列)对应的 Field 结构体,Name 则表示了表的名字。同时,Schema 还提供了一个 GetField 方法,用于根据字段名字获取对应的 Field 结构体,便于进行 ORM 操作。

Paese()主要的作用是将传入的Go语言结构体类型dest中的每个字段作为数据库表的一列进行解析,并生成对应的Field结构体对象,将其保存在Schema结构体对象中,最终返回该Schema结构体对象。

具体实现上,首先使用reflect.ValueOf获取dest这个结构体的值,但因为设计的入参是一个对象的指针,所以需要 reflect.Indirect()获取指针指向的实例,紧接着Type()获取该指针指向的实例对应的值的类型赋给modelType,并通过 modelType.NumField() 获取dest中字段的数量。然后通过循环遍历每个字段,将字段的NameTypeTag等元信息解析出来,并创建对应的Field结构体对象,将其添加到Schema结构体中的Fields数组中,并将Name保存在Schema结构体的FieldNames数组中,同时将NameField结构体对象的指针插入到fieldMap中。

在for循环中,p.Anonymous表示p是否为匿名字段,如果是匿名字段则返回true,否则返回false。这里使用取反操作来判断p是否为非匿名字段。ast.IsExported(p.Name)判断p.Name是否为导出字段的名称,如果是导出字段则返回true,否则返回false。导出字段指的是首字母大写的字段,可以在包外被访问。

3.Session

由于新增了DialectSchemaSession的字段也要对应进行调整:

type Session struct {
	db       *sql.DB
	dialect  dialect.Dialect
	refTable *schema.Schema
	sql      strings.Builder
	sqlVars  []interface{}
}

func New(db *sql.DB,dialect dialect.Dialect) *Session {
	return &Session{
		db: db,
		dialect: dialect,
	}
}

前面在定义Session的时候设计它是对数据库进行操作的部分,因此在文件夹session下新建table.go用于放置操作数据库表相关的代码session/table.go

func (s *Session) Model(value interface{}) *Session {
	if s.refTable == nil || reflect.TypeOf(value) != reflect.TypeOf(s.refTable.Model) {
		s.refTable = schema.Parse(value, s.dialect)
	}
	return s
}

func (s *Session) RefTable() *schema.Schema {
	if s.refTable == nil {
		log.Error("Model is not set")
	}
	return s.refTable
}

func (s *Session) CreateTable() error {
	table := s.RefTable()
	var colums []string
	for _, field := range table.Fields {
		colums = append(colums, fmt.Sprintf("%s %s %s", field.Name, field.Type, field.Tag))
	}
	desc := strings.Join(colums, ",")
	_, err := s.Raw(fmt.Sprintf("CREATE TABLE %s (%s);", table.Name, desc)).Exec()
	return err
}

func (s *Session) DropTable() error {
	_, err := s.Raw(fmt.Sprintf("DROP TABLE IF EXISTS %s", s.RefTable().Name)).Exec()
	return err
}

func (s *Session) HasTable() bool {
	sql, values := s.dialect.TableExistSQL(s.RefTable().Name)
	row := s.Raw(sql, values...).QueryRow()
	var tmp string
	_ = row.Scan(&tmp)
	return tmp == s.refTable.Name
}

Model(方法用于指定「模型」(即数据库表的映射结构),该方法接收一个 value 参数,它可以是任意一个 Go 结构体对象。在方法内部,会通过反射获取 value 的类型,然后判断该类型是否与现有的 refTable 中的模型类型相同(也就是上一次执行 Model 方法时传入的结构体类型是否相同),如果不同,则调用 schema.Parse 方法对 value 进行解析,并返回其对应的模型结构体,同时将这个模型结构体保存到 Session 结构体的 refTable 字段中,以便 Session 对象执行后续的数据库操作。

RefTable()方法 用于获取会话结构体 Session 中保存的当前操作的该表的模型结构体。如果该模型结构体还未被设置(即 refTable 字段为空),则会输出错误日志并返回空值。如果该模型结构体已经被设置,则返回保存在 refTable 字段中的值。

后面三个方法的逻辑都比较相似,分别为创建表、删除表、判断是否具有表,都是利用RefTable()返回的数据库表和字段的信息,拼接出 SQL 语句,调用原生SQL接口执行。

4.Engine

同理,添加了新字段,Engine也要对应更新:

type Engine struct {
	db      *sql.DB
	dialect dialect.Dialect
}

func NewEngine(driver, source string) (e *Engine, err error) {
	db, err := sql.Open(driver, source)
	if err != nil {
		log.Error(err)
		return
	}
	if err = db.Ping(); err != nil {
		log.Error(err)
		return
	}
	dial, ok := dialect.GetDialect(driver)
	if !ok {
		log.Errorf("dialect %s Not Found", driver)
		return
	}
	e = &Engine{db: db, dialect: dial}
	log.Info("Connect database success")
	return
}

func (e *Engine) NewSession() *session.Session {
	return session.New(e.db, e.dialect)
}

具体修改为:创建Engine实例时,获取driver对应的dialect;创建Session实例时,传递dialect给构造函数New

三、记录新增和查询

这部分要实现两个功能:

  • 实现新增(insert)记录的功能。
  • 使用反射(reflect)将数据库的记录转换为对应的结构体实例,实现查询(select)功能。

1.Clause构造SQL语句

如果要写一条查询语句,那么它大概长这样:

SELECT col1, col2, ...
    FROM table_name
    WHERE [ conditions ]
    GROUP BY col1
    HAVING [ conditions ]

所以要一次性实现SQL语句的拼接会比较困难,因此将构造语句这一部分独立出来,写在clause/generator.go中:

type generator func(values ...interface{}) (string, []interface{})

var generators map[Type]generator

func init() {
	generators = make(map[Type]generator)
	generators[INSERT] = _insert
	generators[VALUES] = _values
	generators[SELECT] = _select
	generators[LIMIT] = _limit
	generators[WHERE] = _where
	generators[ORDERBY] = _orderBy
}

func genBindVars(num int) string {
	var vars []string
	for i := 0; i < num; i++ {
		vars = append(vars, "?")
	}
	return strings.Join(vars, ", ")
}

func _insert(values ...interface{}) (string, []interface{}) {
	tableName := values[0]
	fields := strings.Join(values[1].([]string), ",")
	return fmt.Sprintf("INSERT INTO %s (%v)", tableName, fields), []interface{}{}
}

func _values(values ...interface{}) (string, []interface{}) {
	var bindStr string
	var sql strings.Builder
	var vars []interface{}
	sql.WriteString("VALUES ")
	for i, value := range values {
		v := value.([]interface{})
		if bindStr == "" {
			bindStr = genBindVars(len(v))
		}
		sql.WriteString(fmt.Sprintf("(%v)", bindStr))
		if i+1 != len(values) {
			sql.WriteString(", ")
		}
		vars = append(vars, v...)
	}
	return sql.String(), vars
}

func _select(values ...interface{}) (string, []interface{}) {
	tableName := values[0]
	fields := strings.Join(values[1].([]string), ",")
	return fmt.Sprintf("SELECT %v FROM %s", fields, tableName), []interface{}{}
}

func _limit(values ...interface{}) (string, []interface{}) {
	return "LIMIT ?", values
}

func _where(values ...interface{}) (string, []interface{}) {
	desc, vars := values[0], values[1:]
	return fmt.Sprintf("WHERE %s", desc), vars
}

func _orderBy(values ...interface{}) (string, []interface{}) {
	return fmt.Sprintf("ORDER BY %s", values[0]), []interface{}{}
}

首先要说明的是,Type以及generators中的key部分会爆红,先别着急,后面会解决,下面讲解这部分代码:

变量generators是一个map,它的key表示SQL语句的一部分,如INSERTVALUESSELECT等等。而value是一个函数,用于根据输入的参数动态地生成对应的SQL语句和绑定变量列表。在init函数中,我们使用make方法初始化了generators,并将每个SQL语句的构建函数注册到generators中去。

genBindVars() 方法用于生成指定数量的SQL绑定变量(即问号)。它接收一个num参数,表示需要生成多少个绑定变量,然后使用 strings.Join 方法将多个问号连成一个逗号分隔的字符串返回。

_insert()方法接收一个变长参数列表values,其中包含要插入的表名和字段列表。函数首先通过 fmt.Sprintf 将表名和字段列表拼接成一个 INSERT INTO 语句,然后返回该语句和一个空的绑定变量列表。其中,values[1]表示函数_insert中传入的第二个参数,是一个字符串切片,存储了所有要插入的字段名。values[1].([]string)表示将values[1]转换为字符串切片类型。因为 values[1] 是一个空接口类型,需要通过类型断言将其转换为指定的类型[]string,以便访问其中的元素。strings.Join(values[1].([]string), ",") 将这个字符串切片中的所有元素用逗号 , 连接起来,并返回一个字符串类型的结果。这个结果将作为插入语句中的字段名。例如,如果values[1][]string{"id", "name", "age"},那么这个函数返回的结果就是"id,name,age"

_values()方法接收一个变长参数列表values,其中每个元素都表示要插入的记录。函数遍历所有的记录并将它们拼接成 VALUES 语句,最终返回该语句和一个包含所有绑定变量的列表。

_select()方法接收两个参数,表名和字段列表。函数通过fmt.Sprintf将表名和字段列表拼接成一个 SELECT 语句,最终返回该语句和一个空的绑定变量列表。

_limit()方法接收一个表示数量的参数values,它将这个参数包装成一个LIMIT语句,并返回该语句和一个只包含LIMIT值的绑定变量列表。

_where()方法接收一个字符串格式的查询条件和一组绑定变量。函数使用fmt.Sprintf将查询条件包装成WHERE语句,最终返回该语句和对应的绑定变量列表。

_orderBy()方法接收一个描述排序规则的字符串,将其拼接成一个ORDER BY语句。注意,该函数不需要绑定变量,因此返回一个空的绑定变量列表。

通过这些函数的组合,可以构建出包含多个 SQL 语句组成的完整的查询语句,让用户可以方便地将其应用到实际的数据库操作当中。

接着,在clause/clause.go中实现结构体Clause拼接各个独立的子句。

type Clause struct {
	sql     map[Type]string
	sqlVars map[Type][]interface{}
}

type Type int

const (
	INSERT Type = iota
	VALUES
	SELECT
	LIMIT
	WHERE
	ORDERBY
)

func (c *Clause) Set(name Type, vars ...interface{}) {
	if c.sql == nil {
		c.sql = make(map[Type]string)
		c.sqlVars = make(map[Type][]interface{})
	}
	sql, vars := generators[name](vars...)
	c.sql[name] = sql
	c.sqlVars[name] = vars
}

func (c *Clause) Build(orders ...Type) (string, []interface{}) {
	var sqls []string
	var vars []interface{}
	for _, order := range orders {
		if sql, ok := c.sql[order]; ok {
			sqls = append(sqls, sql)
			vars = append(vars, c.sqlVars[orders]...)
		}
	}
	return strings.Join(sqls, " "), vars
}

在这个文件下声明的Type和const里面的关键字,解决了generator.go中的爆红问题。

Clause为SQL语句生成器的结构体类型。该类型包含两个字段:sqlsqlVarssql用于存储生成的SQL语句,sqlVars用于存储语句中的绑定变量。

Set():用于设置各种SQL语句的函数。该方法接受一个名为name的Type值,表示要设置的SQL语句类型,以及可选的变长参数列表vars,表示该SQL语句中的参数值。该方法首先根据参数获取相应的SQL语句和绑定变量,并存储在结构体中的sqlsqlVars字段中。

Build():用于组装各种SQL语句的函数。该方法接受一个名为orders的可变长参数列表,表示要查询的SQL语句类型。该方法根据传入的orders列表,从结构体中的sqlsqlVars字段中取出各种SQL语句和绑定变量,并拼接成一个完整的SQL语句和绑定变量数组,最终返回这两个值。

2.实现Insert功能

首先为Session添加成员变量clause

type Session struct {
	db       *sql.DB
	dialect  dialect.Dialect
	refTable *schema.Schema
	clause   clause.Clause
	sql      strings.Builder
	sqlVars  []interface{}
}

func (s *Session) Clear() {
	s.sql.Reset()
	s.sqlVars = nil
	s.clause = clause.Clause{}
}

接着让我们看一下INSERT对应的Sql语句:

INSERT INTO table_name(col1, col2, col3, ...) VALUES
    (A1, A2, A3, ...),
    (B1, B2, B3, ...),
    ...

那么在调用ORM框架时,应该要这么书写代码:

s := geeorm.NewEngine("sqlite3", "gee.db").NewSession()
u1 := &User{Name: "Tom", Age: 18}
u2 := &User{Name: "Sam", Age: 25}
s.Insert(u1, u2, ...)

因此还需要一个步骤:根据数据库中列的顺序,从对象中找到对应的值,按顺序平铺。即u1u2转换为("Tom", 18), ("Same", 25)这样的格式,所以要先给Schema新增一个函数RecordValues完成上述的转换,以下为schema/schema.go

func (s *Schema) RecordValues(dest interface{}) []interface{} {
	destValue := reflect.Indirect(reflect.ValueOf(dest))
	var fieldValues []interface{}
	for _, field := range s.Fields {
		fieldValues = append(fieldValues, destValue.FieldByName(field.Name).Interface())
	}
	return fieldValues
}

此方法先通过反射获取dest的的值,然后遍历Fields中所有字段,获取每个字段在结构体实例中的值,并将其转换成Interface()类型,添加到fieldValues数组中,最终返回。

下面,在session文件夹下新建record.go,用于实现记录增删查改相关的代码:

func (s *Session) Insert(values ...interface{}) (int64, error) {
	recordValues := make([]interface{}, 0)
	for _, value := range values {
		table := s.Model(value).RefTable()
		s.clause.Set(clause.INSERT, table.Name, table.FieldNames)
		recordValues = append(recordValues, table.RecordValues(value))
	}
	s.clause.Set(clause.INSERT, recordValues...)
	sql, vars := s.clause.Build(clause.INSERT, clause.VALUES)
	result, err := s.Raw(sql, vars...).Exec()
	if err != nil {
		return 0, err
	}
	return result.RowsAffected()
}

该方法遍历每个传入的values参数,获取该结构体对象对应的表信息,并使用RefTable函数获取其映射到数据库的元数据信息,然后通过RecordValues函数处理它在数据库中对应的值集合,将处理后的结果追加到recordValues中;接着使用clause.Set函数设置INSERT操作,并将表名和字段列表传递给clause.Set函数;再将recordValues中的所有记录传递给clause.Set函数,用于设置当前的INSERT指令;然后使用clause.Build函数构建与当前INSERT操作对应的SQL语句,并获取语句中对应待绑定的变量vars;最后使用s.Raw函数执行构造后的SQL语句,并将绑定变量作为可变参数列表一同传入。

后续所有构造SQL语句的方式都将与Insert中构造 SQL 语句的方式一致。分两步:

  • 1)多次调用clause.Set()构造好每一个子句。
  • 2)调用一次clause.Build()按照传入的顺序构造出最终的 SQL 语句。

构造完成后,调用Raw().Exec()方法执行。

3.实现Find功能

在调用ORM的Find功能时,代码一般会这么写:

s := geeorm.NewEngine("sqlite3", "gee.db").NewSession()
var users []User
s.Find(&users);

Find功能和Insert恰好反了过来。Insert需要将已经存在的对象的每一个字段的值平铺开来,而Find则是需要根据平铺开的字段的值构造出对象。同样,也需要用到反射(reflect)。

func (s *Session) Find(values interface{}) error {
	destSlice := reflect.Indirect(reflect.ValueOf(values))
	destType := destSlice.Type().Elem()
	table := s.Model(reflect.New(destType).Elem().Interface()).RefTable()
	s.clause.Set(clause.SELECT, table.Name, table.FieldNames)
	sql, vars := s.clause.Build(clause.SELECT, clause.WHERE, clause.ORDERBY, clause.LIMIT)
	rows, err := s.Raw(sql, vars...).QueryRows()
	if err != nil {
		return err
	}
	for rows.Next() {
		dest := reflect.New(destType).Elem()
		var values []interface{}
		for _, name := range table.FieldNames {
			values = append(values, dest.FieldByName(name).Addr().Interface())
		}
		if err := rows.Scan(values...); err != nil {
			return err
		}
		destSlice.Set(reflect.Append(destSlice, dest))
	}
	return rows.Close()
}

该方法获取传入参数变量values的反射值,并通过该值获取要查询的结构体类型;接着获取该结构体类型对象对应的表信息,并使用 clause.Set函数设置SELECT操作,传递表名和字段列表给clause.Set函数;构建 SQL 语句,并通过调用Raw函数执行SQL查询操作,并获取查询结果集;再遍历查询结果,并使用reflect.New函数创建一个新的变量,然后通过该变量的反射值dest来获取每个字段的指针;将值的指针添加到与其对应的值的values数组中,然后使用rows.Scan函数将查询结果绑定到指定值的指针上。 reflect.Append 函数将匹配的结果添加到切片对象 destSlice中。最后返回结果集,并关闭结果集的游标。

全部评论

相关推荐

10-30 22:18
已编辑
毛坦厂中学 C++
点赞 评论 收藏
分享
点赞 收藏 评论
分享
牛客网
牛客企业服务