自己实现一个简单Golang ORM函数库
前言
通过该项目,对go的反射有了更深入的了解。特意记录下。将要使用的sql驱动为github.com/go-sql-driver/mysql
正文
数据库初始化
任何sql操作都离不开初始化,调用sql.Open(dbType,dataSource)
;即可初始化数据库。但需要注意的是该函数是golang官方的数据库规范接口其具体实现交由第三处理。所以需要在需要初始化的包中导入并初始化第三方包的想关函数。
package db
import (
"database/sql"
"fmt"
"time"
//init sql
_ "github.com/go-sql-driver/mysql"
)
//DB db oprate instance
var DB *sql.DB
var sqlType = "mysql"
var dataSource = "root:123456@tcp(localhost)/alming"
func init() {
DB, _ = sql.Open(sqlType, dataSource)
DB.SetConnMaxLifetime(time.Minute * 3)
DB.SetMaxOpenConns(10)
DB.SetMaxIdleConns(10)
if err := DB.Ping(); err == nil {
fmt.Println("Connect success:")
} else {
fmt.Println("Connect fail:", err)
}
}
单结果查询操作
需要解决的问题是如何将Sql操作的结果集映射到struct中。首先看下常规查询操作
//以下代码基本为伪代码,未经测试仅展示流程
type User struct{
Username string
Password string
}
rows,err:=DB.Query("select * from user")
user:=new(User)
rows.Next()
rows.Scan(&user.Username,&user.Password)
可以看到,Scan()
方法接收的是指针类型参数,所以说要创建一个指针容器用于存放结果集。那么有两个问题:容器内指针是什么类型,容器的大小又是多少。这时我们需要使用rows实例的另一个函数ColumnTypes()
它返回一个[]*sql.ColumnType
其数组内元素包含每列结果的数据库类型。有了数据库类型就可以根据数据库类型创建go类型参数。而该返回值的个数也就是我们需要创建的容器大小。详情见代码
//aldb.go
//获得结果集所有列信息以创建接收结果的容器
rc, err := rows.ColumnTypes()
if err != nil {
log.Println("Get column types fail")
}
container := createContainer(rc)
if !rows.Next() {
return false
}
column, _ := rows.Columns()
rows.Scan(container...)
success := mapResult(container, column, rs)
//dbutil.go
func createContainer(columnTyes []*sql.ColumnType) (params []interface{}) {
params = make([]interface{}, len(columnTyes))
for i, ct := range columnTyes {
params[i] = createSlot(ct.DatabaseTypeName())
}
return
}
//这里也是仅列出了常用的类型,如需扩展再进行类型添加
func createSlot(dbType string) interface{} {
switch dbType {
case "INT", "TINYINT", "BIGINT":
return new(int)
case "MEDIUMINT":
case "DOUBLE":
return new(float32)
case "DECIMAL":
case "CHAR":
return new(byte)
case "VARCHAR", "TEXT", "LONGTEXT":
return &sql.NullString{String: "", Valid: true}
case "BIT":
return new(interface{})
case "DATE":
return &sql.NullString{String: "", Valid: false}
case "DATETIME":
return &sql.NullString{String: "", Valid: false}
case "TIMESTAMP":
return &sql.NullString{String: "", Valid: false}
}
return nil
}
这里有一个坑就是想要映射为golang的string类型时需要使用sql.NullString,否则当驱动扫描到一个值为NULL的列时将不会继续扫描后面的结果将会获取不到
另外单结果查询我们还需要判断结果集是否为多个,因为有些业务只允许返回一个结果集,返回多个视为错误。实现起来也非常简单
if rows.Next() {
panic("QueryOne except one result but get no more one")
}
多结果查询操作
多结果查询与单结果类似,只是在单结果上多了一个for循环
rc, err := rows.ColumnTypes()
if err != nil {
log.Println("Get column types fail")
}
column, _ := rows.Columns()
var oneMoreSet bool = false
for rows.Next() {
container := createContainer(rc)
err = rows.Scan(container...)
if err != nil {
panic("Scan rows error")
}
oneMoreSet = mapResult(container, column, rs)
}
结果集映射
可以看到前文中mapResult(container, column, rs)
即为结果集映射函数,多结果与单结果共用一个函数,内部通过if判断区分以写操作。在接下来的源码中您可能会看到toPascalCase(columns[i])
函数,该函数是一个工具函数它将sql列命映射成为Golang命名规范的变量命方便使用反射。映射规则是将首字母大写,_
后第一个字母大写,其源码为
func toPascalCase(src string) string {
var dst = make([]uint8, 0)
if src[0] > 96 && src[0] < 123 {
dst = append(dst, src[0]-32)
} else {
dst = append(dst, src[0])
}
for i := 1; i < len(src); {
if src[i] == '_' {
if src[0] > 96 && src[0] < 123 {
dst = append(dst, src[i+1]-32)
}
i += 2
} else {
dst = append(dst, src[i])
i++
}
}
return string(dst)
}
然后继续看映射部分
//mapResult 将sql rows扫描到的数据填入给定的结构中(结构体或slice)
//container :单条结果容器,columns 结果集对应数据库中的列名,value
//被映射对象
func mapResult(container []interface{}, columns []string, value reflect.Value) bool {
var slot reflect.Value
var arr = make([]reflect.Value, 0)
//判断待映射类型,结构以与slice分别处理
if value.Elem().Kind() == reflect.Struct {
slot = value.Elem()
} else {
//slice内数据类型的实例
slot = reflect.New(value.Type().Elem().Elem()).Elem()
}
var oneMoreSet = false
//遍历一行结果集找到其在结构体中的位置并赋值
for i, v := range container {
//找到对应结构体的属性
slotField := slot.FieldByName(toPascalCase(columns[i]))
if slotField.CanSet() {
switch value := v.(type) {
case *int:
//只有与其结构体类型匹配才赋值
if slotField.Kind() == reflect.Int {
slotField.SetInt(int64(*value))
}
oneMoreSet = true
case *string:
if slotField.Kind() == reflect.String {
slotField.SetString(*value)
}
oneMoreSet = true
case *sql.NullString:
if slotField.Kind() == reflect.String {
slotField.SetString(value.String)
}
oneMoreSet = true
}
}
}
//如果被映射对象是slice也就是多结果集映射要通过反射将映射出的
//结构体实例追加到结果集中
if value.Elem().Kind() == reflect.Slice {
arr = append(arr, slot)
added := reflect.Append(value.Elem(), arr...)
value.Elem().Set(added)
}
return oneMoreSet
}
插入更新操作
这部分我实现了一个自定义SQL格式,使用时需按该格式编写sql。规定:sql中参数都使用:数据库列名
代替,它看起来是下面这样
update user set username=:username where id=:id
使用时会像下面这样
u := user{
Id: 2,
Username: "alming_update",
}
Exec(&u, "update user set username=:username where id=:id")
其内部实现原理也非常简单,直接看源码
//Exec excute sql with the params in the struct you give
func Exec(structure interface{}, sqlStr string) (success bool) {
rs := reflect.ValueOf(structure)
pointTo := rs.Elem()
//自定义sql 表达式中 ?由[]:变量名]代替,找到这些变量名并由反射根据改名称获取所给
//结构体实例当中的数据作为参数传递给Exec函数
reg, _ := regexp.Compile(`:[a-zA-z_]+`)
regFind := reg.FindAllString(sqlStr, -1)
//通过反射创建参数列表的容器
params := make([]interface{}, len(regFind))
//通过自定义sql表达式获取sql
SQLParsed := reg.ReplaceAllString(sqlStr, "?")
//通过自定义sql中:找到对应的参数
for i, sqlArgs := range regFind {
parseArg := strings.TrimPrefix(sqlArgs, `:`)
fieldName := toPascalCase(parseArg)
field := pointTo.FieldByName(fieldName)
switch field.Kind() {
case reflect.Int:
//将参数添加到参数容器中
params[i] = field.Int()
case reflect.String:
params[i] = field.String()
case reflect.Float32, reflect.Float64:
params[i] = field.Float()
}
}
var res sql.Result
var err error
if len(params) > 0 {
res, err = DB.Exec(SQLParsed, params...)
} else {
res, err = DB.Exec(SQLParsed)
}
if err == nil {
rowAf, _ := res.RowsAffected()
return rowAf > 0
}
return false
}
关于一对多问题
该操作实现的过于笨重且限制较多,就不班门弄斧了。感兴趣可以看下源码。
func QueryOneToMany(slice interface{}, sqlStr string, outPk string, inPk string, params ...interface{}) (resMatched bool) {
defer catchPanic()
rs := reflect.ValueOf(slice)
pointTo := rs.Elem()
if pointTo.Kind() != reflect.Slice {
panic("QueryOne must to map to a slice,please check your structure parameter")
}
var rows *sql.Rows
var err error
if len(params) == 0 {
rows, err = DB.Query(sqlStr)
} else {
rows, err = DB.Query(sqlStr, params...)
}
if err != nil {
log.Println("An error occerred when exec query sql", err)
}
rc, err := rows.ColumnTypes()
if err != nil {
log.Println("Get column types fail")
}
column, _ := rows.Columns()
var allRows = make([][]interface{}, 0)
for rows.Next() {
container := createContainer(rc)
err = rows.Scan(container...)
if err != nil {
panic("Scan rows error")
}
allRows = append(allRows, container)
}
//outPk,对应“一”的主键,inPk对应“多”的主键
mapRes(allRows, column, rs, 0, outPk, inPk)
//别忘改
return true
}
//mapRes 将查询的结果集按一对多形式映射到结构当中
//allRows 所有结果集,columns 结果集对应数据库中的列名,value
//被映射对象,height工具属性与可变参数pk配合使用,pk(primary
//key)设计目的是为了兼容QueryOne与Query的结果集映射。实际
//这两个方法有单独的映射函数
func mapRes(allRows [][]interface{}, columns []string, value reflect.Value, height int, pk ...string) {
in := value.Elem()
inType := in.Type().Elem()
var inSlot reflect.Value
var inSlotName string
//查找给定结构的slice属性并为其
for i := 0; i < inType.NumField(); i++ {
if inType.Field(i).Type.Kind() == reflect.Slice {
//记录改属性属性名方便之后通过反射获取改属性并为其赋值
inSlotName = inType.Field(i).Name
inSlot = reflect.New(inType.Field(i).Type)
mapRes(allRows, columns, inSlot, height+1, pk...)
}
}
//mark为一个标识,以sql primary key为map,通过它标识同一元素是否被重复扫描
mark := make(map[interface{}]byte)
//主键在column中索引位置,方便获取主键值并配合mark判断是否重复扫描
var pkIdx = -1
if len(pk) > 0 {
pkIdx = getColIndex(columns, pk[height])
}
var arr = make([]reflect.Value, 0)
for _, row := range allRows {
if mark[pkValue(row[pkIdx])] == 1 {
continue
}
outSlot := reflect.New(inType).Elem()
var oneMoreSet = false
for i, v := range row {
slot := outSlot.FieldByName(toPascalCase(columns[i]))
if slot.CanSet() {
switch setValue := v.(type) {
case *int:
if slot.Kind() == reflect.Int {
slot.SetInt(int64(*setValue))
oneMoreSet = true
}
case *string:
if slot.Kind() == reflect.String {
slot.SetString(*setValue)
oneMoreSet = true
}
case *sql.NullString:
if slot.Kind() == reflect.String {
slot.SetString(setValue.String)
oneMoreSet = true
}
}
}
}
slot := outSlot.FieldByName(inSlotName)
if slot.CanSet() {
slot.Set(inSlot.Elem())
}
if oneMoreSet {
if len(pk) > 0 {
mark[pkValue(row[pkIdx])] = 1
}
}
arr = append(arr, outSlot)
}
added := reflect.Append(in, arr...)
in.Set(added)
}
func getColIndex(colunms []string, col string) int {
for idx, item := range colunms {
if item == col {
return idx
}
}
return -1
}
func pkValue(pkContent interface{}) interface{} {
switch v := pkContent.(type) {
case *int:
return *v
case *byte:
return *v
case *float32:
return *v
case *string:
return *v
case *sql.NullString:
return v.String
default:
return nil
}
}
总结
关于Go反射
go反射不像java,go必须在已有实例上进行反射。
go使用反射修改实例内容时需要反射的内容必须为指针类型(可通过
CanSet()
判断该属性是否可以赋值),并且修改时需要调用Elem()
方法获取其指向的元素。反射slice添加元素比较复杂详情见代码。
Elem()
返回指针所指向的元素,如果是数组类型则返回其内部元素的类型。可以通过
reflect.New()
创建新的实例,但与第一条不冲突(创建实例所需的类型参数由反射已有实例获得)
附录
源代码:alming_backend
一些平台禁止外链 https://github.com/ALMing530/alming_backend
进入该项目db文件夹下查看
作者:小艾咪
原文链接:https://www.jianshu.com/p/322d687aa60e