feat: add user groups

This commit is contained in:
Elias Schneider
2024-10-02 08:43:44 +02:00
parent 7a54d3ae20
commit 24c948e6a6
40 changed files with 1142 additions and 37 deletions

View File

@@ -41,6 +41,7 @@ func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) {
userService := service.NewUserService(db, jwtService)
oidcService := service.NewOidcService(db, jwtService, appConfigService, auditLogService)
testService := service.NewTestService(db, appConfigService)
userGroupService := service.NewUserGroupService(db)
r.Use(middleware.NewCorsMiddleware().Add())
r.Use(middleware.NewRateLimitMiddleware().Add(rate.Every(time.Second), 60))
@@ -57,6 +58,7 @@ func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) {
controller.NewUserController(apiGroup, jwtAuthMiddleware, middleware.NewRateLimitMiddleware(), userService)
controller.NewAppConfigController(apiGroup, jwtAuthMiddleware, appConfigService)
controller.NewAuditLogController(apiGroup, auditLogService, jwtAuthMiddleware)
controller.NewUserGroupController(apiGroup, jwtAuthMiddleware, userGroupService)
// Add test controller in non-production environments
if common.EnvConfig.AppEnv != "production" {

View File

@@ -15,4 +15,5 @@ var (
ErrOidcInvalidCallbackURL = errors.New("invalid callback URL")
ErrFileTypeNotSupported = errors.New("file type not supported")
ErrInvalidCredentials = errors.New("no user found with provided credentials")
ErrNameAlreadyInUse = errors.New("name is already in use")
)

View File

@@ -0,0 +1,162 @@
package controller
import (
"errors"
"net/http"
"strconv"
"github.com/gin-gonic/gin"
"github.com/stonith404/pocket-id/backend/internal/common"
"github.com/stonith404/pocket-id/backend/internal/dto"
"github.com/stonith404/pocket-id/backend/internal/middleware"
"github.com/stonith404/pocket-id/backend/internal/service"
"github.com/stonith404/pocket-id/backend/internal/utils"
)
func NewUserGroupController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.JwtAuthMiddleware, userGroupService *service.UserGroupService) {
ugc := UserGroupController{
UserGroupService: userGroupService,
}
group.GET("/user-groups", jwtAuthMiddleware.Add(true), ugc.list)
group.GET("/user-groups/:id", jwtAuthMiddleware.Add(true), ugc.get)
group.POST("/user-groups", jwtAuthMiddleware.Add(true), ugc.create)
group.PUT("/user-groups/:id", jwtAuthMiddleware.Add(true), ugc.update)
group.DELETE("/user-groups/:id", jwtAuthMiddleware.Add(true), ugc.delete)
group.PUT("/user-groups/:id/users", jwtAuthMiddleware.Add(true), ugc.updateUsers)
}
type UserGroupController struct {
UserGroupService *service.UserGroupService
}
func (ugc *UserGroupController) list(c *gin.Context) {
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("limit", "10"))
searchTerm := c.Query("search")
groups, pagination, err := ugc.UserGroupService.List(searchTerm, page, pageSize)
if err != nil {
utils.ControllerError(c, err)
return
}
var groupsDto = make([]dto.UserGroupDtoWithUserCount, len(groups))
for i, group := range groups {
var groupDto dto.UserGroupDtoWithUserCount
if err := dto.MapStruct(group, &groupDto); err != nil {
utils.ControllerError(c, err)
return
}
groupDto.UserCount, err = ugc.UserGroupService.GetUserCountOfGroup(group.ID)
if err != nil {
utils.ControllerError(c, err)
return
}
groupsDto[i] = groupDto
}
c.JSON(http.StatusOK, gin.H{
"data": groupsDto,
"pagination": pagination,
})
}
func (ugc *UserGroupController) get(c *gin.Context) {
group, err := ugc.UserGroupService.Get(c.Param("id"))
if err != nil {
utils.ControllerError(c, err)
return
}
var groupDto dto.UserGroupDtoWithUsers
if err := dto.MapStruct(group, &groupDto); err != nil {
utils.ControllerError(c, err)
return
}
c.JSON(http.StatusOK, groupDto)
}
func (ugc *UserGroupController) create(c *gin.Context) {
var input dto.UserGroupCreateDto
if err := c.ShouldBindJSON(&input); err != nil {
utils.ControllerError(c, err)
return
}
group, err := ugc.UserGroupService.Create(input)
if err != nil {
if errors.Is(err, common.ErrNameAlreadyInUse) {
utils.CustomControllerError(c, http.StatusConflict, err.Error())
} else {
utils.ControllerError(c, err)
}
return
}
var groupDto dto.UserGroupDtoWithUsers
if err := dto.MapStruct(group, &groupDto); err != nil {
utils.ControllerError(c, err)
return
}
c.JSON(http.StatusCreated, groupDto)
}
func (ugc *UserGroupController) update(c *gin.Context) {
var input dto.UserGroupCreateDto
if err := c.ShouldBindJSON(&input); err != nil {
utils.ControllerError(c, err)
return
}
group, err := ugc.UserGroupService.Update(c.Param("id"), input)
if err != nil {
if errors.Is(err, common.ErrNameAlreadyInUse) {
utils.CustomControllerError(c, http.StatusConflict, err.Error())
} else {
utils.ControllerError(c, err)
}
return
}
var groupDto dto.UserGroupDtoWithUsers
if err := dto.MapStruct(group, &groupDto); err != nil {
utils.ControllerError(c, err)
return
}
c.JSON(http.StatusOK, groupDto)
}
func (ugc *UserGroupController) delete(c *gin.Context) {
if err := ugc.UserGroupService.Delete(c.Param("id")); err != nil {
utils.ControllerError(c, err)
return
}
c.Status(http.StatusNoContent)
}
func (ugc *UserGroupController) updateUsers(c *gin.Context) {
var input dto.UserGroupUpdateUsersDto
if err := c.ShouldBindJSON(&input); err != nil {
utils.ControllerError(c, err)
return
}
group, err := ugc.UserGroupService.UpdateUsers(c.Param("id"), input)
if err != nil {
utils.ControllerError(c, err)
return
}
var groupDto dto.UserGroupDtoWithUsers
if err := dto.MapStruct(group, &groupDto); err != nil {
utils.ControllerError(c, err)
return
}
c.JSON(http.StatusOK, groupDto)
}

