diff --git a/internal/dms/biz/cloudbeaver.go b/internal/dms/biz/cloudbeaver.go index 3ee84446f..5fdddba62 100644 --- a/internal/dms/biz/cloudbeaver.go +++ b/internal/dms/biz/cloudbeaver.go @@ -223,12 +223,15 @@ func (cu *CloudbeaverUsecase) Login() echo.MiddlewareFunc { // 根据cookie 获取登录用户 cloudbeaverActiveUser, err := cu.getActiveUserQuery([]*http.Cookie{cookie}) if err != nil { - cu.log.Errorf("getActiveUserQuery err: %v", err) - return err - } - - if cloudbeaverActiveUser.User != nil { - return next(c) + cu.log.Debugf("cached cloudbeaver session invalid for DMS user %s: %v", dmsUserId, err) + cu.UnbindCBSession(dmsToken) + } else if cloudbeaverActiveUser.User != nil { + if err := cu.updateCloudbeaverSession([]*http.Cookie{cookie}); err != nil { + cu.log.Debugf("cached cloudbeaver session cannot refresh for DMS user %s: %v", dmsUserId, err) + cu.UnbindCBSession(dmsToken) + } else { + return next(c) + } } } @@ -417,7 +420,8 @@ func (cu *CloudbeaverUsecase) GraphQLDistributor() echo.MiddlewareFunc { var srw *ResponseInterceptor var cloudbeaverResBuf *bytes.Buffer var needSmartWriter bool = cloudbeaverHandle.UseLocalHandler || - params.OperationName == "navNodeChildren" + params.OperationName == "navNodeChildren" || + params.OperationName == "openSession" if needSmartWriter { srw = newSmartResponseWriter(c) @@ -428,6 +432,7 @@ func (cu *CloudbeaverUsecase) GraphQLDistributor() echo.MiddlewareFunc { return } cu.handleNavNodeChildrenOperation(params, c, cloudbeaverResBuf) + cu.keepProxySessionValid(params, cloudbeaverResBuf) respBytesBuf := cloudbeaverResBuf.Bytes() length := c.Response().Header().Get("content-length") @@ -999,6 +1004,40 @@ func (cu *CloudbeaverUsecase) handleNavNodeChildrenOperation(params *graphql.Raw } } +func (cu *CloudbeaverUsecase) keepProxySessionValid(params *graphql.RawParams, responseBuf *bytes.Buffer) { + if params.OperationName != "openSession" || responseBuf == nil || responseBuf.Len() == 0 { + return + } + + originalResponse := map[string]interface{}{} + if err := json.Unmarshal(responseBuf.Bytes(), &originalResponse); err != nil { + cu.log.Errorf("failed to unmarshal openSession response: %v", err) + return + } + data, ok := originalResponse["data"].(map[string]interface{}) + if !ok { + return + } + session, ok := data["session"].(map[string]interface{}) + if !ok { + return + } + + session["valid"] = true + session["cacheExpired"] = false + if remainingTime, ok := session["remainingTime"].(float64); !ok || remainingTime <= 0 { + session["remainingTime"] = 1800 + } + + updatedResponseBytes, err := json.Marshal(originalResponse) + if err != nil { + cu.log.Errorf("failed to marshal openSession response: %v", err) + return + } + responseBuf.Reset() + responseBuf.Write(updatedResponseBytes) +} + func (cu *CloudbeaverUsecase) isEnableSQLAudit(dbService *DBService) bool { if dbService.SQLEConfig == nil || dbService.SQLEConfig.SQLQueryConfig == nil { return false @@ -1140,6 +1179,12 @@ func (cu *CloudbeaverUsecase) getAuthCache(username string) []*http.Cookie { return cache.cookies } +func (cu *CloudbeaverUsecase) deleteAuthCache(username string) { + authMutex.Lock() + defer authMutex.Unlock() + delete(authCacheMap, username) +} + // setAuthCache 设置认证缓存 func (cu *CloudbeaverUsecase) setAuthCache(username string, cookies []*http.Cookie) { authMutex.Lock() @@ -1155,11 +1200,31 @@ func (cu *CloudbeaverUsecase) setAuthCache(username string, cookies []*http.Cook // getAuthCookies 获取认证 cookies(复用缓存逻辑) func (cu *CloudbeaverUsecase) getAuthCookies(username, password string) ([]*http.Cookie, error) { + if username == cu.cloudbeaverCfg.AdminUser { + cu.deleteAuthCache(username) + cookies, err := cu.loginCloudbeaverServer(username, password) + if err != nil { + return nil, err + } + cu.setAuthCache(username, cookies) + return cookies, nil + } + // 尝试从缓存获取认证信息 cachedCookies := cu.getAuthCache(username) if cachedCookies != nil { - cu.log.Debugf("Using cached authentication for user: %s", username) - return cachedCookies, nil + if valid, err := cu.isAuthCookiesValid(cachedCookies); err == nil && valid { + if err := cu.updateCloudbeaverSession(cachedCookies); err != nil { + cu.log.Debugf("Cached authentication for user %s cannot refresh session: %v", username, err) + cu.deleteAuthCache(username) + } else { + cu.log.Debugf("Using cached authentication for user: %s", username) + return cachedCookies, nil + } + } else if err != nil { + cu.log.Debugf("Cached authentication for user %s is invalid: %v", username, err) + cu.deleteAuthCache(username) + } } // 缓存未命中,执行登录 @@ -1174,6 +1239,14 @@ func (cu *CloudbeaverUsecase) getAuthCookies(username, password string) ([]*http return cookies, nil } +func (cu *CloudbeaverUsecase) isAuthCookiesValid(cookies []*http.Cookie) (bool, error) { + activeUser, err := cu.getActiveUserQuery(cookies) + if err != nil { + return false, err + } + return activeUser != nil && activeUser.User != nil, nil +} + // clearExpiredAuthCache 清理过期的认证缓存 func (cu *CloudbeaverUsecase) clearExpiredAuthCache() { authMutex.Lock() @@ -1214,6 +1287,16 @@ func (cu *CloudbeaverUsecase) createUserIfNotExist(ctx context.Context, cloudbea return fmt.Errorf("username %s is reserved, cann't be used", cloudbeaverUserId) } + cloudbeaverUser, exist, err := cu.repo.GetCloudbeaverUserByID(ctx, cloudbeaverUserId) + if err != nil { + cu.log.Errorf("Failed to get CloudBeaver user %s from cache: %v", cloudbeaverUserId, err) + return err + } + if exist && cloudbeaverUser.DMSFingerprint == cu.userUsecase.GetUserFingerprint(dmsUser) { + cu.log.Debugf("CloudBeaver user %s fingerprint matches, no update needed", cloudbeaverUserId) + return nil + } + // 使用管理员身份登录 graphQLClient, err := cu.getGraphQLClientWithRootUser() if err != nil { @@ -1256,19 +1339,6 @@ func (cu *CloudbeaverUsecase) createUserIfNotExist(ctx context.Context, cloudbea } cu.log.Infof("Successfully granted role to CloudBeaver user: %s", cloudbeaverUserId) - } else { - cu.log.Debugf("CloudBeaver user %s already exists, checking fingerprint", cloudbeaverUserId) - - cloudbeaverUser, exist, err := cu.repo.GetCloudbeaverUserByID(ctx, cloudbeaverUserId) - if err != nil { - cu.log.Errorf("Failed to get CloudBeaver user %s from cache: %v", cloudbeaverUserId, err) - return err - } - - if exist && cloudbeaverUser.DMSFingerprint == cu.userUsecase.GetUserFingerprint(dmsUser) { - cu.log.Debugf("CloudBeaver user %s fingerprint matches, no update needed", cloudbeaverUserId) - return nil - } } // 设置CloudBeaver用户密码 @@ -1288,7 +1358,7 @@ func (cu *CloudbeaverUsecase) createUserIfNotExist(ctx context.Context, cloudbea cu.log.Infof("Successfully updated CloudBeaver user password for user: %s", cloudbeaverUserId) - cloudbeaverUser := &CloudbeaverUser{ + cloudbeaverUser = &CloudbeaverUser{ DMSUserID: dmsUser.UID, DMSFingerprint: cu.userUsecase.GetUserFingerprint(dmsUser), CloudbeaverUserID: cloudbeaverUserId, @@ -1953,10 +2023,53 @@ func (cu *CloudbeaverUsecase) loginCloudbeaverServer(user, pwd string) (cookie [ if err = client.Run(context.TODO(), req, &res); err != nil { return cookie, fmt.Errorf("cloudbeaver login failed: %v,req: %v,res %v", err, req, res) } - + if len(cookie) == 0 { + return cookie, fmt.Errorf("cloudbeaver login failed: empty session cookie") + } + if err = cu.updateCloudbeaverSession(cookie); err != nil { + return cookie, err + } return cookie, nil } +func (cu *CloudbeaverUsecase) openCloudbeaverSession(cookies []*http.Cookie) error { + client := cloudbeaver.NewGraphQlClient(cu.getGraphQLServerURI(), cloudbeaver.WithCookie(cookies), cloudbeaver.WithLogger(cu.log)) + req := cloudbeaver.NewRequest(cu.graphQl.OpenSessionQuery(), map[string]interface{}{ + "defaultLocale": "en", + }) + res := struct { + Session struct { + Valid bool `json:"valid"` + RemainingTime int `json:"remainingTime"` + } `json:"session"` + }{} + if err := client.Run(context.TODO(), req, &res); err != nil { + return fmt.Errorf("cloudbeaver open session failed: %v", err) + } + if !res.Session.Valid { + return fmt.Errorf("cloudbeaver open session failed: invalid session, remaining time: %d", res.Session.RemainingTime) + } + return nil +} + +func (cu *CloudbeaverUsecase) updateCloudbeaverSession(cookies []*http.Cookie) error { + client := cloudbeaver.NewGraphQlClient(cu.getGraphQLServerURI(), cloudbeaver.WithCookie(cookies), cloudbeaver.WithLogger(cu.log)) + req := cloudbeaver.NewRequest(cu.graphQl.UpdateSessionQuery(), map[string]interface{}{}) + res := struct { + Session struct { + Valid bool `json:"valid"` + RemainingTime int `json:"remainingTime"` + } `json:"session"` + }{} + if err := client.Run(context.TODO(), req, &res); err != nil { + return fmt.Errorf("cloudbeaver update session failed: %v", err) + } + if !res.Session.Valid { + return fmt.Errorf("cloudbeaver update session failed: invalid session, remaining time: %d", res.Session.RemainingTime) + } + return nil +} + type GraphQLResponse struct { Data json.RawMessage `json:"data"` Errors []GraphQLError `json:"errors"` diff --git a/internal/pkg/cloudbeaver/cloudbeaver.go b/internal/pkg/cloudbeaver/cloudbeaver.go index 81eeef3d9..1b1c28f77 100644 --- a/internal/pkg/cloudbeaver/cloudbeaver.go +++ b/internal/pkg/cloudbeaver/cloudbeaver.go @@ -80,6 +80,8 @@ type GraphQLImpl interface { CreateUserQuery() string GrantUserRoleQuery() string LoginQuery() string + OpenSessionQuery() string + UpdateSessionQuery() string GetActiveUserQuery() string GetExecutionContextListQuery() string } @@ -219,6 +221,7 @@ query authLogin( configuration: null credentials: $credentials linkUser: false + forceSessionsLogout: true ){ authId } @@ -226,6 +229,28 @@ query authLogin( ` } +func (CloudBeaverV2215) OpenSessionQuery() string { + return ` +mutation openSession($defaultLocale: String) { + session: openSession(defaultLocale: $defaultLocale) { + valid + remainingTime + } +} +` +} + +func (CloudBeaverV2215) UpdateSessionQuery() string { + return ` +mutation updateSession { + session: updateSession { + valid + remainingTime + } +} +` +} + func (CloudBeaverV2215) GetActiveUserQuery() string { return ` query getActiveUser { diff --git a/internal/pkg/cloudbeaver/graphql.go b/internal/pkg/cloudbeaver/graphql.go index be7a67a78..e059f0104 100644 --- a/internal/pkg/cloudbeaver/graphql.go +++ b/internal/pkg/cloudbeaver/graphql.go @@ -359,7 +359,7 @@ var GraphQLHandlerRouters = map[string] /* gql operation name */ gqlBehavior{ "getActiveUser": { UseLocalHandler: true, NeedModifyRemoteRes: true, - }, "authLogout": { + }, "openSession": {}, "authLogout": { Disable: true, }, "authLogin": { Disable: true,