package logger import ( "context" "encoding/json" "fmt" "io" "os" "path/filepath" "sync" "sync/atomic" "time" "git.toowon.com/jimmy/go-common/config" ) type ctxKey int const requestIDKey ctxKey = iota // WithRequestID 将 Request ID 写入 context func WithRequestID(ctx context.Context, id string) context.Context { return context.WithValue(ctx, requestIDKey, id) } // RequestIDFromContext 从 context 读取 Request ID func RequestIDFromContext(ctx context.Context) string { if ctx == nil { return "" } if id, ok := ctx.Value(requestIDKey).(string); ok { return id } return "" } var ( defaultLogger *Logger defaultMux sync.RWMutex ) func init() { l, err := NewLogger(nil) if err != nil { defaultLogger = nil return } defaultLogger = l } // SetDefaultLogger 设置全局默认 logger func SetDefaultLogger(l *Logger) { defaultMux.Lock() defer defaultMux.Unlock() if defaultLogger != nil && defaultLogger != l { _ = defaultLogger.Close() } defaultLogger = l } func getDefaultLogger() *Logger { defaultMux.RLock() defer defaultMux.RUnlock() return defaultLogger } type logEntry struct { level string message string fields map[string]any } // Logger 日志记录器 type Logger struct { level string writers []io.Writer prefix string async bool logChan chan logEntry done chan struct{} wg sync.WaitGroup closed bool mu sync.RWMutex dropped atomic.Uint64 } // NewLogger 创建日志记录器 func NewLogger(cfg *config.LoggerConfig) (*Logger, error) { if cfg == nil { cfg = &config.LoggerConfig{ Level: "info", Output: "stdout", Async: config.BoolPtr(true), BufferSize: 1000, } } if cfg.Level == "" { cfg.Level = "info" } if cfg.Output == "" { cfg.Output = "stdout" } if cfg.BufferSize <= 0 { cfg.BufferSize = 1000 } writers, err := createWriters(cfg) if err != nil { return nil, err } prefix := "" if cfg.Prefix != "" { prefix = cfg.Prefix + " " } l := &Logger{ level: cfg.Level, writers: writers, prefix: prefix, async: cfg.IsAsync(), } if cfg.IsAsync() { l.logChan = make(chan logEntry, cfg.BufferSize) l.done = make(chan struct{}) l.wg.Add(1) go l.processLogs() } return l, nil } func createWriters(cfg *config.LoggerConfig) ([]io.Writer, error) { var writers []io.Writer switch cfg.Output { case "stdout": writers = append(writers, os.Stdout) case "stderr": writers = append(writers, os.Stderr) case "file": if cfg.FilePath == "" { return nil, fmt.Errorf("file path is required when output is file") } w, err := openLogFile(cfg.FilePath) if err != nil { return nil, err } writers = append(writers, w) case "both": writers = append(writers, os.Stdout) if cfg.FilePath == "" { return nil, fmt.Errorf("file path is required when output is both") } w, err := openLogFile(cfg.FilePath) if err != nil { return nil, err } writers = append(writers, w) default: return nil, fmt.Errorf("invalid output type: %s", cfg.Output) } return writers, nil } func openLogFile(path string) (io.Writer, error) { dir := filepath.Dir(path) if err := os.MkdirAll(dir, 0o755); err != nil { return nil, fmt.Errorf("failed to create log directory: %w", err) } f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o666) if err != nil { return nil, fmt.Errorf("failed to open log file: %w", err) } return f, nil } func (l *Logger) processLogs() { defer l.wg.Done() for { select { case msg, ok := <-l.logChan: if !ok { return } l.writeEntry(msg) case <-l.done: for { select { case msg, ok := <-l.logChan: if !ok { return } l.writeEntry(msg) default: return } } } } } func (l *Logger) isClosed() bool { l.mu.RLock() defer l.mu.RUnlock() return l.closed } func (l *Logger) shouldLog(level string) bool { switch l.level { case "debug": return true case "info": return level != "debug" case "error": return level == "error" default: return level != "debug" } } func (l *Logger) emit(level, message string, fields map[string]any) { if l.isClosed() || !l.shouldLog(level) { return } entry := logEntry{level: level, message: message, fields: fields} if l.async { select { case l.logChan <- entry: default: l.dropped.Add(1) } return } l.writeEntry(entry) } func (l *Logger) writeEntry(entry logEntry) { line := l.formatLine(entry.level, entry.message, entry.fields) for _, w := range l.writers { _, _ = fmt.Fprintln(w, line) } } func (l *Logger) formatLine(level, message string, fields map[string]any) string { ts := time.Now().Format("2006-01-02 15:04:05") payload := map[string]any{ "time": ts, "level": level, "message": message, } for k, v := range fields { payload[k] = v } b, err := json.Marshal(payload) if err != nil { return fmt.Sprintf("%s[%s] %s %v", l.prefix, level, message, fields) } return l.prefix + string(b) } // Debug 记录调试日志 func (l *Logger) Debug(message string, fields map[string]any) { l.emit("debug", message, fields) } // Info 记录信息日志 func (l *Logger) Info(message string, fields map[string]any) { l.emit("info", message, fields) } // Error 记录错误日志 func (l *Logger) Error(message string, fields map[string]any) { l.emit("error", message, fields) } // Close 刷盘并关闭异步队列 func (l *Logger) Close() error { if !l.async { return nil } l.mu.Lock() if l.closed { l.mu.Unlock() return nil } l.closed = true l.mu.Unlock() close(l.done) close(l.logChan) l.wg.Wait() return nil } // DroppedCount 返回因队列满而丢弃的日志条数 func (l *Logger) DroppedCount() uint64 { return l.dropped.Load() } // ContextLogger 带 context 的 logger(自动附加 request_id) type ContextLogger struct { base *Logger ctx context.Context } // FromContext 从 context 获取 logger,自动附加 request_id func FromContext(ctx context.Context) *ContextLogger { base := getDefaultLogger() if base == nil { if l, err := NewLogger(nil); err == nil { base = l } } return &ContextLogger{base: base, ctx: ctx} } // FromContextWithLogger 使用指定 logger 并从 context 附加 request_id func FromContextWithLogger(ctx context.Context, base *Logger) *ContextLogger { if base == nil { base = getDefaultLogger() } return &ContextLogger{base: base, ctx: ctx} } func (c *ContextLogger) mergeFields(fields map[string]any) map[string]any { out := make(map[string]any, len(fields)+1) for k, v := range fields { out[k] = v } if id := RequestIDFromContext(c.ctx); id != "" { out["request_id"] = id } return out } // Debug 记录调试日志 func (c *ContextLogger) Debug(message string, fields map[string]any) { if c.base == nil { return } c.base.Debug(message, c.mergeFields(fields)) } // Info 记录信息日志 func (c *ContextLogger) Info(message string, fields map[string]any) { if c.base == nil { return } c.base.Info(message, c.mergeFields(fields)) } // Error 记录错误日志 func (c *ContextLogger) Error(message string, fields map[string]any) { if c.base == nil { return } c.base.Error(message, c.mergeFields(fields)) }