View File

@@ -57,15 +57,37 @@ func mapStructInternal(sourceVal reflect.Value, destVal reflect.Value) error {
// Handle direct assignment for simple types
if sourceField.Type() == destField.Type() {
destField.Set(sourceField)
} else if sourceField.Kind() == reflect.Slice && destField.Kind() == reflect.Slice {
// Handle slices
if sourceField.Type().Elem() == destField.Type().Elem() {
// Direct assignment for slices of primitive types or non-struct elements
newSlice := reflect.MakeSlice(destField.Type(), sourceField.Len(), sourceField.Cap())
for j := 0; j < sourceField.Len(); j++ {
newSlice.Index(j).Set(sourceField.Index(j))
}
destField.Set(newSlice)
} else if sourceField.Type().Elem().Kind() == reflect.Struct && destField.Type().Elem().Kind() == reflect.Struct {
// Recursively map slices of structs
newSlice := reflect.MakeSlice(destField.Type(), sourceField.Len(), sourceField.Cap())
for j := 0; j < sourceField.Len(); j++ {
// Get the element from both source and destination slice
sourceElem := sourceField.Index(j)
destElem := reflect.New(destField.Type().Elem()).Elem()
// Recursively map the struct elements
if err := mapStructInternal(sourceElem, destElem); err != nil {
return err
}
// Set the mapped element in the new slice
newSlice.Index(j).Set(destElem)
}
destField.Set(newSlice)
}
} else if sourceField.Kind() == reflect.Struct && destField.Kind() == reflect.Struct {

View File

@@ -0,0 +1,32 @@
package dto
import "time"
type UserGroupDtoWithUsers struct {
ID string `json:"id"`
FriendlyName string `json:"friendlyName"`
Name string `json:"name"`
Users []UserDto `json:"users"`
CreatedAt time.Time `json:"createdAt"`
}
type UserGroupDtoWithUserCount struct {
ID string `json:"id"`
FriendlyName string `json:"friendlyName"`
Name string `json:"name"`
UserCount int64 `json:"userCount"`
CreatedAt time.Time `json:"createdAt"`
}
type UserGroupCreateDto struct {
FriendlyName string `json:"friendlyName" binding:"required,min=3,max=30"`
Name string `json:"name" binding:"required,min=3,max=30,userGroupName"`
}
type UserGroupUpdateUsersDto struct {
UserIDs []string `json:"userIds" binding:"required"`
}
type AssignUserToGroupDto struct {
UserID string `json:"userId" binding:"required"`
}

View File

@@ -28,6 +28,13 @@ var validateUsername validator.Func = func(fl validator.FieldLevel) bool {
return matched
}
var validateUserGroupName validator.Func = func(fl validator.FieldLevel) bool {
// [a-z0-9_] : The group name can only contain lowercase letters, numbers, and underscores
regex := "^[a-z0-9_]+$"
matched, _ := regexp.MatchString(regex, fl.Field().String())
return matched
}
func init() {
if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
if err := v.RegisterValidation("urlList", validateUrlList); err != nil {
@@ -39,4 +46,10 @@ func init() {
log.Fatalf("Failed to register custom validation: %v", err)
}
}
if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
if err := v.RegisterValidation("userGroupName", validateUserGroupName); err != nil {
log.Fatalf("Failed to register custom validation: %v", err)
}
}
}

View File

@@ -15,6 +15,7 @@ type User struct {
LastName string
IsAdmin bool
UserGroups []UserGroup `gorm:"many2many:user_groups_users;"`
Credentials []WebauthnCredential
}

View File

@@ -0,0 +1,8 @@
package model
type UserGroup struct {
Base
FriendlyName string
Name string `gorm:"unique"`
Users []User `gorm:"many2many:user_groups_users;"`
}

View File

@@ -301,15 +301,21 @@ func (s *OidcService) DeleteClientLogo(clientID string) error {
func (s *OidcService) GetUserClaimsForClient(userID string, clientID string) (map[string]interface{}, error) {
var authorizedOidcClient model.UserAuthorizedOidcClient
if err := s.db.Preload("User").First(&authorizedOidcClient, "user_id = ? AND client_id = ?", userID, clientID).Error; err != nil {
if err := s.db.Preload("User.UserGroups").First(&authorizedOidcClient, "user_id = ? AND client_id = ?", userID, clientID).Error; err != nil {
return nil, err
}
user := authorizedOidcClient.User
scope := authorizedOidcClient.Scope
userGroups := make([]string, len(user.UserGroups))
for i, group := range user.UserGroups {
userGroups[i] = group.Name
}
claims := map[string]interface{}{
"sub": user.ID,
"sub": user.ID,
"groups": userGroups,
}
if strings.Contains(scope, "email") {

View File

@@ -0,0 +1,111 @@
package service
import (
"errors"
"github.com/stonith404/pocket-id/backend/internal/common"
"github.com/stonith404/pocket-id/backend/internal/dto"
"github.com/stonith404/pocket-id/backend/internal/model"
"github.com/stonith404/pocket-id/backend/internal/utils"
"gorm.io/gorm"
)
type UserGroupService struct {
db *gorm.DB
}
func NewUserGroupService(db *gorm.DB) *UserGroupService {
return &UserGroupService{db: db}
}
func (s *UserGroupService) List(name string, page int, pageSize int) (groups []model.UserGroup, response utils.PaginationResponse, err error) {
query := s.db.Model(&model.UserGroup{})
if name != "" {
query = query.Where("name LIKE ?", "%"+name+"%")
}
response, err = utils.Paginate(page, pageSize, query, &groups)
return groups, response, err
}
func (s *UserGroupService) Get(id string) (group model.UserGroup, err error) {
err = s.db.Where("id = ?", id).Preload("Users").First(&group).Error
return group, err
}
func (s *UserGroupService) Delete(id string) error {
var group model.UserGroup
if err := s.db.Where("id = ?", id).First(&group).Error; err != nil {
return err
}
return s.db.Delete(&group).Error
}
func (s *UserGroupService) Create(input dto.UserGroupCreateDto) (group model.UserGroup, err error) {
group = model.UserGroup{
FriendlyName: input.FriendlyName,
Name: input.Name,
}
if err := s.db.Preload("Users").Create(&group).Error; err != nil {
if errors.Is(err, gorm.ErrDuplicatedKey) {
return model.UserGroup{}, common.ErrNameAlreadyInUse
}
return model.UserGroup{}, err
}
return group, nil
}
func (s *UserGroupService) Update(id string, input dto.UserGroupCreateDto) (group model.UserGroup, err error) {
group, err = s.Get(id)
if err != nil {
return model.UserGroup{}, err
}
group.Name = input.Name
group.FriendlyName = input.FriendlyName
if err := s.db.Preload("Users").Save(&group).Error; err != nil {
if errors.Is(err, gorm.ErrDuplicatedKey) {
return model.UserGroup{}, common.ErrNameAlreadyInUse
}
return model.UserGroup{}, err
}
return group, nil
}
func (s *UserGroupService) UpdateUsers(id string, input dto.UserGroupUpdateUsersDto) (group model.UserGroup, err error) {
group, err = s.Get(id)
if err != nil {
return model.UserGroup{}, err
}
// Fetch the users based on UserIDs in input
var users []model.User
if len(input.UserIDs) > 0 {
if err := s.db.Where("id IN (?)", input.UserIDs).Find(&users).Error; err != nil {
return model.UserGroup{}, err
}
}
// Replace the current users with the new set of users
if err := s.db.Model(&group).Association("Users").Replace(users); err != nil {
return model.UserGroup{}, err
}
// Save the updated group
if err := s.db.Save(&group).Error; err != nil {
return model.UserGroup{}, err
}
return group, nil
}
func (s *UserGroupService) GetUserCountOfGroup(id string) (int64, error) {
var group model.UserGroup
if err := s.db.Preload("Users").Where("id = ?", id).First(&group).Error; err != nil {
return 0, err
}
return s.db.Model(&group).Association("Users").Count(), nil
}

View File

@@ -5,9 +5,10 @@ import (
)
type PaginationResponse struct {
TotalPages int64 `json:"totalPages"`
TotalItems int64 `json:"totalItems"`
CurrentPage int `json:"currentPage"`
TotalPages int64 `json:"totalPages"`
TotalItems int64 `json:"totalItems"`
CurrentPage int `json:"currentPage"`
ItemsPerPage int `json:"itemsPerPage"`
}
func Paginate(page int, pageSize int, db *gorm.DB, result interface{}) (PaginationResponse, error) {
@@ -33,8 +34,9 @@ func Paginate(page int, pageSize int, db *gorm.DB, result interface{}) (Paginati
}
return PaginationResponse{
TotalPages: (totalItems + int64(pageSize) - 1) / int64(pageSize),
TotalItems: totalItems,
CurrentPage: page,
TotalPages: (totalItems + int64(pageSize) - 1) / int64(pageSize),
TotalItems: totalItems,
CurrentPage: page,
ItemsPerPage: pageSize,
}, nil
}

View File

@@ -0,0 +1,2 @@
DROP TABLE user_groups;
DROP TABLE user_groups_users;

View File

@@ -0,0 +1,16 @@
CREATE TABLE user_groups
(
id TEXT NOT NULL PRIMARY KEY,
created_at DATETIME,
friendly_name TEXT NOT NULL,
name TEXT NOT NULL UNIQUE
);
CREATE TABLE user_groups_users
(
user_id TEXT NOT NULL,
user_group_id TEXT NOT NULL,
PRIMARY KEY (user_id, user_group_id),
FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE,
FOREIGN KEY (user_group_id) REFERENCES user_groups (id) ON DELETE CASCADE
);