362 lines
7.1 KiB
Go
362 lines
7.1 KiB
Go
package logger
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"os"
|
||
"path/filepath"
|
||
"sync"
|
||
"sync/atomic"
|
||
"time"
|
||
|
||
"git.toowon.com/jimmy/go-common/config"
|
||
)
|
||
|
||
type ctxKey int
|
||
|
||
const requestIDKey ctxKey = iota
|
||
|
||
// WithRequestID 将 Request ID 写入 context
|
||
func WithRequestID(ctx context.Context, id string) context.Context {
|
||
return context.WithValue(ctx, requestIDKey, id)
|
||
}
|
||
|
||
// RequestIDFromContext 从 context 读取 Request ID
|
||
func RequestIDFromContext(ctx context.Context) string {
|
||
if ctx == nil {
|
||
return ""
|
||
}
|
||
if id, ok := ctx.Value(requestIDKey).(string); ok {
|
||
return id
|
||
}
|
||
return ""
|
||
}
|
||
|
||
var (
|
||
defaultLogger *Logger
|
||
defaultMux sync.RWMutex
|
||
)
|
||
|
||
func init() {
|
||
l, err := NewLogger(nil)
|
||
if err != nil {
|
||
defaultLogger = nil
|
||
return
|
||
}
|
||
defaultLogger = l
|
||
}
|
||
|
||
// SetDefaultLogger 设置全局默认 logger
|
||
func SetDefaultLogger(l *Logger) {
|
||
defaultMux.Lock()
|
||
defer defaultMux.Unlock()
|
||
if defaultLogger != nil && defaultLogger != l {
|
||
_ = defaultLogger.Close()
|
||
}
|
||
defaultLogger = l
|
||
}
|
||
|
||
func getDefaultLogger() *Logger {
|
||
defaultMux.RLock()
|
||
defer defaultMux.RUnlock()
|
||
return defaultLogger
|
||
}
|
||
|
||
type logEntry struct {
|
||
level string
|
||
message string
|
||
fields map[string]any
|
||
}
|
||
|
||
// Logger 日志记录器
|
||
type Logger struct {
|
||
level string
|
||
writers []io.Writer
|
||
prefix string
|
||
async bool
|
||
logChan chan logEntry
|
||
done chan struct{}
|
||
wg sync.WaitGroup
|
||
closed bool
|
||
mu sync.RWMutex
|
||
dropped atomic.Uint64
|
||
}
|
||
|
||
// NewLogger 创建日志记录器
|
||
func NewLogger(cfg *config.LoggerConfig) (*Logger, error) {
|
||
if cfg == nil {
|
||
cfg = &config.LoggerConfig{
|
||
Level: "info",
|
||
Output: "stdout",
|
||
Async: config.BoolPtr(true),
|
||
BufferSize: 1000,
|
||
}
|
||
}
|
||
|
||
if cfg.Level == "" {
|
||
cfg.Level = "info"
|
||
}
|
||
if cfg.Output == "" {
|
||
cfg.Output = "stdout"
|
||
}
|
||
if cfg.BufferSize <= 0 {
|
||
cfg.BufferSize = 1000
|
||
}
|
||
|
||
writers, err := createWriters(cfg)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
prefix := ""
|
||
if cfg.Prefix != "" {
|
||
prefix = cfg.Prefix + " "
|
||
}
|
||
|
||
l := &Logger{
|
||
level: cfg.Level,
|
||
writers: writers,
|
||
prefix: prefix,
|
||
async: cfg.IsAsync(),
|
||
}
|
||
|
||
if cfg.IsAsync() {
|
||
l.logChan = make(chan logEntry, cfg.BufferSize)
|
||
l.done = make(chan struct{})
|
||
l.wg.Add(1)
|
||
go l.processLogs()
|
||
}
|
||
|
||
return l, nil
|
||
}
|
||
|
||
func createWriters(cfg *config.LoggerConfig) ([]io.Writer, error) {
|
||
var writers []io.Writer
|
||
switch cfg.Output {
|
||
case "stdout":
|
||
writers = append(writers, os.Stdout)
|
||
case "stderr":
|
||
writers = append(writers, os.Stderr)
|
||
case "file":
|
||
if cfg.FilePath == "" {
|
||
return nil, fmt.Errorf("file path is required when output is file")
|
||
}
|
||
w, err := openLogFile(cfg.FilePath)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
writers = append(writers, w)
|
||
case "both":
|
||
writers = append(writers, os.Stdout)
|
||
if cfg.FilePath == "" {
|
||
return nil, fmt.Errorf("file path is required when output is both")
|
||
}
|
||
w, err := openLogFile(cfg.FilePath)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
writers = append(writers, w)
|
||
default:
|
||
return nil, fmt.Errorf("invalid output type: %s", cfg.Output)
|
||
}
|
||
return writers, nil
|
||
}
|
||
|
||
func openLogFile(path string) (io.Writer, error) {
|
||
dir := filepath.Dir(path)
|
||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||
return nil, fmt.Errorf("failed to create log directory: %w", err)
|
||
}
|
||
f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o666)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to open log file: %w", err)
|
||
}
|
||
return f, nil
|
||
}
|
||
|
||
func (l *Logger) processLogs() {
|
||
defer l.wg.Done()
|
||
for {
|
||
select {
|
||
case msg, ok := <-l.logChan:
|
||
if !ok {
|
||
return
|
||
}
|
||
l.writeEntry(msg)
|
||
case <-l.done:
|
||
for {
|
||
select {
|
||
case msg, ok := <-l.logChan:
|
||
if !ok {
|
||
return
|
||
}
|
||
l.writeEntry(msg)
|
||
default:
|
||
return
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
func (l *Logger) isClosed() bool {
|
||
l.mu.RLock()
|
||
defer l.mu.RUnlock()
|
||
return l.closed
|
||
}
|
||
|
||
func (l *Logger) shouldLog(level string) bool {
|
||
switch l.level {
|
||
case "debug":
|
||
return true
|
||
case "info":
|
||
return level != "debug"
|
||
case "error":
|
||
return level == "error"
|
||
default:
|
||
return level != "debug"
|
||
}
|
||
}
|
||
|
||
func (l *Logger) emit(level, message string, fields map[string]any) {
|
||
if l.isClosed() || !l.shouldLog(level) {
|
||
return
|
||
}
|
||
|
||
entry := logEntry{level: level, message: message, fields: fields}
|
||
if l.async {
|
||
select {
|
||
case l.logChan <- entry:
|
||
default:
|
||
l.dropped.Add(1)
|
||
}
|
||
return
|
||
}
|
||
l.writeEntry(entry)
|
||
}
|
||
|
||
func (l *Logger) writeEntry(entry logEntry) {
|
||
line := l.formatLine(entry.level, entry.message, entry.fields)
|
||
for _, w := range l.writers {
|
||
_, _ = fmt.Fprintln(w, line)
|
||
}
|
||
}
|
||
|
||
func (l *Logger) formatLine(level, message string, fields map[string]any) string {
|
||
ts := time.Now().Format("2006-01-02 15:04:05")
|
||
payload := map[string]any{
|
||
"time": ts,
|
||
"level": level,
|
||
"message": message,
|
||
}
|
||
for k, v := range fields {
|
||
payload[k] = v
|
||
}
|
||
b, err := json.Marshal(payload)
|
||
if err != nil {
|
||
return fmt.Sprintf("%s[%s] %s %v", l.prefix, level, message, fields)
|
||
}
|
||
return l.prefix + string(b)
|
||
}
|
||
|
||
// Debug 记录调试日志
|
||
func (l *Logger) Debug(message string, fields map[string]any) {
|
||
l.emit("debug", message, fields)
|
||
}
|
||
|
||
// Info 记录信息日志
|
||
func (l *Logger) Info(message string, fields map[string]any) {
|
||
l.emit("info", message, fields)
|
||
}
|
||
|
||
// Error 记录错误日志
|
||
func (l *Logger) Error(message string, fields map[string]any) {
|
||
l.emit("error", message, fields)
|
||
}
|
||
|
||
// Close 刷盘并关闭异步队列
|
||
func (l *Logger) Close() error {
|
||
if !l.async {
|
||
return nil
|
||
}
|
||
l.mu.Lock()
|
||
if l.closed {
|
||
l.mu.Unlock()
|
||
return nil
|
||
}
|
||
l.closed = true
|
||
l.mu.Unlock()
|
||
|
||
close(l.done)
|
||
close(l.logChan)
|
||
l.wg.Wait()
|
||
return nil
|
||
}
|
||
|
||
// DroppedCount 返回因队列满而丢弃的日志条数
|
||
func (l *Logger) DroppedCount() uint64 {
|
||
return l.dropped.Load()
|
||
}
|
||
|
||
// ContextLogger 带 context 的 logger(自动附加 request_id)
|
||
type ContextLogger struct {
|
||
base *Logger
|
||
ctx context.Context
|
||
}
|
||
|
||
// FromContext 从 context 获取 logger,自动附加 request_id
|
||
func FromContext(ctx context.Context) *ContextLogger {
|
||
base := getDefaultLogger()
|
||
if base == nil {
|
||
if l, err := NewLogger(nil); err == nil {
|
||
base = l
|
||
}
|
||
}
|
||
return &ContextLogger{base: base, ctx: ctx}
|
||
}
|
||
|
||
// FromContextWithLogger 使用指定 logger 并从 context 附加 request_id
|
||
func FromContextWithLogger(ctx context.Context, base *Logger) *ContextLogger {
|
||
if base == nil {
|
||
base = getDefaultLogger()
|
||
}
|
||
return &ContextLogger{base: base, ctx: ctx}
|
||
}
|
||
|
||
func (c *ContextLogger) mergeFields(fields map[string]any) map[string]any {
|
||
out := make(map[string]any, len(fields)+1)
|
||
for k, v := range fields {
|
||
out[k] = v
|
||
}
|
||
if id := RequestIDFromContext(c.ctx); id != "" {
|
||
out["request_id"] = id
|
||
}
|
||
return out
|
||
}
|
||
|
||
// Debug 记录调试日志
|
||
func (c *ContextLogger) Debug(message string, fields map[string]any) {
|
||
if c.base == nil {
|
||
return
|
||
}
|
||
c.base.Debug(message, c.mergeFields(fields))
|
||
}
|
||
|
||
// Info 记录信息日志
|
||
func (c *ContextLogger) Info(message string, fields map[string]any) {
|
||
if c.base == nil {
|
||
return
|
||
}
|
||
c.base.Info(message, c.mergeFields(fields))
|
||
}
|
||
|
||
// Error 记录错误日志
|
||
func (c *ContextLogger) Error(message string, fields map[string]any) {
|
||
if c.base == nil {
|
||
return
|
||
}
|
||
c.base.Error(message, c.mergeFields(fields))
|
||
}
|