package log import ( "bytes" "context" "encoding/json" "io" "log/slog" "net/http" "time" "github.com/c2h5oh/datasize" ) type RoundTripCollector struct { Request *http.Request // HTTP Response. May be nil. Response *http.Response Start time.Time End time.Time RequestBody *bytes.Buffer ResponseBody *bytes.Buffer } func (ht RoundTripCollector) LogValue() slog.Value { values := make([]slog.Attr, 0, 4) if !ht.Start.IsZero() { values = append(values, slog.Time("start", ht.Start)) } if !ht.End.IsZero() { values = append(values, slog.Time("end", ht.End)) } if !ht.Start.IsZero() && !ht.End.IsZero() { values = append(values, slog.Duration("duration", ht.End.Sub(ht.Start))) } if ht.Request != nil { vals := make([]slog.Attr, 0, 5) vals = append(vals, slog.String("url", ht.Request.URL.String())) vals = append(vals, slog.String("method", ht.Request.Method)) headers := []slog.Attr{} for k := range ht.Request.Header { headers = append(headers, slog.String(k, ht.Request.Header.Get(k))) } if len(headers) > 0 { vals = append(vals, slog.Attr{Key: "headers", Value: slog.GroupValue(headers...)}) } if ht.RequestBody.Len() > 0 { cl := datasize.ByteSize(ht.RequestBody.Len()) vals = append(vals, slog.String("content_length", cl.HumanReadable())) if ht.Request.Header.Get("Content-Type") == "application/json" { vals = append(vals, slog.Any("body", json.RawMessage(ht.RequestBody.Bytes()))) } else { vals = append(vals, slog.String("body", ht.RequestBody.String())) } } values = append(values, slog.Attr{Key: "request", Value: slog.GroupValue(vals...)}) } if ht.Response != nil { vals := make([]slog.Attr, 0, 4) vals = append(vals, slog.Int("code", ht.Response.StatusCode)) cl := datasize.ByteSize(ht.Response.ContentLength) vals = append(vals, slog.String("content_length", cl.HumanReadable())) headers := []slog.Attr{} for k := range ht.Response.Header { headers = append(headers, slog.String(k, ht.Response.Header.Get(k))) } if ht.ResponseBody.Len() > 0 { if ht.Response.Header.Get("Content-Type") == "application/json" { vals = append(vals, slog.Any("body", json.RawMessage(ht.ResponseBody.Bytes()))) } else { vals = append(vals, slog.String("body", ht.ResponseBody.String())) } } values = append(values, slog.Attr{Key: "response", Value: slog.GroupValue(vals...)}) } return slog.GroupValue(values...) } type httpLogCollectorKey struct{} // ContextWithRoundTripCollector injects an *HTTPLogCollector into given context. func ContextWithRoundTripCollector(ctx context.Context) (context.Context, *RoundTripCollector) { coll := &RoundTripCollector{} return context.WithValue(ctx, httpLogCollectorKey{}, coll), coll } // RoundTripCollectorFromContext gets an *HTTPLogCollector instance. // // Returns nil if not found. func RoundTripCollectorFromContext(ctx context.Context) *RoundTripCollector { coll, _ := ctx.Value(httpLogCollectorKey{}).(*RoundTripCollector) return coll } type RoundTripper struct { Next http.RoundTripper } // NewRoundTripper creates a new http log collector round tripper. // // If next is nil, uses http.DefaultTransport instead. func NewRoundTripper(next http.RoundTripper) *RoundTripper { if next == nil { next = http.DefaultTransport } return &RoundTripper{next} } type bodyCloser struct { io.Reader close func() error } func (b bodyCloser) Close() error { return b.close() } // RoundTrip implements http.RoundTripper func (ht *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { coll := RoundTripCollectorFromContext(req.Context()) if coll == nil { return ht.Next.RoundTrip(req) } coll.RequestBody = new(bytes.Buffer) coll.ResponseBody = new(bytes.Buffer) coll.Request = req if coll.Request.Body != nil { tee := io.TeeReader(coll.Request.Body, coll.RequestBody) req.Body = bodyCloser{tee, req.Body.Close} } coll.Start = time.Now() resp, err := ht.Next.RoundTrip(req) coll.End = time.Now() coll.Response = resp if resp != nil { coll.Request = resp.Request tee := io.TeeReader(resp.Body, coll.ResponseBody) resp.Body = bodyCloser{tee, resp.Body.Close} } return resp, err }