Files
go-common/logger/logger.go

362 lines
7.1 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package logger
import (
"context"
"encoding/json"
"fmt"
"io"
"os"
"path/filepath"
"sync"
"sync/atomic"
"time"
"git.toowon.com/jimmy/go-common/config"
)
type ctxKey int
const requestIDKey ctxKey = iota
// WithRequestID 将 Request ID 写入 context
func WithRequestID(ctx context.Context, id string) context.Context {
return context.WithValue(ctx, requestIDKey, id)
}
// RequestIDFromContext 从 context 读取 Request ID
func RequestIDFromContext(ctx context.Context) string {
if ctx == nil {
return ""
}
if id, ok := ctx.Value(requestIDKey).(string); ok {
return id
}
return ""
}
var (
defaultLogger *Logger
defaultMux sync.RWMutex
)
func init() {
l, err := NewLogger(nil)
if err != nil {
defaultLogger = nil
return
}
defaultLogger = l
}
// SetDefaultLogger 设置全局默认 logger
func SetDefaultLogger(l *Logger) {
defaultMux.Lock()
defer defaultMux.Unlock()
if defaultLogger != nil && defaultLogger != l {
_ = defaultLogger.Close()
}
defaultLogger = l
}
func getDefaultLogger() *Logger {
defaultMux.RLock()
defer defaultMux.RUnlock()
return defaultLogger
}
type logEntry struct {
level string
message string
fields map[string]any
}
// Logger 日志记录器
type Logger struct {
level string
writers []io.Writer
prefix string
async bool
logChan chan logEntry
done chan struct{}
wg sync.WaitGroup
closed bool
mu sync.RWMutex
dropped atomic.Uint64
}
// NewLogger 创建日志记录器
func NewLogger(cfg *config.LoggerConfig) (*Logger, error) {
if cfg == nil {
cfg = &config.LoggerConfig{
Level: "info",
Output: "stdout",
Async: config.BoolPtr(true),
BufferSize: 1000,
}
}
if cfg.Level == "" {
cfg.Level = "info"
}
if cfg.Output == "" {
cfg.Output = "stdout"
}
if cfg.BufferSize <= 0 {
cfg.BufferSize = 1000
}
writers, err := createWriters(cfg)
if err != nil {
return nil, err
}
prefix := ""
if cfg.Prefix != "" {
prefix = cfg.Prefix + " "
}
l := &Logger{
level: cfg.Level,
writers: writers,
prefix: prefix,
async: cfg.IsAsync(),
}
if cfg.IsAsync() {
l.logChan = make(chan logEntry, cfg.BufferSize)
l.done = make(chan struct{})
l.wg.Add(1)
go l.processLogs()
}
return l, nil
}
func createWriters(cfg *config.LoggerConfig) ([]io.Writer, error) {
var writers []io.Writer
switch cfg.Output {
case "stdout":
writers = append(writers, os.Stdout)
case "stderr":
writers = append(writers, os.Stderr)
case "file":
if cfg.FilePath == "" {
return nil, fmt.Errorf("file path is required when output is file")
}
w, err := openLogFile(cfg.FilePath)
if err != nil {
return nil, err
}
writers = append(writers, w)
case "both":
writers = append(writers, os.Stdout)
if cfg.FilePath == "" {
return nil, fmt.Errorf("file path is required when output is both")
}
w, err := openLogFile(cfg.FilePath)
if err != nil {
return nil, err
}
writers = append(writers, w)
default:
return nil, fmt.Errorf("invalid output type: %s", cfg.Output)
}
return writers, nil
}
func openLogFile(path string) (io.Writer, error) {
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0o755); err != nil {
return nil, fmt.Errorf("failed to create log directory: %w", err)
}
f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o666)
if err != nil {
return nil, fmt.Errorf("failed to open log file: %w", err)
}
return f, nil
}
func (l *Logger) processLogs() {
defer l.wg.Done()
for {
select {
case msg, ok := <-l.logChan:
if !ok {
return
}
l.writeEntry(msg)
case <-l.done:
for {
select {
case msg, ok := <-l.logChan:
if !ok {
return
}
l.writeEntry(msg)
default:
return
}
}
}
}
}
func (l *Logger) isClosed() bool {
l.mu.RLock()
defer l.mu.RUnlock()
return l.closed
}
func (l *Logger) shouldLog(level string) bool {
switch l.level {
case "debug":
return true
case "info":
return level != "debug"
case "error":
return level == "error"
default:
return level != "debug"
}
}
func (l *Logger) emit(level, message string, fields map[string]any) {
if l.isClosed() || !l.shouldLog(level) {
return
}
entry := logEntry{level: level, message: message, fields: fields}
if l.async {
select {
case l.logChan <- entry:
default:
l.dropped.Add(1)
}
return
}
l.writeEntry(entry)
}
func (l *Logger) writeEntry(entry logEntry) {
line := l.formatLine(entry.level, entry.message, entry.fields)
for _, w := range l.writers {
_, _ = fmt.Fprintln(w, line)
}
}
func (l *Logger) formatLine(level, message string, fields map[string]any) string {
ts := time.Now().Format("2006-01-02 15:04:05")
payload := map[string]any{
"time": ts,
"level": level,
"message": message,
}
for k, v := range fields {
payload[k] = v
}
b, err := json.Marshal(payload)
if err != nil {
return fmt.Sprintf("%s[%s] %s %v", l.prefix, level, message, fields)
}
return l.prefix + string(b)
}
// Debug 记录调试日志
func (l *Logger) Debug(message string, fields map[string]any) {
l.emit("debug", message, fields)
}
// Info 记录信息日志
func (l *Logger) Info(message string, fields map[string]any) {
l.emit("info", message, fields)
}
// Error 记录错误日志
func (l *Logger) Error(message string, fields map[string]any) {
l.emit("error", message, fields)
}
// Close 刷盘并关闭异步队列
func (l *Logger) Close() error {
if !l.async {
return nil
}
l.mu.Lock()
if l.closed {
l.mu.Unlock()
return nil
}
l.closed = true
l.mu.Unlock()
close(l.done)
close(l.logChan)
l.wg.Wait()
return nil
}
// DroppedCount 返回因队列满而丢弃的日志条数
func (l *Logger) DroppedCount() uint64 {
return l.dropped.Load()
}
// ContextLogger 带 context 的 logger自动附加 request_id
type ContextLogger struct {
base *Logger
ctx context.Context
}
// FromContext 从 context 获取 logger自动附加 request_id
func FromContext(ctx context.Context) *ContextLogger {
base := getDefaultLogger()
if base == nil {
if l, err := NewLogger(nil); err == nil {
base = l
}
}
return &ContextLogger{base: base, ctx: ctx}
}
// FromContextWithLogger 使用指定 logger 并从 context 附加 request_id
func FromContextWithLogger(ctx context.Context, base *Logger) *ContextLogger {
if base == nil {
base = getDefaultLogger()
}
return &ContextLogger{base: base, ctx: ctx}
}
func (c *ContextLogger) mergeFields(fields map[string]any) map[string]any {
out := make(map[string]any, len(fields)+1)
for k, v := range fields {
out[k] = v
}
if id := RequestIDFromContext(c.ctx); id != "" {
out["request_id"] = id
}
return out
}
// Debug 记录调试日志
func (c *ContextLogger) Debug(message string, fields map[string]any) {
if c.base == nil {
return
}
c.base.Debug(message, c.mergeFields(fields))
}
// Info 记录信息日志
func (c *ContextLogger) Info(message string, fields map[string]any) {
if c.base == nil {
return
}
c.base.Info(message, c.mergeFields(fields))
}
// Error 记录错误日志
func (c *ContextLogger) Error(message string, fields map[string]any) {
if c.base == nil {
return
}
c.base.Error(message, c.mergeFields(fields))
}