Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
130 changes: 68 additions & 62 deletions tencentyun/TLSSigAPI.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ import (
"encoding/base64"
"encoding/json"
"errors"
"io"
"io/ioutil"
"strconv"
"sync"
"time"
)

Expand All @@ -23,7 +25,7 @@ import (
* expire - UserSig 票据的过期时间,单位是秒,比如 86400 代表生成的 UserSig 票据在一天后就无法再使用了。
*/

/**
/**
* Function: Used to issue UserSig that is required by the TRTC and IM services.
*
* Parameter description:
Expand Down Expand Up @@ -94,7 +96,6 @@ func GenUserSigWithBuf(sdkappid int, key string, userid string, expire int, buf
* - privilegeMap == 0010 1010 == 42: Indicates that the UserID has only the permissions to enter the room and receive audio/video data.
*/


func GenPrivateMapKey(sdkappid int, key string, userid string, expire int, roomid uint32, privilegeMap uint32) (string, error) {
var userbuf []byte = genUserBuf(userid, sdkappid, roomid, expire, privilegeMap, 0, "")
return genSig(sdkappid, key, userid, expire, userbuf)
Expand Down Expand Up @@ -252,52 +253,28 @@ func genUserBuf(account string, dwSdkappid int, dwAuthID uint32,
return userBuf
}

func hmacsha256(sdkappid int, key string, identifier string, currTime int64, expire int, base64UserBuf *string) string {
var contentToBeSigned string
contentToBeSigned = "TLS.identifier:" + identifier + "\n"
contentToBeSigned += "TLS.sdkappid:" + strconv.Itoa(sdkappid) + "\n"
contentToBeSigned += "TLS.time:" + strconv.FormatInt(currTime, 10) + "\n"
contentToBeSigned += "TLS.expire:" + strconv.Itoa(expire) + "\n"
if nil != base64UserBuf {
contentToBeSigned += "TLS.userbuf:" + *base64UserBuf + "\n"
}

h := hmac.New(sha256.New, []byte(key))
h.Write([]byte(contentToBeSigned))
return base64.StdEncoding.EncodeToString(h.Sum(nil))
}

func genSig(sdkappid int, key string, identifier string, expire int, userbuf []byte) (string, error) {
currTime := time.Now().Unix()
sigDoc := make(map[string]interface{})
sigDoc["TLS.ver"] = "2.0"
sigDoc["TLS.identifier"] = identifier
sigDoc["TLS.sdkappid"] = sdkappid
sigDoc["TLS.expire"] = expire
sigDoc["TLS.time"] = currTime
var base64UserBuf string
if nil != userbuf {
base64UserBuf = base64.StdEncoding.EncodeToString(userbuf)
sigDoc["TLS.userbuf"] = base64UserBuf
sigDoc["TLS.sig"] = hmacsha256(sdkappid, key, identifier, currTime, expire, &base64UserBuf)
} else {
sigDoc["TLS.sig"] = hmacsha256(sdkappid, key, identifier, currTime, expire, nil)
}

data, err := json.Marshal(sigDoc)
if err != nil {
return "", err
sigDoc := userSig{
Version: "2.0",
Identifier: identifier,
SdkAppID: uint64(sdkappid),
Expire: int64(expire),
Time: currTime,
UserBuf: userbuf,
}
sigDoc.Sig = sigDoc.sign(key)

var b bytes.Buffer
w := zlib.NewWriter(&b)
if _, err = w.Write(data); err != nil {
w := newZlibWriter(&b)
defer zlibWriterPool.Put(w)
if err := json.NewEncoder(w).Encode(sigDoc); err != nil {
return "", err
}
if err = w.Close(); err != nil {
if err := w.Close(); err != nil {
return "", err
}
return base64urlEncode(b.Bytes()), nil
return base64url.EncodeToString(b.Bytes()), nil
}

// VerifyUserSig 检验UserSig在now时间点时是否有效
Expand Down Expand Up @@ -327,7 +304,7 @@ type userSig struct {
Expire int64 `json:"TLS.expire,omitempty"`
Time int64 `json:"TLS.time,omitempty"`
UserBuf []byte `json:"TLS.userbuf,omitempty"`
Sig string `json:"TLS.sig,omitempty"`
Sig []byte `json:"TLS.sig,omitempty"`
}

func newUserSig(usersig string) (userSig, error) {
Expand Down Expand Up @@ -373,35 +350,41 @@ func (u userSig) verify(sdkappid uint64, key string, userid string, now time.Tim
} else if u.UserBuf != nil {
return ErrUserBufTypeNotMatch
}
if u.sign(key) != u.Sig {
if !bytes.Equal(u.sign(key), u.Sig) {
return ErrSigNotMatch
}
return nil
}

func (u userSig) sign(key string) string {
var sb bytes.Buffer
sb.WriteString("TLS.identifier:")
sb.WriteString(u.Identifier)
sb.WriteString("\n")
sb.WriteString("TLS.sdkappid:")
sb.WriteString(strconv.FormatUint(u.SdkAppID, 10))
sb.WriteString("\n")
sb.WriteString("TLS.time:")
sb.WriteString(strconv.FormatInt(u.Time, 10))
sb.WriteString("\n")
sb.WriteString("TLS.expire:")
sb.WriteString(strconv.FormatInt(u.Expire, 10))
sb.WriteString("\n")
if u.UserBuf != nil {
sb.WriteString("TLS.userbuf:")
sb.WriteString(base64.StdEncoding.EncodeToString(u.UserBuf))
sb.WriteString("\n")
}
var (
sigIdentifier = []byte("TLS.identifier:")
sigSdkAppID = []byte("TLS.sdkappid:")
sigTime = []byte("TLS.time:")
sigExpire = []byte("TLS.expire:")
sigUserBuf = []byte("TLS.userbuf:")
sigEnter = []byte("\n")
)

func (u userSig) sign(key string) []byte {
h := hmac.New(sha256.New, []byte(key))
h.Write(sb.Bytes())
return base64.StdEncoding.EncodeToString(h.Sum(nil))
h.Write(sigIdentifier)
h.Write([]byte(u.Identifier))
h.Write(sigEnter)
h.Write(sigSdkAppID)
h.Write([]byte(strconv.FormatUint(u.SdkAppID, 10)))
h.Write(sigEnter)
h.Write(sigTime)
h.Write([]byte(strconv.FormatInt(u.Time, 10)))
h.Write(sigEnter)
h.Write(sigExpire)
h.Write([]byte(strconv.FormatInt(u.Expire, 10)))
h.Write(sigEnter)
if u.UserBuf != nil {
h.Write(sigUserBuf)
h.Write([]byte(base64.StdEncoding.EncodeToString(u.UserBuf)))
h.Write(sigEnter)
}
return h.Sum(nil)
}

// 错误类型
Expand All @@ -413,3 +396,26 @@ var (
ErrUserBufNotMatch = errors.New("userbuf not match")
ErrSigNotMatch = errors.New("sig not match")
)

var (
zlibWriterPool sync.Pool
)

func newZlibWriter(w io.Writer) *zlib.Writer {
v := zlibWriterPool.Get()
if v == nil {
zw, err := zlib.NewWriterLevel(w, DefaultCompressionLevel)
if err != nil {
return zlib.NewWriter(w)
}
return zw
}
zw := v.(*zlib.Writer)
zw.Reset(w)
return zw
}

// DefaultCompressionLevel is the default compression level.
// Default is zlib.NoCompression.
// It can be set to any valid compression level to balance speed and size.
var DefaultCompressionLevel = zlib.NoCompression
12 changes: 12 additions & 0 deletions tencentyun/TLSSigAPI_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,15 @@ func TestGenAndVerify(t *testing.T) {
assert.Equal(t, ErrUserBufTypeNotMatch, VerifyUserSigWithBuf(1, "3", "3", bufSig, now, nil))
assert.Equal(t, ErrUserBufNotMatch, VerifyUserSigWithBuf(1, "3", "3", bufSig, now, []byte{6}))
}

func BenchmarkGenSig(b *testing.B) {
for i := 0; i < b.N; i++ {
_, _ = GenUserSig(1, "abc", "a", 1)
}
}

func BenchmarkGenUserSigWithBuf(b *testing.B) {
for i := 0; i < b.N; i++ {
_, _ = GenUserSigWithBuf(1, "abc", "a", 1, []byte{1})
}
}
8 changes: 1 addition & 7 deletions tencentyun/base64url.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,7 @@ import (
"strings"
)

func base64urlEncode(data []byte) string {
str := base64.StdEncoding.EncodeToString(data)
str = strings.Replace(str, "+", "*", -1)
str = strings.Replace(str, "/", "-", -1)
str = strings.Replace(str, "=", "_", -1)
return str
}
var base64url = base64.NewEncoding("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789*-").WithPadding('_')

func base64urlDecode(str string) ([]byte, error) {
str = strings.Replace(str, "_", "=", -1)
Expand Down
Loading