package repositories import ( "time" "git.secnex.io/secnex/idp-api/models" "gorm.io/gorm" ) type SessionRepository struct { db *gorm.DB } func NewSessionRepository(db *gorm.DB) *SessionRepository { return &SessionRepository{db: db} } // CreateSession creates a new session func (r *SessionRepository) CreateSession(session *models.Session) error { return r.db.Create(session).Error } func (r *SessionRepository) GetSessionCount(user *string) (int64, error) { var count int64 query := r.db if user != nil && *user != "" { query = query.Where("user_id = ?", *user) } err := query.Model(&models.Session{}).Count(&count).Error return count, err } // GetSessions retrieves all sessions from the database with pagination func (r *SessionRepository) GetSessions(page, limit int, user *string) ([]models.Session, error) { var sessions []models.Session query := r.db.Offset((page - 1) * limit).Limit(limit) if user != nil && *user != "" { query = query.Where("user_id = ?", *user) } err := query.Find(&sessions).Error return sessions, err } // GetSessionByID retrieves a session by ID with optional preloading func (r *SessionRepository) GetSessionByID(id string, preloadUser bool, preloadUserRelations ...string) (*models.Session, error) { var session models.Session query := r.db.Where("id = ? AND revoked = ? AND expires_at > ? AND logged_out = ?", id, false, time.Now(), false) if preloadUser { query = query.Preload("User") // Preload additional user relations if specified for _, relation := range preloadUserRelations { query = query.Preload("User." + relation) } } err := query.First(&session).Error if err != nil { return nil, err } return &session, nil } // GetSessionByUserID retrieves sessions by user ID with optional preloading func (r *SessionRepository) GetSessionByUserID(userID string, preloadUser bool, preloadUserRelations ...string) ([]models.Session, error) { var sessions []models.Session query := r.db.Where("user_id = ?", userID) if preloadUser { query = query.Preload("User") // Preload additional user relations if specified for _, relation := range preloadUserRelations { query = query.Preload("User." + relation) } } err := query.Find(&sessions).Error return sessions, err } // UpdateSession updates an existing session func (r *SessionRepository) UpdateSession(session *models.Session) error { return r.db.Save(session).Error } // DeleteSession deletes a session by ID func (r *SessionRepository) DeleteSession(id string) error { return r.db.Delete(&models.Session{}, "id = ?", id).Error } // RevokeSession revokes a session func (r *SessionRepository) RevokeSession(id string) error { return r.db.Model(&models.Session{}).Where("id = ?", id).Update("revoked", true).Error } // RevokeAllSessions revokes all sessions func (r *SessionRepository) RevokeAllSessions() error { return r.db.Model(&models.Session{}).Where("revoked = ?", false).Update("revoked", true).Error } // RevokeAllSessionsByUserID revokes all sessions for a user func (r *SessionRepository) RevokeAllSessionsByUserID(userID string) error { return r.db.Model(&models.Session{}).Where("user_id = ?", userID).Update("revoked", true).Error } // GetActiveSessionsByUserID gets all active (non-revoked) sessions for a user func (r *SessionRepository) GetActiveSessionsByUserID(userID string, preloadUser bool) ([]models.Session, error) { var sessions []models.Session query := r.db.Where("user_id = ? AND revoked = ?", userID, false) if preloadUser { query = query.Preload("User") } err := query.Find(&sessions).Error return sessions, err } // LogoutSessionByID logs out a session by ID func (r *SessionRepository) LogoutSessionByID(id string) error { return r.db.Model(&models.Session{}).Where("id = ?", id).Update("logged_out", true).Error }