Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 137 additions & 24 deletions internal/dms/biz/cloudbeaver.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,12 +223,15 @@ func (cu *CloudbeaverUsecase) Login() echo.MiddlewareFunc {
// 根据cookie 获取登录用户
cloudbeaverActiveUser, err := cu.getActiveUserQuery([]*http.Cookie{cookie})
if err != nil {
cu.log.Errorf("getActiveUserQuery err: %v", err)
return err
}

if cloudbeaverActiveUser.User != nil {
return next(c)
cu.log.Debugf("cached cloudbeaver session invalid for DMS user %s: %v", dmsUserId, err)
cu.UnbindCBSession(dmsToken)
} else if cloudbeaverActiveUser.User != nil {
if err := cu.updateCloudbeaverSession([]*http.Cookie{cookie}); err != nil {
cu.log.Debugf("cached cloudbeaver session cannot refresh for DMS user %s: %v", dmsUserId, err)
cu.UnbindCBSession(dmsToken)
} else {
return next(c)
}
}
}

Expand Down Expand Up @@ -417,7 +420,8 @@ func (cu *CloudbeaverUsecase) GraphQLDistributor() echo.MiddlewareFunc {
var srw *ResponseInterceptor
var cloudbeaverResBuf *bytes.Buffer
var needSmartWriter bool = cloudbeaverHandle.UseLocalHandler ||
params.OperationName == "navNodeChildren"
params.OperationName == "navNodeChildren" ||
params.OperationName == "openSession"

if needSmartWriter {
srw = newSmartResponseWriter(c)
Expand All @@ -428,6 +432,7 @@ func (cu *CloudbeaverUsecase) GraphQLDistributor() echo.MiddlewareFunc {
return
}
cu.handleNavNodeChildrenOperation(params, c, cloudbeaverResBuf)
cu.keepProxySessionValid(params, cloudbeaverResBuf)
respBytesBuf := cloudbeaverResBuf.Bytes()

length := c.Response().Header().Get("content-length")
Expand Down Expand Up @@ -999,6 +1004,40 @@ func (cu *CloudbeaverUsecase) handleNavNodeChildrenOperation(params *graphql.Raw
}
}

func (cu *CloudbeaverUsecase) keepProxySessionValid(params *graphql.RawParams, responseBuf *bytes.Buffer) {
if params.OperationName != "openSession" || responseBuf == nil || responseBuf.Len() == 0 {
return
}

originalResponse := map[string]interface{}{}
if err := json.Unmarshal(responseBuf.Bytes(), &originalResponse); err != nil {
cu.log.Errorf("failed to unmarshal openSession response: %v", err)
return
}
data, ok := originalResponse["data"].(map[string]interface{})
if !ok {
return
}
session, ok := data["session"].(map[string]interface{})
if !ok {
return
}

session["valid"] = true
session["cacheExpired"] = false
if remainingTime, ok := session["remainingTime"].(float64); !ok || remainingTime <= 0 {
session["remainingTime"] = 1800
}

updatedResponseBytes, err := json.Marshal(originalResponse)
if err != nil {
cu.log.Errorf("failed to marshal openSession response: %v", err)
return
}
responseBuf.Reset()
responseBuf.Write(updatedResponseBytes)
}

func (cu *CloudbeaverUsecase) isEnableSQLAudit(dbService *DBService) bool {
if dbService.SQLEConfig == nil || dbService.SQLEConfig.SQLQueryConfig == nil {
return false
Expand Down Expand Up @@ -1140,6 +1179,12 @@ func (cu *CloudbeaverUsecase) getAuthCache(username string) []*http.Cookie {
return cache.cookies
}

func (cu *CloudbeaverUsecase) deleteAuthCache(username string) {
authMutex.Lock()
defer authMutex.Unlock()
delete(authCacheMap, username)
}

// setAuthCache 设置认证缓存
func (cu *CloudbeaverUsecase) setAuthCache(username string, cookies []*http.Cookie) {
authMutex.Lock()
Expand All @@ -1155,11 +1200,31 @@ func (cu *CloudbeaverUsecase) setAuthCache(username string, cookies []*http.Cook

// getAuthCookies 获取认证 cookies(复用缓存逻辑)
func (cu *CloudbeaverUsecase) getAuthCookies(username, password string) ([]*http.Cookie, error) {
if username == cu.cloudbeaverCfg.AdminUser {
cu.deleteAuthCache(username)
cookies, err := cu.loginCloudbeaverServer(username, password)
if err != nil {
return nil, err
}
cu.setAuthCache(username, cookies)
return cookies, nil
}

// 尝试从缓存获取认证信息
cachedCookies := cu.getAuthCache(username)
if cachedCookies != nil {
cu.log.Debugf("Using cached authentication for user: %s", username)
return cachedCookies, nil
if valid, err := cu.isAuthCookiesValid(cachedCookies); err == nil && valid {
if err := cu.updateCloudbeaverSession(cachedCookies); err != nil {
cu.log.Debugf("Cached authentication for user %s cannot refresh session: %v", username, err)
cu.deleteAuthCache(username)
} else {
cu.log.Debugf("Using cached authentication for user: %s", username)
return cachedCookies, nil
}
} else if err != nil {
cu.log.Debugf("Cached authentication for user %s is invalid: %v", username, err)
cu.deleteAuthCache(username)
}
}

// 缓存未命中,执行登录
Expand All @@ -1174,6 +1239,14 @@ func (cu *CloudbeaverUsecase) getAuthCookies(username, password string) ([]*http
return cookies, nil
}

func (cu *CloudbeaverUsecase) isAuthCookiesValid(cookies []*http.Cookie) (bool, error) {
activeUser, err := cu.getActiveUserQuery(cookies)
if err != nil {
return false, err
}
return activeUser != nil && activeUser.User != nil, nil
}

// clearExpiredAuthCache 清理过期的认证缓存
func (cu *CloudbeaverUsecase) clearExpiredAuthCache() {
authMutex.Lock()
Expand Down Expand Up @@ -1214,6 +1287,16 @@ func (cu *CloudbeaverUsecase) createUserIfNotExist(ctx context.Context, cloudbea
return fmt.Errorf("username %s is reserved, cann't be used", cloudbeaverUserId)
}

cloudbeaverUser, exist, err := cu.repo.GetCloudbeaverUserByID(ctx, cloudbeaverUserId)
if err != nil {
cu.log.Errorf("Failed to get CloudBeaver user %s from cache: %v", cloudbeaverUserId, err)
return err
}
if exist && cloudbeaverUser.DMSFingerprint == cu.userUsecase.GetUserFingerprint(dmsUser) {
cu.log.Debugf("CloudBeaver user %s fingerprint matches, no update needed", cloudbeaverUserId)
return nil
}

// 使用管理员身份登录
graphQLClient, err := cu.getGraphQLClientWithRootUser()
if err != nil {
Expand Down Expand Up @@ -1256,19 +1339,6 @@ func (cu *CloudbeaverUsecase) createUserIfNotExist(ctx context.Context, cloudbea
}

cu.log.Infof("Successfully granted role to CloudBeaver user: %s", cloudbeaverUserId)
} else {
cu.log.Debugf("CloudBeaver user %s already exists, checking fingerprint", cloudbeaverUserId)

cloudbeaverUser, exist, err := cu.repo.GetCloudbeaverUserByID(ctx, cloudbeaverUserId)
if err != nil {
cu.log.Errorf("Failed to get CloudBeaver user %s from cache: %v", cloudbeaverUserId, err)
return err
}

if exist && cloudbeaverUser.DMSFingerprint == cu.userUsecase.GetUserFingerprint(dmsUser) {
cu.log.Debugf("CloudBeaver user %s fingerprint matches, no update needed", cloudbeaverUserId)
return nil
}
}

// 设置CloudBeaver用户密码
Expand All @@ -1288,7 +1358,7 @@ func (cu *CloudbeaverUsecase) createUserIfNotExist(ctx context.Context, cloudbea

cu.log.Infof("Successfully updated CloudBeaver user password for user: %s", cloudbeaverUserId)

cloudbeaverUser := &CloudbeaverUser{
cloudbeaverUser = &CloudbeaverUser{
DMSUserID: dmsUser.UID,
DMSFingerprint: cu.userUsecase.GetUserFingerprint(dmsUser),
CloudbeaverUserID: cloudbeaverUserId,
Expand Down Expand Up @@ -1953,10 +2023,53 @@ func (cu *CloudbeaverUsecase) loginCloudbeaverServer(user, pwd string) (cookie [
if err = client.Run(context.TODO(), req, &res); err != nil {
return cookie, fmt.Errorf("cloudbeaver login failed: %v,req: %v,res %v", err, req, res)
}

if len(cookie) == 0 {
return cookie, fmt.Errorf("cloudbeaver login failed: empty session cookie")
}
if err = cu.updateCloudbeaverSession(cookie); err != nil {
return cookie, err
}
return cookie, nil
}

func (cu *CloudbeaverUsecase) openCloudbeaverSession(cookies []*http.Cookie) error {
client := cloudbeaver.NewGraphQlClient(cu.getGraphQLServerURI(), cloudbeaver.WithCookie(cookies), cloudbeaver.WithLogger(cu.log))
req := cloudbeaver.NewRequest(cu.graphQl.OpenSessionQuery(), map[string]interface{}{
"defaultLocale": "en",
})
res := struct {
Session struct {
Valid bool `json:"valid"`
RemainingTime int `json:"remainingTime"`
} `json:"session"`
}{}
if err := client.Run(context.TODO(), req, &res); err != nil {
return fmt.Errorf("cloudbeaver open session failed: %v", err)
}
if !res.Session.Valid {
return fmt.Errorf("cloudbeaver open session failed: invalid session, remaining time: %d", res.Session.RemainingTime)
}
return nil
}

func (cu *CloudbeaverUsecase) updateCloudbeaverSession(cookies []*http.Cookie) error {
client := cloudbeaver.NewGraphQlClient(cu.getGraphQLServerURI(), cloudbeaver.WithCookie(cookies), cloudbeaver.WithLogger(cu.log))
req := cloudbeaver.NewRequest(cu.graphQl.UpdateSessionQuery(), map[string]interface{}{})
res := struct {
Session struct {
Valid bool `json:"valid"`
RemainingTime int `json:"remainingTime"`
} `json:"session"`
}{}
if err := client.Run(context.TODO(), req, &res); err != nil {
return fmt.Errorf("cloudbeaver update session failed: %v", err)
}
if !res.Session.Valid {
return fmt.Errorf("cloudbeaver update session failed: invalid session, remaining time: %d", res.Session.RemainingTime)
}
return nil
}

type GraphQLResponse struct {
Data json.RawMessage `json:"data"`
Errors []GraphQLError `json:"errors"`
Expand Down
25 changes: 25 additions & 0 deletions internal/pkg/cloudbeaver/cloudbeaver.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ type GraphQLImpl interface {
CreateUserQuery() string
GrantUserRoleQuery() string
LoginQuery() string
OpenSessionQuery() string
UpdateSessionQuery() string
GetActiveUserQuery() string
GetExecutionContextListQuery() string
}
Expand Down Expand Up @@ -219,13 +221,36 @@ query authLogin(
configuration: null
credentials: $credentials
linkUser: false
forceSessionsLogout: true
){
authId
}
}
`
}

func (CloudBeaverV2215) OpenSessionQuery() string {
return `
mutation openSession($defaultLocale: String) {
session: openSession(defaultLocale: $defaultLocale) {
valid
remainingTime
}
}
`
}

func (CloudBeaverV2215) UpdateSessionQuery() string {
return `
mutation updateSession {
session: updateSession {
valid
remainingTime
}
}
`
}

func (CloudBeaverV2215) GetActiveUserQuery() string {
return `
query getActiveUser {
Expand Down
2 changes: 1 addition & 1 deletion internal/pkg/cloudbeaver/graphql.go
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ var GraphQLHandlerRouters = map[string] /* gql operation name */ gqlBehavior{
"getActiveUser": {
UseLocalHandler: true,
NeedModifyRemoteRes: true,
}, "authLogout": {
}, "openSession": {}, "authLogout": {
Disable: true,
}, "authLogin": {
Disable: true,
Expand Down