package sms import ( "context" "crypto/hmac" "crypto/sha1" "encoding/base64" "encoding/json" "fmt" "io" "net/http" "net/url" "sort" "strings" "sync" "sync/atomic" "time" "git.toowon.com/jimmy/go-common/config" "git.toowon.com/jimmy/go-common/logger" ) // SendResponse 发送短信响应 type SendResponse struct { RequestID string `json:"RequestId"` Code string `json:"Code"` Message string `json:"Message"` BizID string `json:"BizId"` } // SMS 短信发送器 type SMS struct { config *config.SMSConfig async bool queue chan smsTask workers int wg sync.WaitGroup closed bool mu sync.Mutex dropped atomic.Uint64 } type smsTask struct { phones []string templateParam interface{} templateCode string requestID string } // NewSMS 创建短信发送器 func NewSMS(cfg *config.Config) *SMS { if cfg == nil || cfg.SMS == nil { return &SMS{config: nil} } s := &SMS{ config: cfg.SMS, async: cfg.SMS.IsAsync(), workers: cfg.SMS.Workers, } if s.workers <= 0 { s.workers = 2 } queueSize := cfg.SMS.QueueSize if queueSize <= 0 { queueSize = 1000 } if s.async { s.queue = make(chan smsTask, queueSize) for i := 0; i < s.workers; i++ { s.wg.Add(1) go s.worker() } } return s } func (s *SMS) worker() { defer s.wg.Done() for task := range s.queue { if _, err := s.SendSMS(task.phones, task.templateParam, task.templateCode); err != nil { logger.FromContext(context.Background()).Error("async sms send failed", map[string]any{ "error": err.Error(), "request_id": task.requestID, "phones": task.phones, }) } } } func (s *SMS) getSMSConfig() (*config.SMSConfig, error) { if s.config == nil { return nil, fmt.Errorf("SMS config is nil") } if s.config.AccessKeyID == "" { return nil, fmt.Errorf("AccessKeyID is required") } if s.config.AccessKeySecret == "" { return nil, fmt.Errorf("AccessKeySecret is required") } if s.config.SignName == "" { return nil, fmt.Errorf("SignName is required") } if s.config.Region == "" { s.config.Region = "cn-hangzhou" } if s.config.Timeout == 0 { s.config.Timeout = 5 } return s.config, nil } // SendSMS 同步发送短信 func (s *SMS) SendSMS(phoneNumbers []string, templateParam interface{}, templateCode ...string) (*SendResponse, error) { cfg, err := s.getSMSConfig() if err != nil { return nil, err } if len(phoneNumbers) == 0 { return nil, fmt.Errorf("phone numbers are required") } templateCodeValue := cfg.TemplateCode if len(templateCode) > 0 && templateCode[0] != "" { templateCodeValue = templateCode[0] } if templateCodeValue == "" { return nil, fmt.Errorf("template code is required") } var templateParamJSON string if templateParam != nil { switch v := templateParam.(type) { case string: templateParamJSON = v default: paramBytes, err := json.Marshal(v) if err != nil { return nil, fmt.Errorf("failed to marshal template param: %w", err) } templateParamJSON = string(paramBytes) } } else { templateParamJSON = "{}" } params := map[string]string{ "Action": "SendSms", "Version": "2017-05-25", "RegionId": cfg.Region, "AccessKeyId": cfg.AccessKeyID, "Format": "JSON", "SignatureMethod": "HMAC-SHA1", "SignatureVersion": "1.0", "SignatureNonce": fmt.Sprint(time.Now().UnixNano()), "Timestamp": time.Now().UTC().Format("2006-01-02T15:04:05Z"), "PhoneNumbers": strings.Join(phoneNumbers, ","), "SignName": cfg.SignName, "TemplateCode": templateCodeValue, "TemplateParam": templateParamJSON, } params["Signature"] = s.calculateSignature(params, "POST", cfg.AccessKeySecret) endpoint := cfg.Endpoint if endpoint == "" { endpoint = "https://dysmsapi.aliyuncs.com" } formData := url.Values{} for k, v := range params { formData.Set(k, v) } httpReq, err := http.NewRequest(http.MethodPost, endpoint, strings.NewReader(formData.Encode())) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") httpReq.Header.Set("Accept", "application/json") client := &http.Client{Timeout: time.Duration(cfg.Timeout) * time.Second} resp, err := client.Do(httpReq) if err != nil { return nil, fmt.Errorf("failed to send request: %w", err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response: %w", err) } var sendResp SendResponse if err := json.Unmarshal(body, &sendResp); err != nil { return nil, fmt.Errorf("failed to parse response: %w", err) } if sendResp.Code != "OK" { return &sendResp, fmt.Errorf("SMS send failed: Code=%s, Message=%s", sendResp.Code, sendResp.Message) } return &sendResp, nil } // SendSMSAsync 异步发送短信 func (s *SMS) SendSMSAsync(ctx context.Context, phoneNumbers []string, templateParam interface{}, templateCode ...string) { code := "" if len(templateCode) > 0 { code = templateCode[0] } task := smsTask{ phones: append([]string(nil), phoneNumbers...), templateParam: templateParam, templateCode: code, requestID: logger.RequestIDFromContext(ctx), } if !s.async { _, _ = s.SendSMS(task.phones, task.templateParam, task.templateCode) return } select { case s.queue <- task: default: s.dropped.Add(1) logger.FromContext(ctx).Error("sms queue full, task dropped", map[string]any{ "phones": phoneNumbers, }) } } // Close 关闭异步 worker func (s *SMS) Close() error { if !s.async { return nil } s.mu.Lock() if s.closed { s.mu.Unlock() return nil } s.closed = true s.mu.Unlock() close(s.queue) s.wg.Wait() return nil } func (s *SMS) calculateSignature(params map[string]string, method, accessKeySecret string) string { keys := make([]string, 0, len(params)) for k := range params { keys = append(keys, k) } sort.Strings(keys) var queryParts []string for _, k := range keys { queryParts = append(queryParts, url.QueryEscape(k)+"="+url.QueryEscape(params[k])) } queryString := strings.Join(queryParts, "&") stringToSign := method + "&" + url.QueryEscape("/") + "&" + url.QueryEscape(queryString) mac := hmac.New(sha1.New, []byte(accessKeySecret+"&")) mac.Write([]byte(stringToSign)) return base64.StdEncoding.EncodeToString(mac.Sum(nil)) }