Bluemage/go/pkg/log/http_transport.go

142 lines
4.1 KiB
Go
Raw Normal View History

package log
import (
"bytes"
"context"
"encoding/json"
"io"
"log/slog"
"net/http"
"time"
"github.com/c2h5oh/datasize"
)
type RoundTripCollector struct {
Request *http.Request
// HTTP Response. May be nil.
Response *http.Response
Start time.Time
End time.Time
RequestBody *bytes.Buffer
ResponseBody *bytes.Buffer
}
func (ht RoundTripCollector) LogValue() slog.Value {
values := make([]slog.Attr, 0, 4)
if !ht.Start.IsZero() {
values = append(values, slog.Time("start", ht.Start))
}
if !ht.End.IsZero() {
values = append(values, slog.Time("end", ht.End))
}
if !ht.Start.IsZero() && !ht.End.IsZero() {
values = append(values, slog.Duration("duration", ht.End.Sub(ht.Start)))
}
if ht.Request != nil {
vals := make([]slog.Attr, 0, 5)
vals = append(vals, slog.String("url", ht.Request.URL.String()))
vals = append(vals, slog.String("method", ht.Request.Method))
headers := []slog.Attr{}
for k := range ht.Request.Header {
headers = append(headers, slog.String(k, ht.Request.Header.Get(k)))
}
if len(headers) > 0 {
vals = append(vals, slog.Attr{Key: "headers", Value: slog.GroupValue(headers...)})
}
if ht.RequestBody.Len() > 0 {
cl := datasize.ByteSize(ht.RequestBody.Len())
vals = append(vals, slog.String("content_length", cl.HumanReadable()))
if ht.Request.Header.Get("Content-Type") == "application/json" {
vals = append(vals, slog.Any("body", json.RawMessage(ht.RequestBody.Bytes())))
} else {
vals = append(vals, slog.String("body", ht.RequestBody.String()))
}
}
values = append(values, slog.Attr{Key: "request", Value: slog.GroupValue(vals...)})
}
if ht.Response != nil {
vals := make([]slog.Attr, 0, 4)
vals = append(vals, slog.Int("code", ht.Response.StatusCode))
cl := datasize.ByteSize(ht.Response.ContentLength)
vals = append(vals, slog.String("content_length", cl.HumanReadable()))
headers := []slog.Attr{}
for k := range ht.Response.Header {
headers = append(headers, slog.String(k, ht.Response.Header.Get(k)))
}
if ht.ResponseBody.Len() > 0 {
if ht.Response.Header.Get("Content-Type") == "application/json" {
vals = append(vals, slog.Any("body", json.RawMessage(ht.ResponseBody.Bytes())))
} else {
vals = append(vals, slog.String("body", ht.ResponseBody.String()))
}
}
values = append(values, slog.Attr{Key: "response", Value: slog.GroupValue(vals...)})
}
return slog.GroupValue(values...)
}
type httpLogCollectorKey struct{}
// ContextWithRoundTripCollector injects an *HTTPLogCollector into given context.
func ContextWithRoundTripCollector(ctx context.Context) (context.Context, *RoundTripCollector) {
coll := &RoundTripCollector{}
return context.WithValue(ctx, httpLogCollectorKey{}, coll), coll
}
// RoundTripCollectorFromContext gets an *HTTPLogCollector instance.
//
// Returns nil if not found.
func RoundTripCollectorFromContext(ctx context.Context) *RoundTripCollector {
coll, _ := ctx.Value(httpLogCollectorKey{}).(*RoundTripCollector)
return coll
}
type RoundTripper struct {
Next http.RoundTripper
}
// NewRoundTripper creates a new http log collector round tripper.
//
// If next is nil, uses http.DefaultTransport instead.
func NewRoundTripper(next http.RoundTripper) *RoundTripper {
if next == nil {
next = http.DefaultTransport
}
return &RoundTripper{next}
}
type bodyCloser struct {
io.Reader
close func() error
}
func (b bodyCloser) Close() error {
return b.close()
}
// RoundTrip implements http.RoundTripper
func (ht *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
coll := RoundTripCollectorFromContext(req.Context())
if coll == nil {
return ht.Next.RoundTrip(req)
}
coll.RequestBody = new(bytes.Buffer)
coll.ResponseBody = new(bytes.Buffer)
coll.Request = req
if coll.Request.Body != nil {
tee := io.TeeReader(coll.Request.Body, coll.RequestBody)
req.Body = bodyCloser{tee, req.Body.Close}
}
coll.Start = time.Now()
resp, err := ht.Next.RoundTrip(req)
coll.End = time.Now()
coll.Response = resp
if resp != nil {
coll.Request = resp.Request
tee := io.TeeReader(resp.Body, coll.ResponseBody)
resp.Body = bodyCloser{tee, resp.Body.Close}
}
return resp, err
}