diff --git a/pkg/middlewares/request_id.go b/pkg/middlewares/request_id.go new file mode 100644 index 00000000..b8f01f9b --- /dev/null +++ b/pkg/middlewares/request_id.go @@ -0,0 +1,27 @@ +package middlewares + +import ( + "github.com/mayswind/lab/pkg/core" + "github.com/mayswind/lab/pkg/requestid" + "github.com/mayswind/lab/pkg/settings" +) + +const REQUEST_ID_HEADER = "X-Request-ID" + +func RequestId(config *settings.Config) core.MiddlewareHandlerFunc { + return func (c *core.Context) { + if requestid.Container.Current == nil { + c.Next() + return + } + + requestId := requestid.Container.Current.GenerateRequestId(c.ClientIP()) + c.SetRequestId(requestId) + + if config.EnableRequestIdHeader { + c.Header(REQUEST_ID_HEADER, requestId) + } + + c.Next() + } +} diff --git a/pkg/requestid/default_request_id_generator.go b/pkg/requestid/default_request_id_generator.go new file mode 100644 index 00000000..41af9618 --- /dev/null +++ b/pkg/requestid/default_request_id_generator.go @@ -0,0 +1,234 @@ +package requestid + +import ( + "bytes" + "encoding/binary" + "encoding/hex" + "fmt" + "hash/crc32" + "math" + "net" + "sync/atomic" + "time" + + "github.com/mayswind/lab/pkg/errs" + "github.com/mayswind/lab/pkg/log" + "github.com/mayswind/lab/pkg/settings" + "github.com/mayswind/lab/pkg/utils" +) + +const REQUEST_ID_LENGTH = 36 +const SECONDS_TODAY_BITS = 17 +const SECONDS_TODAY_BITS_MASK = (1 << SECONDS_TODAY_BITS) - 1 +const RANDOM_NUMBER_BITS = 15 +const RANDOM_NUMBER_BITS_MASK = (1 << RANDOM_NUMBER_BITS) - 1 +const REQ_SEQ_NUMBER_BITS = 31 +const REQ_SEQ_NUMBER_BITS_MASK = (1 << REQ_SEQ_NUMBER_BITS) - 1 +const CLIENT_IPV6_BIT = 1 +const CLIENT_IPV6_BIT_MASK = 1 + +type RequestIdInfo struct { + ServerUniqId uint16 + InstanceUniqId uint16 + SecondsElapsedToday uint32 + RandomNumber uint32 + RequestSeqId uint32 + IsClientIpv6 bool + ClientIp uint32 +} + +type DefaultRequestIdGenerator struct { + serverUniqId uint16 + instanceUniqId uint16 + requestSeqId uint32 +} + +func NewDefaultRequestIdGenerator(config *settings.Config) (*DefaultRequestIdGenerator, error) { + serverUniqId, err := getServerUniqId(config) + + if err != nil { + return nil, err + } + + instanceUniqId := getInstanceUniqId(config) + + generator := &DefaultRequestIdGenerator{ + serverUniqId: serverUniqId, + instanceUniqId: instanceUniqId, + } + + return generator, nil +} + +func getServerUniqId(config *settings.Config) (uint16, error) { + localAddr := "" + settingAddr := net.ParseIP(config.HttpAddr) + + if settingAddr != nil && !settingAddr.IsUnspecified() { + localAddr = settingAddr.String() + } else { + var err error + localAddr, err = utils.GetLocalIPAddressesString() + + if err != nil { + log.Warnf("[default_request_id_generator.getServerUniqId] failed to get local ipv4 address, because %s", err.Error()) + return 0, err + } + } + + serverUniqFlag := fmt.Sprintf("%s_%s", localAddr, config.SecretKey) + + return uint16(crc32.ChecksumIEEE([]byte(serverUniqFlag))), nil +} + +func getInstanceUniqId(config *settings.Config) uint16 { + var instanceUniqFlag string + + if config.Protocol == settings.SCHEME_SOCKET { + instanceUniqFlag = fmt.Sprintf("%s_%s", config.UnixSocketPath, config.SecretKey) + } else { + instanceUniqFlag = fmt.Sprintf("%d_%s", config.HttpPort, config.SecretKey) + } + + return uint16(crc32.ChecksumIEEE([]byte(instanceUniqFlag))) + +} + +func (r *DefaultRequestIdGenerator) ParseRequestIdInfo(requestId string) (*RequestIdInfo, error) { + if requestId == "" || len(requestId) != REQUEST_ID_LENGTH { + return nil, errs.ErrRequestIdInvalid + } + + requestIdData := r.parseRequestIdFromUuid(requestId) + return r.parseRequestIdInfo(requestIdData), nil +} + +func (r *DefaultRequestIdGenerator) GetCurrentServerUniqId() uint16 { + return r.serverUniqId +} + +func (r *DefaultRequestIdGenerator) GetCurrentInstanceUniqId() uint16 { + return r.instanceUniqId +} + +func (r *DefaultRequestIdGenerator) GenerateRequestId(clientIpAddr string) string { + ip := net.ParseIP(clientIpAddr) + isClientIpv6 := ip.To4() == nil + var clientIp uint32 + + if isClientIpv6 { + clientIp = crc32.ChecksumIEEE([]byte(ip.String())) + } else { + clientIp = binary.BigEndian.Uint32(ip.To4()) + } + + requestId := r.getRequestId(r.serverUniqId, r.instanceUniqId, isClientIpv6, clientIp) + + return requestId +} + +func (r *DefaultRequestIdGenerator) getRequestId(serverUniqId uint16, instanceUniqId uint16, clientIpV6 bool, clientIp uint32) string { + clientIpv6Flag := uint32(0) + + if clientIpV6 { + clientIpv6Flag = uint32(1) + } + + // 128bits = serverUniqId(16bits) + instanceUniqId(16bits) + secondsElapsedToday(17bits) + randomNumber(15bits) + sequentialNumber(31bits) + clientIpv6Flag(1bit) + clientIp(32bits) + + secondsElapsedToday := r.getSecondsElapsedToday() + secondsLow17bits := uint32(secondsElapsedToday & SECONDS_TODAY_BITS_MASK) + + randomNumber, _ := utils.GetRandomInteger(math.MaxInt16) + randomNumberLow15bits := uint32(randomNumber & RANDOM_NUMBER_BITS_MASK) + + secondsAndRandomNumber := (secondsLow17bits << RANDOM_NUMBER_BITS) | randomNumberLow15bits + + seqId := atomic.AddUint32(&r.requestSeqId, 1) + seqIdLow31bits := seqId & REQ_SEQ_NUMBER_BITS_MASK + + seqIdAndClientIpv6Flag := (seqIdLow31bits << CLIENT_IPV6_BIT) | (clientIpv6Flag & CLIENT_IPV6_BIT_MASK) + + buf := &bytes.Buffer{} + _ = binary.Write(buf, binary.BigEndian, serverUniqId) + _ = binary.Write(buf, binary.BigEndian, instanceUniqId) + _ = binary.Write(buf, binary.BigEndian, secondsAndRandomNumber) + _ = binary.Write(buf, binary.BigEndian, seqIdAndClientIpv6Flag) + _ = binary.Write(buf, binary.BigEndian, clientIp) + + return r.getUuidFromRequestId(buf) +} + +func (r *DefaultRequestIdGenerator) getSecondsElapsedToday() int { + now := time.Now() + seconds := now.Hour()*24*60 + now.Minute()*60 + now.Second() + + return seconds +} + +func (r *DefaultRequestIdGenerator) getUuidFromRequestId(buffer *bytes.Buffer) string { + data := buffer.Bytes() + result := make([]byte, 36) + + hex.Encode(result[0:8], data[0:4]) + result[8] = '-' + hex.Encode(result[9:13], data[4:6]) + result[13] = '-' + hex.Encode(result[14:18], data[6:8]) + result[18] = '-' + hex.Encode(result[19:23], data[8:10]) + result[23] = '-' + hex.Encode(result[24:], data[10:]) + + return string(result) +} + +func (r *DefaultRequestIdGenerator) parseRequestIdInfo(data []byte) *RequestIdInfo { + buf := bytes.NewBuffer(data) + + var serverUniqId uint16 + var instanceUniqId uint16 + var secondsAndRandomNumber uint32 + var seqIdAndClientIpv6Flag uint32 + var clientIp uint32 + + _ = binary.Read(buf, binary.BigEndian, &serverUniqId) + _ = binary.Read(buf, binary.BigEndian, &instanceUniqId) + _ = binary.Read(buf, binary.BigEndian, &secondsAndRandomNumber) + _ = binary.Read(buf, binary.BigEndian, &seqIdAndClientIpv6Flag) + _ = binary.Read(buf, binary.BigEndian, &clientIp) + + secondsElapsedToday := (secondsAndRandomNumber >> RANDOM_NUMBER_BITS) & SECONDS_TODAY_BITS_MASK + randomNumber := (secondsAndRandomNumber & RANDOM_NUMBER_BITS_MASK) + + seqId := (seqIdAndClientIpv6Flag >> CLIENT_IPV6_BIT) & REQ_SEQ_NUMBER_BITS_MASK + isClientIpv6Flag := (seqIdAndClientIpv6Flag & CLIENT_IPV6_BIT_MASK) + isClientIpv6 := false + + if isClientIpv6Flag == 1 { + isClientIpv6 = true + } + + return &RequestIdInfo{ + ServerUniqId: serverUniqId, + InstanceUniqId: instanceUniqId, + SecondsElapsedToday: secondsElapsedToday, + RequestSeqId: seqId, + RandomNumber: randomNumber, + IsClientIpv6: isClientIpv6, + ClientIp: clientIp, + } +} + +func (r *DefaultRequestIdGenerator) parseRequestIdFromUuid(uuid string) []byte { + data := []byte(uuid) + result := make([]byte, 16) + + _, _ = hex.Decode(result[0:4], data[0:8]) + _, _ = hex.Decode(result[4:6], data[9:13]) + _, _ = hex.Decode(result[6:8], data[14:18]) + _, _ = hex.Decode(result[8:10], data[19:23]) + _, _ = hex.Decode(result[10:], data[24:]) + + return result +} diff --git a/pkg/requestid/default_request_id_generator_test.go b/pkg/requestid/default_request_id_generator_test.go new file mode 100644 index 00000000..da203bf5 --- /dev/null +++ b/pkg/requestid/default_request_id_generator_test.go @@ -0,0 +1,100 @@ +package requestid + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/mayswind/lab/pkg/settings" +) + +func TestNewDefaultRequestIdGenerator_Http(t *testing.T) { + generator, _ := NewDefaultRequestIdGenerator(&settings.Config{HttpAddr: "123.234.123.234", HttpPort: 8080, SecretKey: "secretkey"}) + requestId := generator.GenerateRequestId("127.0.0.1") + requestIdInfo := generator.parseRequestIdInfo(generator.parseRequestIdFromUuid(requestId)) + + expectedServerUniqId := uint16(0x2476) // crc32("123.234.123.234" + "_" + "secretkey") & 0xFFFF + actualServerUniqId := requestIdInfo.ServerUniqId + assert.Equal(t, expectedServerUniqId, actualServerUniqId) + + expectedInstanceUniqId := uint16(0x0e79) // crc32("8080" + "_" + "secretkey") & 0xFFFF + actualInstanceUniqId := requestIdInfo.InstanceUniqId + assert.Equal(t, expectedInstanceUniqId, actualInstanceUniqId) +} + +func TestNewDefaultRequestIdGenerator_UnixSocket(t *testing.T) { + generator, _ := NewDefaultRequestIdGenerator(&settings.Config{HttpAddr: "1.2.3.4", UnixSocketPath: "/var/lib/labapp/lab.sock", Protocol: "socket", SecretKey: "secretkey"}) + requestId := generator.GenerateRequestId("127.0.0.1") + requestIdInfo := generator.parseRequestIdInfo(generator.parseRequestIdFromUuid(requestId)) + + expectedServerUniqId := uint16(0x5bdb) // crc32("1.2.3.4" + "_" + "secretkey") & 0xFFFF + actualServerUniqId := requestIdInfo.ServerUniqId + assert.Equal(t, expectedServerUniqId, actualServerUniqId) + + expectedInstanceUniqId := uint16(0x694b) // crc32("/var/lib/labapp/lab.sock" + "_" + "secretkey") & 0xFFFF + actualInstanceUniqId := requestIdInfo.InstanceUniqId + assert.Equal(t, expectedInstanceUniqId, actualInstanceUniqId) +} + +func TestNewDefaultRequestIdGenerator_ClientIpv4(t *testing.T) { + generator, _ := NewDefaultRequestIdGenerator(&settings.Config{HttpAddr: "1.2.3.4", UnixSocketPath: "/var/lib/labapp/lab.sock", Protocol: "socket", SecretKey: "secretkey"}) + requestId := generator.GenerateRequestId("127.0.0.1") + requestIdInfo := generator.parseRequestIdInfo(generator.parseRequestIdFromUuid(requestId)) + + expectedClientIp := uint32(0x7f000001) // 127.0.0.1 + actualClientIp := requestIdInfo.ClientIp + assert.Equal(t, expectedClientIp, actualClientIp) + + expectedClientIpv6 := false + actualClientIpv6 := requestIdInfo.IsClientIpv6 + assert.Equal(t, expectedClientIpv6, actualClientIpv6) + + requestId = generator.GenerateRequestId("192.168.1.100") + requestIdInfo = generator.parseRequestIdInfo(generator.parseRequestIdFromUuid(requestId)) + + expectedClientIp = uint32(0xc0a80164) // 192.168.1.100 + actualClientIp = requestIdInfo.ClientIp + assert.Equal(t, expectedClientIp, actualClientIp) + + expectedClientIpv6 = false + actualClientIpv6 = requestIdInfo.IsClientIpv6 + assert.Equal(t, expectedClientIpv6, actualClientIpv6) +} + +func TestNewDefaultRequestIdGenerator_ClientIpv6(t *testing.T) { + generator, _ := NewDefaultRequestIdGenerator(&settings.Config{HttpAddr: "1.2.3.4", UnixSocketPath: "/var/lib/labapp/lab.sock", Protocol: "socket", SecretKey: "secretkey"}) + requestId := generator.GenerateRequestId("2001:abc:def:1234::1") + requestIdInfo := generator.parseRequestIdInfo(generator.parseRequestIdFromUuid(requestId)) + + expectedClientIp := uint32(0x76fe1b98) // crc32("2001:abc:def:1234::1") + actualClientIp := requestIdInfo.ClientIp + assert.Equal(t, expectedClientIp, actualClientIp) + + expectedClientIpv6 := true + actualClientIpv6 := requestIdInfo.IsClientIpv6 + assert.Equal(t, expectedClientIpv6, actualClientIpv6) + + requestId = generator.GenerateRequestId("2400:abcd:1234:1:56ef:ab78:c90d:1e2f") + requestIdInfo = generator.parseRequestIdInfo(generator.parseRequestIdFromUuid(requestId)) + + expectedClientIp = uint32(0xa0a25faa) // crc32("2400:abcd:1234:1:56ef:ab78:c90d:1e2f") + actualClientIp = requestIdInfo.ClientIp + assert.Equal(t, expectedClientIp, actualClientIp) + + expectedClientIpv6 = true + actualClientIpv6 = requestIdInfo.IsClientIpv6 + assert.Equal(t, expectedClientIpv6, actualClientIpv6) +} + +func TestGenerateRequestId_100Times(t *testing.T) { + generator, _ := NewDefaultRequestIdGenerator(&settings.Config{HttpAddr: "1.2.3.4", HttpPort: 1234, SecretKey: "secretkey"}) + + for i := 1; i <= 100; i++ { + requestId := generator.GenerateRequestId("127.0.0.1") + requestIdInfo := generator.parseRequestIdInfo(generator.parseRequestIdFromUuid(requestId)) + + expectedRequestSeqId := uint32(i) + actualRequestSeqId := requestIdInfo.RequestSeqId + assert.Equal(t, expectedRequestSeqId, actualRequestSeqId) + } +} diff --git a/pkg/requestid/request_id_container.go b/pkg/requestid/request_id_container.go new file mode 100644 index 00000000..b486201c --- /dev/null +++ b/pkg/requestid/request_id_container.go @@ -0,0 +1,28 @@ +package requestid + +import ( + "github.com/mayswind/lab/pkg/settings" +) + +type RequestIdContainer struct { + Current RequestIdGenerator +} + +var ( + Container = &RequestIdContainer{} +) + +func InitializeRequestIdGenerator(config *settings.Config) error { + generator, err := NewDefaultRequestIdGenerator(config) + + if err != nil { + return err + } + + Container.Current = generator + return nil +} + +func (u *RequestIdContainer) GenerateRequestId(clientIpAddr string) string { + return u.Current.GenerateRequestId(clientIpAddr) +} diff --git a/pkg/requestid/request_id_generator.go b/pkg/requestid/request_id_generator.go new file mode 100644 index 00000000..c85cf683 --- /dev/null +++ b/pkg/requestid/request_id_generator.go @@ -0,0 +1,7 @@ +package requestid + +type RequestIdGenerator interface { + GenerateRequestId(clientIpAddr string) string + GetCurrentServerUniqId() uint16 + GetCurrentInstanceUniqId() uint16 +}