142 lines
4.1 KiB
Go
142 lines
4.1 KiB
Go
|
package log
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"context"
|
||
|
"encoding/json"
|
||
|
"io"
|
||
|
"log/slog"
|
||
|
"net/http"
|
||
|
"time"
|
||
|
|
||
|
"github.com/c2h5oh/datasize"
|
||
|
)
|
||
|
|
||
|
type RoundTripCollector struct {
|
||
|
Request *http.Request
|
||
|
// HTTP Response. May be nil.
|
||
|
Response *http.Response
|
||
|
Start time.Time
|
||
|
End time.Time
|
||
|
RequestBody *bytes.Buffer
|
||
|
ResponseBody *bytes.Buffer
|
||
|
}
|
||
|
|
||
|
func (ht RoundTripCollector) LogValue() slog.Value {
|
||
|
values := make([]slog.Attr, 0, 4)
|
||
|
if !ht.Start.IsZero() {
|
||
|
values = append(values, slog.Time("start", ht.Start))
|
||
|
}
|
||
|
if !ht.End.IsZero() {
|
||
|
values = append(values, slog.Time("end", ht.End))
|
||
|
}
|
||
|
if !ht.Start.IsZero() && !ht.End.IsZero() {
|
||
|
values = append(values, slog.Duration("duration", ht.End.Sub(ht.Start)))
|
||
|
}
|
||
|
if ht.Request != nil {
|
||
|
vals := make([]slog.Attr, 0, 5)
|
||
|
vals = append(vals, slog.String("url", ht.Request.URL.String()))
|
||
|
vals = append(vals, slog.String("method", ht.Request.Method))
|
||
|
headers := []slog.Attr{}
|
||
|
for k := range ht.Request.Header {
|
||
|
headers = append(headers, slog.String(k, ht.Request.Header.Get(k)))
|
||
|
}
|
||
|
if len(headers) > 0 {
|
||
|
vals = append(vals, slog.Attr{Key: "headers", Value: slog.GroupValue(headers...)})
|
||
|
}
|
||
|
if ht.RequestBody.Len() > 0 {
|
||
|
cl := datasize.ByteSize(ht.RequestBody.Len())
|
||
|
vals = append(vals, slog.String("content_length", cl.HumanReadable()))
|
||
|
if ht.Request.Header.Get("Content-Type") == "application/json" {
|
||
|
vals = append(vals, slog.Any("body", json.RawMessage(ht.RequestBody.Bytes())))
|
||
|
} else {
|
||
|
vals = append(vals, slog.String("body", ht.RequestBody.String()))
|
||
|
}
|
||
|
}
|
||
|
values = append(values, slog.Attr{Key: "request", Value: slog.GroupValue(vals...)})
|
||
|
}
|
||
|
if ht.Response != nil {
|
||
|
vals := make([]slog.Attr, 0, 4)
|
||
|
vals = append(vals, slog.Int("code", ht.Response.StatusCode))
|
||
|
cl := datasize.ByteSize(ht.Response.ContentLength)
|
||
|
vals = append(vals, slog.String("content_length", cl.HumanReadable()))
|
||
|
headers := []slog.Attr{}
|
||
|
for k := range ht.Response.Header {
|
||
|
headers = append(headers, slog.String(k, ht.Response.Header.Get(k)))
|
||
|
}
|
||
|
if ht.ResponseBody.Len() > 0 {
|
||
|
if ht.Response.Header.Get("Content-Type") == "application/json" {
|
||
|
vals = append(vals, slog.Any("body", json.RawMessage(ht.ResponseBody.Bytes())))
|
||
|
} else {
|
||
|
vals = append(vals, slog.String("body", ht.ResponseBody.String()))
|
||
|
}
|
||
|
}
|
||
|
values = append(values, slog.Attr{Key: "response", Value: slog.GroupValue(vals...)})
|
||
|
}
|
||
|
return slog.GroupValue(values...)
|
||
|
}
|
||
|
|
||
|
type httpLogCollectorKey struct{}
|
||
|
|
||
|
// ContextWithRoundTripCollector injects an *HTTPLogCollector into given context.
|
||
|
func ContextWithRoundTripCollector(ctx context.Context) (context.Context, *RoundTripCollector) {
|
||
|
coll := &RoundTripCollector{}
|
||
|
return context.WithValue(ctx, httpLogCollectorKey{}, coll), coll
|
||
|
}
|
||
|
|
||
|
// RoundTripCollectorFromContext gets an *HTTPLogCollector instance.
|
||
|
//
|
||
|
// Returns nil if not found.
|
||
|
func RoundTripCollectorFromContext(ctx context.Context) *RoundTripCollector {
|
||
|
coll, _ := ctx.Value(httpLogCollectorKey{}).(*RoundTripCollector)
|
||
|
return coll
|
||
|
}
|
||
|
|
||
|
type RoundTripper struct {
|
||
|
Next http.RoundTripper
|
||
|
}
|
||
|
|
||
|
// NewRoundTripper creates a new http log collector round tripper.
|
||
|
//
|
||
|
// If next is nil, uses http.DefaultTransport instead.
|
||
|
func NewRoundTripper(next http.RoundTripper) *RoundTripper {
|
||
|
if next == nil {
|
||
|
next = http.DefaultTransport
|
||
|
}
|
||
|
return &RoundTripper{next}
|
||
|
}
|
||
|
|
||
|
type bodyCloser struct {
|
||
|
io.Reader
|
||
|
close func() error
|
||
|
}
|
||
|
|
||
|
func (b bodyCloser) Close() error {
|
||
|
return b.close()
|
||
|
}
|
||
|
|
||
|
// RoundTrip implements http.RoundTripper
|
||
|
func (ht *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||
|
coll := RoundTripCollectorFromContext(req.Context())
|
||
|
if coll == nil {
|
||
|
return ht.Next.RoundTrip(req)
|
||
|
}
|
||
|
coll.RequestBody = new(bytes.Buffer)
|
||
|
coll.ResponseBody = new(bytes.Buffer)
|
||
|
coll.Request = req
|
||
|
if coll.Request.Body != nil {
|
||
|
tee := io.TeeReader(coll.Request.Body, coll.RequestBody)
|
||
|
req.Body = bodyCloser{tee, req.Body.Close}
|
||
|
}
|
||
|
coll.Start = time.Now()
|
||
|
resp, err := ht.Next.RoundTrip(req)
|
||
|
coll.End = time.Now()
|
||
|
coll.Response = resp
|
||
|
if resp != nil {
|
||
|
coll.Request = resp.Request
|
||
|
tee := io.TeeReader(resp.Body, coll.ResponseBody)
|
||
|
resp.Body = bodyCloser{tee, resp.Body.Close}
|
||
|
}
|
||
|
return resp, err
|
||
|
}
|