diff --git a/controllers/controller.go b/controllers/controller.go index 527e32e..b5ba84f 100644 --- a/controllers/controller.go +++ b/controllers/controller.go @@ -5,6 +5,7 @@ import ( "fmt" "git.rosy.net.cn/jx-print/globals" "git.rosy.net.cn/jx-print/model" + putils "git.rosy.net.cn/jx-print/utils" "github.com/dchest/captcha" "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" @@ -42,9 +43,10 @@ func callFunc(c *gin.Context, worker func() (retVal interface{}, errCode string, } else { token = cookie.Value } - if token != "token" { - err = fmt.Errorf("token 已过期,请重新登录!") + if user := putils.GetKet(token); user == nil { + err = fmt.Errorf("token过期或无效,请重新登录!") callBack.Desc = err.Error() + callBack.Code = model.ErrCodeToken c.JSON(http.StatusOK, callBack) return false } diff --git a/controllers/user_controller.go b/controllers/user_controller.go index b49275e..bcc4446 100644 --- a/controllers/user_controller.go +++ b/controllers/user_controller.go @@ -43,7 +43,7 @@ func GetUsers(c *gin.Context) { return } -type RegisterUserParam struct { +type UserParam struct { Name string `json:"name" form:"name" binding:"required"` //用户名 Password string `json:"password" form:"password" binding:"required"` //密码,md5后的 Code string `json:"code" form:"code" binding:"required"` //验证码 @@ -53,7 +53,7 @@ type RegisterUserParam struct { func RegisterUser(c *gin.Context) { var ( err error - user = &RegisterUserParam{} + user = &UserParam{} ) globals.SugarLogger.Debugf("Begin API :%s params: %v ip: %s", c.Request.URL, c.Params, c.ClientIP()) if err = c.Bind(&user); err != nil { @@ -81,17 +81,11 @@ func RegisterUser(c *gin.Context) { return } -type LoginParam struct { - Name string `json:"name" form:"name" binding:"required"` //用户名 - Password string `json:"password" form:"password" binding:"required"` //密码,md5后的 - Code string `json:"code" form:"code" binding:"required"` //验证码 -} - //登录 POST func Login(c *gin.Context) { var ( err error - user = &LoginParam{} + user = &UserParam{} ) globals.SugarLogger.Debugf("Begin API :%s params: %v ip: %s", c.Request.URL, c.Params, c.ClientIP()) if err = c.Bind(&user); err != nil { @@ -111,7 +105,85 @@ func Login(c *gin.Context) { return } if !callFunc(c, func() (retVal interface{}, errCode string, err error) { - retVal, err = services.Login(c, user.Name, user.Password, user.Code) + retVal, err = services.Login(c, user.Name, user.Password) + return retVal, "", err + }) { + return + } + return +} + +//自动登录,获取token POST +func GetTokenInfo(c *gin.Context) { + var ( + err error + user = &struct { + Token string `json:"token" form:"token"` + }{} + ) + globals.SugarLogger.Debugf("Begin API :%s params: %v ip: %s", c.Request.URL, c.Params, c.ClientIP()) + if err = c.Bind(&user); err != nil { + c.JSON(http.StatusOK, &CallBack{ + Code: model.ErrCodeNormal, + Desc: err.Error(), + }) + globals.SugarLogger.Debugf("End API :%s error:%v:", c.Request.URL, err) + return + } + if !callFunc(c, func() (retVal interface{}, errCode string, err error) { + retVal, err = services.GetTokenInfo(c, user.Token) + return retVal, "", err + }) { + return + } + return +} + +//登出,删token POST +func Logout(c *gin.Context) { + var ( + err error + user = &struct { + Token string `json:"token" form:"token"` + }{} + ) + globals.SugarLogger.Debugf("Begin API :%s params: %v ip: %s", c.Request.URL, c.Params, c.ClientIP()) + if err = c.Bind(&user); err != nil { + c.JSON(http.StatusOK, &CallBack{ + Code: model.ErrCodeNormal, + Desc: err.Error(), + }) + globals.SugarLogger.Debugf("End API :%s error:%v:", c.Request.URL, err) + return + } + if !callFunc(c, func() (retVal interface{}, errCode string, err error) { + err = services.Logout(c, user.Token) + return retVal, "", err + }) { + return + } + return +} + +//更新用户信息 POST +func UpdateUser(c *gin.Context) { + var ( + err error + user = &struct { + Payload string `json:"payload" form:"payload"` //user 的json格式数据 + }{} + ) + globals.SugarLogger.Debugf("Begin API :%s params: %v ip: %s", c.Request.URL, c.Params, c.ClientIP()) + if err = c.Bind(&user); err != nil { + c.JSON(http.StatusOK, &CallBack{ + Code: model.ErrCodeNormal, + Desc: err.Error(), + }) + globals.SugarLogger.Debugf("End API :%s error:%v:", c.Request.URL, err) + return + } + if !callFunc(c, func() (retVal interface{}, errCode string, err error) { + err = services.UpdateUser(c, user.Payload) return retVal, "", err }) { return diff --git a/dao/dao.go b/dao/dao.go index b7ccde4..9358690 100644 --- a/dao/dao.go +++ b/dao/dao.go @@ -1,6 +1,7 @@ package dao import ( + "fmt" putils "git.rosy.net.cn/jx-print/utils" "github.com/jmoiron/sqlx" "reflect" @@ -13,16 +14,17 @@ func Insert(db *sqlx.DB, obj interface{}) (err error) { var ( value = reflect.ValueOf(obj) stype = reflect.TypeOf(obj) - sname = stype.Name() sql, values = strings.Builder{}, strings.Builder{} sqlParams = []interface{}{} direct reflect.Value ) if stype.Kind() != reflect.Struct { direct = reflect.Indirect(value) + stype = stype.Elem() } else { direct = value } + sname := stype.Name() sql.WriteString("INSERT INTO ") for i := 0; i < stype.NumField()-1; i++ { if stype.Field(i).Type.String() == "*time.Time" { @@ -50,3 +52,108 @@ func Insert(db *sqlx.DB, obj interface{}) (err error) { _, err = db.DB.Exec(sql.String(), sqlParams...) return err } + +func Update(db *sqlx.DB, obj interface{}, fields ...string) (err error) { + var ( + value = reflect.ValueOf(obj) + stype = reflect.TypeOf(obj) + sql = strings.Builder{} + sqlParams = []interface{}{} + direct reflect.Value + fieldsMap = make(map[string]string) + ) + if stype.Kind() != reflect.Struct { + direct = reflect.Indirect(value) + stype = stype.Elem() + } else { + direct = value + } + sname := stype.Name() + sql.WriteString("UPDATE ") + sql.WriteString(putils.UnMarshalHr(sname) + " SET ") + fieldsStr := []string{} + for _, v := range fields { + fieldsStr = append(fieldsStr, v+"=?") + fieldsMap[v] = v + } + sql.WriteString(strings.Join(fieldsStr, ",")) + sql.WriteString(" WHERE id = ?") + for i := 0; i < stype.NumField()-1; i++ { + if fieldsMap[stype.Field(i).Tag.Get("json")] != "" { + if stype.Field(i).Type.String() == "*time.Time" { + if direct.Field(i).Interface().(*time.Time) != nil { + sqlParams = append(sqlParams, direct.Field(i).Interface()) + } + } else { + if !direct.Field(i).IsZero() { + sqlParams = append(sqlParams, direct.Field(i).Interface()) + } + } + } + } + if direct.Field(0).Int() == 0 { + return err + } else { + sqlParams = append(sqlParams, direct.Field(0).Int()) + } + _, err = db.DB.Exec(sql.String(), sqlParams...) + return err +} + +//更新两个结构体中不同的字段 +//obj是作为参数,obj2是原本的要更新的 +func UpdateDiff(db *sqlx.DB, obj interface{}, obj2 interface{}) (err error) { + var ( + value = reflect.ValueOf(obj) + stype = reflect.TypeOf(obj) + value2 = reflect.ValueOf(obj2) + stype2 = reflect.TypeOf(obj2) + sql = strings.Builder{} + sqlParams = []interface{}{} + fieldMap1 = make(map[string]interface{}) + fieldMap2 = make(map[string]interface{}) + fields = make(map[string]interface{}) + ) + if stype.Kind() != reflect.Struct { + stype = stype.Elem() + value = reflect.Indirect(value) + } + if stype2.Kind() != reflect.Struct { + stype2 = stype2.Elem() + value2 = reflect.Indirect(value2) + } + sname := stype.Name() + sname2 := stype2.Name() + if sname != sname2 { + return fmt.Errorf("请传入两个类型相同的结构体!") + } + for i := 1; i < stype.NumField()-1; i++ { + fieldMap1[stype.Field(i).Tag.Get("json")] = value.Field(i).Interface() + } + for i := 1; i < stype2.NumField()-1; i++ { + fieldMap2[stype2.Field(i).Tag.Get("json")] = value2.Field(i).Interface() + } + for k, v := range fieldMap1 { + if fieldMap2[k] != nil { + if fieldMap2[k] != v { + fields[k] = v + } + } + } + sql.WriteString("UPDATE ") + sql.WriteString(putils.UnMarshalHr(sname) + " SET ") + fieldsStr := []string{} + for k, v := range fields { + fieldsStr = append(fieldsStr, k+"=?") + sqlParams = append(sqlParams, v) + } + sql.WriteString(strings.Join(fieldsStr, ",")) + sql.WriteString(" WHERE id = ?") + if value2.Field(0).Int() == 0 { + return err + } else { + sqlParams = append(sqlParams, value2.Field(0).Int()) + } + _, err = db.DB.Exec(sql.String(), sqlParams...) + return err +} diff --git a/dao/user_dao.go b/dao/user_dao.go index dfed479..68f073a 100644 --- a/dao/user_dao.go +++ b/dao/user_dao.go @@ -29,3 +29,17 @@ func GetUsers(db *sqlx.DB, userID, name, mobile string) (users []*model.User, er } return users, err } + +func GetUserForLogin(db *sqlx.DB, name, password string) (user *model.User, err error) { + var users []*model.User + sql := ` + SELECT * + FROM user + WHERE name = ? AND password = ? + ` + sqlParams := []interface{}{name, password} + if err = db.Select(&users, sql, sqlParams...); err == nil { + return users[0], err + } + return user, err +} diff --git a/routers/router.go b/routers/router.go index 7514e34..feebea5 100644 --- a/routers/router.go +++ b/routers/router.go @@ -10,11 +10,14 @@ func Init(r *gin.Engine) { //user user := v2.Group("/user") user.GET("/getUsers", controllers.GetUsers) + user.GET("/getTokenInfo", controllers.GetTokenInfo) + user.GET("/logout", controllers.Logout) + user.GET("/updateUser", controllers.UpdateUser) //v1是不需要token的 v1 := r.Group("v1") userw := v1.Group("/user") - user.GET("/login", controllers.Login) + userw.GET("/login", controllers.Login) userw.GET("/refreshCode", controllers.RefreshCode) userw.GET("/register", controllers.RegisterUser) } diff --git a/services/user.go b/services/user.go index e048a8a..7451265 100644 --- a/services/user.go +++ b/services/user.go @@ -2,15 +2,24 @@ package services import ( "crypto/md5" + "encoding/json" "fmt" "git.rosy.net.cn/baseapi/utils" "git.rosy.net.cn/jx-print/dao" "git.rosy.net.cn/jx-print/globals" "git.rosy.net.cn/jx-print/model" + putils "git.rosy.net.cn/jx-print/utils" "github.com/gin-gonic/gin" + "strings" "time" ) +const ( + TokenHeader = "TOKEN" + TokenVer = "V2" + TokenTypeSep = "." +) + func GetUsers(c *gin.Context, name, mobile, userID string) (users []*model.User, err error) { return dao.GetUsers(globals.GetDB(), userID, name, mobile) } @@ -18,7 +27,7 @@ func GetUsers(c *gin.Context, name, mobile, userID string) (users []*model.User, func RegisterUser(c *gin.Context, name, password string) (err error) { var ( db = globals.GetDB() - user = model.User{} + user = &model.User{} now = time.Now() ) if users, _ := dao.GetUsers(db, "", name, ""); len(users) > 0 { @@ -34,10 +43,84 @@ func RegisterUser(c *gin.Context, name, password string) (err error) { return err } -func Login(c *gin.Context, name, password, code string) (user *model.User, err error) { - //var ( - // db = globals.GetDB() - //) +type LoginResult struct { + model.User + Token string `json:"token"` //token +} +func Login(c *gin.Context, name, password string) (loginResult *LoginResult, err error) { + var ( + db = globals.GetDB() + now = time.Now() + user = &model.User{} + token string + ) + loginResult = &LoginResult{} + if users, _ := dao.GetUsers(db, "", name, ""); len(users) == 0 { + return loginResult, fmt.Errorf("用户名不存在!") + } + if user, err = dao.GetUserForLogin(db, name, fmt.Sprintf("%x", md5.Sum([]byte(model.RegisterKey+password)))); err != nil { + return loginResult, err + } else if user == nil { + return loginResult, fmt.Errorf("密码错误!") + } + loginResult.User = *user + //创建token + token, err = setToken(user) + loginResult.Token = token + //更新登录时间和ip + user.LastLoginAt = &now + user.LastLoginIP = c.ClientIP() + err = dao.Update(db, user, "last_login_at", "last_login_ip") + return loginResult, err +} + +func setToken(user *model.User) (token string, err error) { + token = createToken(user) + err = putils.SetKey(token, user, putils.DefTokenDuration) + return token, err +} + +func createToken(user *model.User) (token string) { + return strings.Join([]string{ + TokenHeader, + TokenVer, + user.UserID, + time.Now().Format("20060102-150405"), + utils.GetUUID(), + user.Name, + }, TokenTypeSep) +} + +func GetTokenInfo(c *gin.Context, token string) (user *model.User, err error) { + result := putils.GetKet(token) + if user, ok := result.(*model.User); !ok { + return user, err + } return user, err } + +func Logout(c *gin.Context, token string) (err error) { + return putils.DelKey(token) +} + +func UpdateUser(c *gin.Context, payload string) (err error) { + var ( + db = globals.GetDB() + userp = &model.User{} + user = &model.User{} + ) + if err = json.Unmarshal([]byte(payload), &userp); err != nil { + return err + } + if userp.ID == 0 && userp.UserID == "" { + return fmt.Errorf("id 和 user_id 至少传一个!") + } + if users, err := dao.GetUsers(db, userp.UserID, "", ""); err != nil { + return err + } else { + user = users[0] + } + err = dao.UpdateDiff(db, userp, user) + return err +} diff --git a/utils/redis.go b/utils/redis.go new file mode 100644 index 0000000..be4e849 --- /dev/null +++ b/utils/redis.go @@ -0,0 +1,46 @@ +package utils + +import ( + "git.rosy.net.cn/baseapi/utils" + "git.rosy.net.cn/jx-print/globals" + "github.com/go-redis/redis" + "time" +) + +const ( + DefTokenDuration = time.Hour * 24 * 7 +) + +var ( + client *redis.Client +) + +func init() { + globals.SugarLogger.Debugf("redis init..") + client = redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + Password: "", + }) + pong, err := client.Ping().Result() + globals.SugarLogger.Debugf("redis pong %v, err: %v", pong, err) +} + +func SetKey(key string, value interface{}, expiration time.Duration) error { + strValue := string(utils.MustMarshal(value)) + return client.Set(key, strValue, expiration).Err() +} + +func DelKey(key string) error { + return client.Del(key).Err() +} + +func GetKet(key string) interface{} { + result, err := client.Get(key).Result() + if err == nil { + var retVal interface{} + if err = utils.UnmarshalUseNumber([]byte(result), &retVal); err == nil { + return retVal + } + } + return nil +}