Skip to content

Commit c64f899

Browse files
authored
feat: implement session management (#9286)
* feat(auth): Added device session management - Added the `handleSession` function to manage user device sessions and verify client identity - Updated `auth.go` to call `handleSession` for device handling when a user logs in - Added the `Session` model to database migrations - Added `device.go` and `session.go` files to handle device session logic - Updated `settings.go` to add device-related configuration items, such as the maximum number of devices, device eviction policy, and session TTL * feat(session): Adds session management features - Added `SessionInactive` error type in `device.go` - Added session-related APIs in `router.go` to support listing and evicting sessions - Added `ListSessionsByUser`, `ListSessions`, and `MarkInactive` methods in `session.go` - Returns an appropriate error when the session state is `SessionInactive` * feat(auth): Marks the device session as invalid. - Import the `session` package into the `auth` module to handle device session status. - Add a check in the login logic. If `device_key` is obtained, call `session.MarkInactive` to mark the device session as invalid. - Store the invalid status in the context variable `session_inactive` for subsequent middleware checks. - Add a check in the session refresh logic to abort the process if the current session has been marked invalid. * feat(auth, session): Added device information processing and session management changes - Updated device handling logic in `auth.go` to pass user agent and IP information - Adjusted database queries in `session.go` to optimize session query fields and add `user_agent` and `ip` fields - Modified the `Handle` method to add `ua` and `ip` parameters to store the user agent and IP address - Added the `SessionResp` structure to return a session response containing `user_agent` and `ip` - Updated the `/admin/user/create` and `/webdav` endpoints to pass the user agent and IP address to the device handler
1 parent 3319f6e commit c64f899

File tree

15 files changed

+378
-4
lines changed

15 files changed

+378
-4
lines changed

internal/bootstrap/data/setting.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,9 @@ func InitialSettings() []model.SettingItem {
165165
{Key: conf.ForwardDirectLinkParams, Value: "false", Type: conf.TypeBool, Group: model.GLOBAL},
166166
{Key: conf.IgnoreDirectLinkParams, Value: "sign,alist_ts", Type: conf.TypeString, Group: model.GLOBAL},
167167
{Key: conf.WebauthnLoginEnabled, Value: "false", Type: conf.TypeBool, Group: model.GLOBAL, Flag: model.PUBLIC},
168+
{Key: conf.MaxDevices, Value: "0", Type: conf.TypeNumber, Group: model.GLOBAL},
169+
{Key: conf.DeviceEvictPolicy, Value: "deny", Type: conf.TypeSelect, Options: "deny,evict_oldest", Group: model.GLOBAL},
170+
{Key: conf.DeviceSessionTTL, Value: "86400", Type: conf.TypeNumber, Group: model.GLOBAL},
168171

169172
// single settings
170173
{Key: conf.Token, Value: token, Type: conf.TypeString, Group: model.SINGLE, Flag: model.PRIVATE},

internal/conf/const.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ const (
4848
ForwardDirectLinkParams = "forward_direct_link_params"
4949
IgnoreDirectLinkParams = "ignore_direct_link_params"
5050
WebauthnLoginEnabled = "webauthn_login_enabled"
51+
MaxDevices = "max_devices"
52+
DeviceEvictPolicy = "device_evict_policy"
53+
DeviceSessionTTL = "device_session_ttl"
5154

5255
// index
5356
SearchIndex = "search_index"

internal/db/db.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ var db *gorm.DB
1212

1313
func Init(d *gorm.DB) {
1414
db = d
15-
err := AutoMigrate(new(model.Storage), new(model.User), new(model.Meta), new(model.SettingItem), new(model.SearchNode), new(model.TaskItem), new(model.SSHPublicKey), new(model.Role), new(model.Label), new(model.LabelFileBinding), new(model.ObjFile))
15+
err := AutoMigrate(new(model.Storage), new(model.User), new(model.Meta), new(model.SettingItem), new(model.SearchNode), new(model.TaskItem), new(model.SSHPublicKey), new(model.Role), new(model.Label), new(model.LabelFileBinding), new(model.ObjFile), new(model.Session))
1616
if err != nil {
1717
log.Fatalf("failed migrate database: %s", err.Error())
1818
}

internal/db/session.go

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
package db
2+
3+
import (
4+
"github.com/alist-org/alist/v3/internal/model"
5+
"github.com/pkg/errors"
6+
"gorm.io/gorm/clause"
7+
)
8+
9+
func GetSession(userID uint, deviceKey string) (*model.Session, error) {
10+
s := model.Session{UserID: userID, DeviceKey: deviceKey}
11+
if err := db.Select("user_id, device_key, last_active, status, user_agent, ip").Where(&s).First(&s).Error; err != nil {
12+
return nil, errors.Wrap(err, "failed find session")
13+
}
14+
return &s, nil
15+
}
16+
17+
func CreateSession(s *model.Session) error {
18+
return errors.WithStack(db.Create(s).Error)
19+
}
20+
21+
func UpsertSession(s *model.Session) error {
22+
return errors.WithStack(db.Clauses(clause.OnConflict{UpdateAll: true}).Create(s).Error)
23+
}
24+
25+
func DeleteSession(userID uint, deviceKey string) error {
26+
return errors.WithStack(db.Where("user_id = ? AND device_key = ?", userID, deviceKey).Delete(&model.Session{}).Error)
27+
}
28+
29+
func CountSessionsByUser(userID uint) (int64, error) {
30+
var count int64
31+
err := db.Model(&model.Session{}).Where("user_id = ?", userID).Count(&count).Error
32+
return count, errors.WithStack(err)
33+
}
34+
35+
func DeleteSessionsBefore(ts int64) error {
36+
return errors.WithStack(db.Where("last_active < ?", ts).Delete(&model.Session{}).Error)
37+
}
38+
39+
func GetOldestSession(userID uint) (*model.Session, error) {
40+
var s model.Session
41+
if err := db.Where("user_id = ?", userID).Order("last_active ASC").First(&s).Error; err != nil {
42+
return nil, errors.Wrap(err, "failed get oldest session")
43+
}
44+
return &s, nil
45+
}
46+
47+
func UpdateSessionLastActive(userID uint, deviceKey string, lastActive int64) error {
48+
return errors.WithStack(db.Model(&model.Session{}).Where("user_id = ? AND device_key = ?", userID, deviceKey).Update("last_active", lastActive).Error)
49+
}
50+
51+
func ListSessionsByUser(userID uint) ([]model.Session, error) {
52+
var sessions []model.Session
53+
err := db.Select("user_id, device_key, last_active, status, user_agent, ip").Where("user_id = ? AND status = ?", userID, model.SessionActive).Find(&sessions).Error
54+
return sessions, errors.WithStack(err)
55+
}
56+
57+
func ListSessions() ([]model.Session, error) {
58+
var sessions []model.Session
59+
err := db.Select("user_id, device_key, last_active, status, user_agent, ip").Where("status = ?", model.SessionActive).Find(&sessions).Error
60+
return sessions, errors.WithStack(err)
61+
}
62+
63+
func MarkInactive(sessionID string) error {
64+
return errors.WithStack(db.Model(&model.Session{}).Where("device_key = ?", sessionID).Update("status", model.SessionInactive).Error)
65+
}

internal/device/session.go

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
package device
2+
3+
import (
4+
"time"
5+
6+
"github.com/alist-org/alist/v3/internal/conf"
7+
"github.com/alist-org/alist/v3/internal/db"
8+
"github.com/alist-org/alist/v3/internal/errs"
9+
"github.com/alist-org/alist/v3/internal/model"
10+
"github.com/alist-org/alist/v3/internal/setting"
11+
"github.com/alist-org/alist/v3/pkg/utils"
12+
"github.com/pkg/errors"
13+
"gorm.io/gorm"
14+
)
15+
16+
// Handle verifies device sessions for a user and upserts current session.
17+
func Handle(userID uint, deviceKey, ua, ip string) error {
18+
ttl := setting.GetInt(conf.DeviceSessionTTL, 86400)
19+
if ttl > 0 {
20+
_ = db.DeleteSessionsBefore(time.Now().Unix() - int64(ttl))
21+
}
22+
23+
ip = utils.MaskIP(ip)
24+
25+
now := time.Now().Unix()
26+
sess, err := db.GetSession(userID, deviceKey)
27+
if err == nil {
28+
if sess.Status == model.SessionInactive {
29+
return errors.WithStack(errs.SessionInactive)
30+
}
31+
sess.LastActive = now
32+
sess.Status = model.SessionActive
33+
sess.UserAgent = ua
34+
sess.IP = ip
35+
return db.UpsertSession(sess)
36+
}
37+
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
38+
return err
39+
}
40+
41+
max := setting.GetInt(conf.MaxDevices, 0)
42+
if max > 0 {
43+
count, err := db.CountSessionsByUser(userID)
44+
if err != nil {
45+
return err
46+
}
47+
if count >= int64(max) {
48+
policy := setting.GetStr(conf.DeviceEvictPolicy, "deny")
49+
if policy == "evict_oldest" {
50+
oldest, err := db.GetOldestSession(userID)
51+
if err == nil {
52+
_ = db.DeleteSession(userID, oldest.DeviceKey)
53+
}
54+
} else {
55+
return errors.WithStack(errs.TooManyDevices)
56+
}
57+
}
58+
}
59+
60+
s := &model.Session{UserID: userID, DeviceKey: deviceKey, UserAgent: ua, IP: ip, LastActive: now, Status: model.SessionActive}
61+
return db.CreateSession(s)
62+
}
63+
64+
// Refresh updates last_active for the session.
65+
func Refresh(userID uint, deviceKey string) {
66+
_ = db.UpdateSessionLastActive(userID, deviceKey, time.Now().Unix())
67+
}

internal/errs/device.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package errs
2+
3+
import "errors"
4+
5+
var (
6+
TooManyDevices = errors.New("too many active devices")
7+
SessionInactive = errors.New("session inactive")
8+
)

internal/model/session.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package model
2+
3+
// Session represents a device session of a user.
4+
type Session struct {
5+
UserID uint `json:"user_id" gorm:"index"`
6+
DeviceKey string `json:"device_key" gorm:"primaryKey;size:64"`
7+
UserAgent string `json:"user_agent" gorm:"size:255"`
8+
IP string `json:"ip" gorm:"size:64"`
9+
LastActive int64 `json:"last_active"`
10+
Status int `json:"status"`
11+
}
12+
13+
const (
14+
SessionActive = iota
15+
SessionInactive
16+
)

internal/session/session.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package session
2+
3+
import "github.com/alist-org/alist/v3/internal/db"
4+
5+
// MarkInactive marks the session with the given ID as inactive.
6+
func MarkInactive(sessionID string) error {
7+
return db.MarkInactive(sessionID)
8+
}

pkg/utils/mask.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
package utils
2+
3+
import "strings"
4+
5+
// MaskIP anonymizes middle segments of an IP address.
6+
func MaskIP(ip string) string {
7+
if ip == "" {
8+
return ""
9+
}
10+
if strings.Contains(ip, ":") {
11+
parts := strings.Split(ip, ":")
12+
if len(parts) > 2 {
13+
for i := 1; i < len(parts)-1; i++ {
14+
if parts[i] != "" {
15+
parts[i] = "*"
16+
}
17+
}
18+
return strings.Join(parts, ":")
19+
}
20+
return ip
21+
}
22+
parts := strings.Split(ip, ".")
23+
if len(parts) == 4 {
24+
for i := 1; i < len(parts)-1; i++ {
25+
parts[i] = "*"
26+
}
27+
return strings.Join(parts, ".")
28+
}
29+
return ip
30+
}

server/handles/auth.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"github.com/alist-org/alist/v3/internal/conf"
1313
"github.com/alist-org/alist/v3/internal/model"
1414
"github.com/alist-org/alist/v3/internal/op"
15+
"github.com/alist-org/alist/v3/internal/session"
1516
"github.com/alist-org/alist/v3/internal/setting"
1617
"github.com/alist-org/alist/v3/server/common"
1718
"github.com/gin-gonic/gin"
@@ -247,6 +248,13 @@ func Verify2FA(c *gin.Context) {
247248
}
248249

249250
func LogOut(c *gin.Context) {
251+
if keyVal, ok := c.Get("device_key"); ok {
252+
if err := session.MarkInactive(keyVal.(string)); err != nil {
253+
common.ErrorResp(c, err, 500)
254+
return
255+
}
256+
c.Set("session_inactive", true)
257+
}
250258
err := common.InvalidateToken(c.GetHeader("Authorization"))
251259
if err != nil {
252260
common.ErrorResp(c, err, 500)

0 commit comments

Comments
 (0)