264 lines
6.3 KiB
Go
264 lines
6.3 KiB
Go
package sms
|
|
|
|
import (
|
|
"context"
|
|
"crypto/hmac"
|
|
"crypto/sha1"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"sort"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"git.toowon.com/jimmy/go-common/config"
|
|
"git.toowon.com/jimmy/go-common/logger"
|
|
)
|
|
|
|
// SendResponse 发送短信响应
|
|
type SendResponse struct {
|
|
RequestID string `json:"RequestId"`
|
|
Code string `json:"Code"`
|
|
Message string `json:"Message"`
|
|
BizID string `json:"BizId"`
|
|
}
|
|
|
|
// SMS 短信发送器
|
|
type SMS struct {
|
|
config *config.SMSConfig
|
|
|
|
async bool
|
|
queue chan smsTask
|
|
workers int
|
|
wg sync.WaitGroup
|
|
closed bool
|
|
mu sync.Mutex
|
|
dropped atomic.Uint64
|
|
}
|
|
|
|
type smsTask struct {
|
|
phones []string
|
|
templateParam interface{}
|
|
templateCode string
|
|
requestID string
|
|
}
|
|
|
|
// NewSMS 创建短信发送器
|
|
func NewSMS(cfg *config.Config) *SMS {
|
|
if cfg == nil || cfg.SMS == nil {
|
|
return &SMS{config: nil}
|
|
}
|
|
s := &SMS{
|
|
config: cfg.SMS,
|
|
async: cfg.SMS.IsAsync(),
|
|
workers: cfg.SMS.Workers,
|
|
}
|
|
if s.workers <= 0 {
|
|
s.workers = 2
|
|
}
|
|
queueSize := cfg.SMS.QueueSize
|
|
if queueSize <= 0 {
|
|
queueSize = 1000
|
|
}
|
|
if s.async {
|
|
s.queue = make(chan smsTask, queueSize)
|
|
for i := 0; i < s.workers; i++ {
|
|
s.wg.Add(1)
|
|
go s.worker()
|
|
}
|
|
}
|
|
return s
|
|
}
|
|
|
|
func (s *SMS) worker() {
|
|
defer s.wg.Done()
|
|
for task := range s.queue {
|
|
if _, err := s.SendSMS(task.phones, task.templateParam, task.templateCode); err != nil {
|
|
logger.FromContext(context.Background()).Error("async sms send failed", map[string]any{
|
|
"error": err.Error(),
|
|
"request_id": task.requestID,
|
|
"phones": task.phones,
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *SMS) getSMSConfig() (*config.SMSConfig, error) {
|
|
if s.config == nil {
|
|
return nil, fmt.Errorf("SMS config is nil")
|
|
}
|
|
if s.config.AccessKeyID == "" {
|
|
return nil, fmt.Errorf("AccessKeyID is required")
|
|
}
|
|
if s.config.AccessKeySecret == "" {
|
|
return nil, fmt.Errorf("AccessKeySecret is required")
|
|
}
|
|
if s.config.SignName == "" {
|
|
return nil, fmt.Errorf("SignName is required")
|
|
}
|
|
if s.config.Region == "" {
|
|
s.config.Region = "cn-hangzhou"
|
|
}
|
|
if s.config.Timeout == 0 {
|
|
s.config.Timeout = 5
|
|
}
|
|
return s.config, nil
|
|
}
|
|
|
|
// SendSMS 同步发送短信
|
|
func (s *SMS) SendSMS(phoneNumbers []string, templateParam interface{}, templateCode ...string) (*SendResponse, error) {
|
|
cfg, err := s.getSMSConfig()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(phoneNumbers) == 0 {
|
|
return nil, fmt.Errorf("phone numbers are required")
|
|
}
|
|
|
|
templateCodeValue := cfg.TemplateCode
|
|
if len(templateCode) > 0 && templateCode[0] != "" {
|
|
templateCodeValue = templateCode[0]
|
|
}
|
|
if templateCodeValue == "" {
|
|
return nil, fmt.Errorf("template code is required")
|
|
}
|
|
|
|
var templateParamJSON string
|
|
if templateParam != nil {
|
|
switch v := templateParam.(type) {
|
|
case string:
|
|
templateParamJSON = v
|
|
default:
|
|
paramBytes, err := json.Marshal(v)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to marshal template param: %w", err)
|
|
}
|
|
templateParamJSON = string(paramBytes)
|
|
}
|
|
} else {
|
|
templateParamJSON = "{}"
|
|
}
|
|
|
|
params := map[string]string{
|
|
"Action": "SendSms",
|
|
"Version": "2017-05-25",
|
|
"RegionId": cfg.Region,
|
|
"AccessKeyId": cfg.AccessKeyID,
|
|
"Format": "JSON",
|
|
"SignatureMethod": "HMAC-SHA1",
|
|
"SignatureVersion": "1.0",
|
|
"SignatureNonce": fmt.Sprint(time.Now().UnixNano()),
|
|
"Timestamp": time.Now().UTC().Format("2006-01-02T15:04:05Z"),
|
|
"PhoneNumbers": strings.Join(phoneNumbers, ","),
|
|
"SignName": cfg.SignName,
|
|
"TemplateCode": templateCodeValue,
|
|
"TemplateParam": templateParamJSON,
|
|
}
|
|
params["Signature"] = s.calculateSignature(params, "POST", cfg.AccessKeySecret)
|
|
|
|
endpoint := cfg.Endpoint
|
|
if endpoint == "" {
|
|
endpoint = "https://dysmsapi.aliyuncs.com"
|
|
}
|
|
|
|
formData := url.Values{}
|
|
for k, v := range params {
|
|
formData.Set(k, v)
|
|
}
|
|
|
|
httpReq, err := http.NewRequest(http.MethodPost, endpoint, strings.NewReader(formData.Encode()))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
|
}
|
|
httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
httpReq.Header.Set("Accept", "application/json")
|
|
|
|
client := &http.Client{Timeout: time.Duration(cfg.Timeout) * time.Second}
|
|
resp, err := client.Do(httpReq)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to send request: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read response: %w", err)
|
|
}
|
|
|
|
var sendResp SendResponse
|
|
if err := json.Unmarshal(body, &sendResp); err != nil {
|
|
return nil, fmt.Errorf("failed to parse response: %w", err)
|
|
}
|
|
if sendResp.Code != "OK" {
|
|
return &sendResp, fmt.Errorf("SMS send failed: Code=%s, Message=%s", sendResp.Code, sendResp.Message)
|
|
}
|
|
return &sendResp, nil
|
|
}
|
|
|
|
// SendSMSAsync 异步发送短信
|
|
func (s *SMS) SendSMSAsync(ctx context.Context, phoneNumbers []string, templateParam interface{}, templateCode ...string) {
|
|
code := ""
|
|
if len(templateCode) > 0 {
|
|
code = templateCode[0]
|
|
}
|
|
task := smsTask{
|
|
phones: append([]string(nil), phoneNumbers...),
|
|
templateParam: templateParam,
|
|
templateCode: code,
|
|
requestID: logger.RequestIDFromContext(ctx),
|
|
}
|
|
if !s.async {
|
|
_, _ = s.SendSMS(task.phones, task.templateParam, task.templateCode)
|
|
return
|
|
}
|
|
select {
|
|
case s.queue <- task:
|
|
default:
|
|
s.dropped.Add(1)
|
|
logger.FromContext(ctx).Error("sms queue full, task dropped", map[string]any{
|
|
"phones": phoneNumbers,
|
|
})
|
|
}
|
|
}
|
|
|
|
// Close 关闭异步 worker
|
|
func (s *SMS) Close() error {
|
|
if !s.async {
|
|
return nil
|
|
}
|
|
s.mu.Lock()
|
|
if s.closed {
|
|
s.mu.Unlock()
|
|
return nil
|
|
}
|
|
s.closed = true
|
|
s.mu.Unlock()
|
|
close(s.queue)
|
|
s.wg.Wait()
|
|
return nil
|
|
}
|
|
|
|
func (s *SMS) calculateSignature(params map[string]string, method, accessKeySecret string) string {
|
|
keys := make([]string, 0, len(params))
|
|
for k := range params {
|
|
keys = append(keys, k)
|
|
}
|
|
sort.Strings(keys)
|
|
|
|
var queryParts []string
|
|
for _, k := range keys {
|
|
queryParts = append(queryParts, url.QueryEscape(k)+"="+url.QueryEscape(params[k]))
|
|
}
|
|
queryString := strings.Join(queryParts, "&")
|
|
stringToSign := method + "&" + url.QueryEscape("/") + "&" + url.QueryEscape(queryString)
|
|
|
|
mac := hmac.New(sha1.New, []byte(accessKeySecret+"&"))
|
|
mac.Write([]byte(stringToSign))
|
|
return base64.StdEncoding.EncodeToString(mac.Sum(nil))
|
|
}
|