package wxpay import ( "bytes" "crypto/md5" "encoding/xml" "fmt" "net/http" "sort" "strings" "git.rosy.net.cn/baseapi/platformapi" "git.rosy.net.cn/baseapi/utils" "github.com/clbanning/mxj" ) const ( prodURL = "https://api.mch.weixin.qq.com/pay" sandboxURL = "https://api.mch.weixin.qq.com/sandboxnew/pay" ) const ( ResponseCodeSuccess = "SUCCESS" ResponseCodeFail = "FAIL" sigKey = "sign" sigTypeKey = "sign_type" sigType = "MD5" ) type RequestBase struct { XMLName xml.Name `json:"-" xml:"xml"` AppID string `json:"appid" xml:"appid"` MchID string `json:"mch_id" xml:"mch_id"` NonceStr string `json:"nonce_str" xml:"nonce_str"` Sign string `json:"sign" xml:"sign"` SignType string `json:"sign_type,omitempty" xml:"sign_type,omitempty"` } type ResponseResult struct { XMLName xml.Name `json:"-" xml:"xml"` ReturnCode string `json:"return_code" xml:"return_code"` ReturnMsg string `json:"return_msg" xml:"return_msg"` AppID string `json:"appid" xml:"appid"` MchID string `json:"mch_id" xml:"mch_id"` NonceStr string `json:"nonce_str" xml:"nonce_str"` Sign string `json:"sign" xml:"sign"` SignType string `json:"sign_type,omitempty" xml:"sign_type,omitempty"` ResultCode string `json:"result_code" xml:"result_code"` ErrCode string `json:"err_code,omitempty" xml:"err_code,omitempty"` ErrCodeDes string `json:"err_code_des,omitempty" xml:"err_code_des,omitempty"` DeviceInfo string `json:"device_info,omitempty" xml:"device_info,omitempty"` OpenID string `json:"openid,omitempty" xml:"openid,omitempty"` } type OrderQueryParam struct { RequestBase TransactionID string `json:"transaction_id" xml:"transaction_id"` OutTradeNo string `json:"out_trade_no" xml:"out_trade_no"` } type API struct { appID string appKey string mchID string client *http.Client config *platformapi.APIConfig } func New(appID, appKey, mchID string, config ...*platformapi.APIConfig) *API { curConfig := platformapi.DefAPIConfig if len(config) > 0 { curConfig = *config[0] } return &API{ appID: appID, appKey: appKey, mchID: mchID, client: &http.Client{Timeout: curConfig.ClientTimeout}, config: &curConfig, } } func (a *API) GetAppID() string { return a.appID } func (a *API) GetMchID() string { return a.mchID } func (a *API) signParam(params map[string]interface{}) (sig string) { var valueList []string for k, v := range params { if k != sigKey { if str := fmt.Sprint(v); str != "" { valueList = append(valueList, fmt.Sprintf("%s=%s", k, str)) } } } sort.Sort(sort.StringSlice(valueList)) valueList = append(valueList, fmt.Sprintf("key=%s", a.appKey)) sig = strings.Join(valueList, "&") sig = fmt.Sprintf("%X", md5.Sum([]byte(sig))) return sig } func unmarshalXML(data []byte, result interface{}) error { d := xml.NewDecoder(bytes.NewReader(data)) return d.Decode(result) } func mustMarshalXML(obj interface{}) []byte { byteArr, err := xml.Marshal(obj) if err != nil { panic(fmt.Sprintf("err when Marshal obj:%v with error:%v", obj, err)) } return byteArr } func (a *API) AccessAPI(action string, params interface{}, baseParam *RequestBase) (retVal *ResponseResult, err error) { baseParam.AppID = a.appID baseParam.MchID = a.mchID baseParam.NonceStr = utils.GetUUID() baseParam.SignType = sigType baseParam.Sign = a.signParam(utils.Struct2FlatMap(params)) fullURL := utils.GenerateGetURL(prodURL, action, nil) err = platformapi.AccessPlatformAPIWithRetry(a.client, func() *http.Request { request, _ := http.NewRequest(http.MethodPost, fullURL, bytes.NewReader(mustMarshalXML(params))) return request }, a.config, func(response *http.Response, bodyStr string, jsonResult1 map[string]interface{}) (errLevel string, err error) { if jsonResult1 == nil { return platformapi.ErrLevelRecoverableErr, fmt.Errorf("mapData is nil") } retVal, errLevel, err = a.checkResult(jsonResult1[platformapi.KeyData].(string)) return errLevel, err }) return retVal, err } func (a *API) checkResult(xmlStr string) (result *ResponseResult, errLevel string, err error) { err = unmarshalXML([]byte(xmlStr), &result) if err != nil { errLevel = platformapi.ErrLevelGeneralFail } else { if result.ReturnCode != ResponseCodeSuccess { errLevel = platformapi.ErrLevelGeneralFail err = utils.NewErrorCode(result.ReturnMsg, result.ReturnCode) result = nil } else { // if result.ResultCode != ResponseCodeSuccess { // errLevel = platformapi.ErrLevelGeneralFail // err = utils.NewErrorCode(result.ErrCodeDes, result.ErrCode) // result = nil // } else { // } } } return result, errLevel, err } func (a *API) AccessAPIByMap(action string, params map[string]interface{}) (retVal map[string]interface{}, err error) { params2 := utils.MergeMaps(params, map[string]interface{}{ "appid": a.appID, "mch_id": a.mchID, "nonce_str": utils.GetUUID(), "sign_type": sigType, }) params2[sigKey] = a.signParam(params2) fullURL := utils.GenerateGetURL(prodURL, action, nil) xmlBytes, err := mxj.Map(params2).Xml("xml") if err != nil { return nil, err } err = platformapi.AccessPlatformAPIWithRetry(a.client, func() *http.Request { request, _ := http.NewRequest(http.MethodPost, fullURL, bytes.NewReader(xmlBytes)) return request }, a.config, func(response *http.Response, bodyStr string, jsonResult1 map[string]interface{}) (errLevel string, err error) { if jsonResult1 == nil { return platformapi.ErrLevelRecoverableErr, fmt.Errorf("mapData is nil") } retVal, errLevel, err = a.checkResultAsMap(jsonResult1[platformapi.KeyData].(string)) return errLevel, err }) return retVal, err } func (a *API) checkResultAsMap(xmlStr string) (result map[string]interface{}, errLevel string, err error) { mv, err := mxj.NewMapXml([]byte(xmlStr)) if err != nil { errLevel = platformapi.ErrLevelGeneralFail } else { result = mv["xml"].(map[string]interface{}) returnCode := utils.Interface2String(result["return_code"]) if returnCode != ResponseCodeSuccess { errLevel = platformapi.ErrLevelGeneralFail err = utils.NewErrorCode(utils.Interface2String(result["return_msg"]), returnCode) result = nil } else { // if utils.Interface2String(result["result_code"]) != ResponseCodeSuccess { // errLevel = platformapi.ErrLevelGeneralFail // err = utils.NewErrorCode(utils.Interface2String(result["err_code_desc"]), utils.Interface2String(result["err_code"])) // result = nil // } else { // } } } return result, errLevel, err } func (a *API) OrderQuery(transactionID, outTradeNo string) (retVal map[string]interface{}, err error) { // param := &OrderQueryParam{ // TransactionID: transactionID, // OutTradeNo: outTradeNo, // } // retVal, err = a.AccessAPI("orderquery", param, ¶m.RequestBase) param := map[string]interface{}{} if transactionID != "" { param["transaction_id"] = transactionID } else { param["out_trade_no"] = outTradeNo } retVal, err = a.AccessAPIByMap("orderquery", param) return retVal, err }