From 601f6c488a7b3c266a1d2174282ab3203841a6e5 Mon Sep 17 00:00:00 2001 From: Elias Schneider Date: Sat, 17 Aug 2024 21:57:14 +0200 Subject: [PATCH] refactor: use dependency injection in backend --- backend/cmd/main.go | 2 +- backend/go.mod | 2 +- .../bootstrap/application_images_bootstrap.go | 28 ++ backend/internal/bootstrap/bootstrap.go | 76 +--- .../db.go => bootstrap/db_bootstrap.go} | 44 +- .../internal/bootstrap/router_bootstrap.go | 67 +++ backend/internal/common/config.go | 138 ------ backend/internal/common/env_config.go | 31 ++ backend/internal/common/errors.go | 18 + backend/internal/common/jwt.go | 209 --------- backend/internal/common/webauthn.go | 37 -- .../controller/app_config_controller.go | 140 ++++++ .../internal/controller/oidc_controller.go | 218 +++++++++ .../internal/controller/test_controller.go | 36 ++ .../internal/controller/user_controller.go | 182 ++++++++ .../controller/webauthn_controller.go | 160 +++++++ .../well_known_controller.go} | 24 +- .../handler/application_configuration.go | 196 --------- backend/internal/handler/oidc.go | 415 ------------------ backend/internal/handler/user.go | 276 ------------ backend/internal/handler/webauthn.go | 257 ----------- backend/internal/job/db_cleanup.go | 47 +- .../internal/{common => }/middleware/cors.go | 10 +- .../middleware/file_size_limit.go | 10 +- .../{common => }/middleware/jwt_auth.go | 19 +- .../{common => }/middleware/rate_limit.go | 13 +- backend/internal/model/app_config.go | 20 + .../model/application_configuration.go | 20 - backend/internal/model/base.go | 2 +- backend/internal/model/oidc.go | 6 + backend/internal/model/webauthn.go | 16 + .../internal/service/app_config_service.go | 213 +++++++++ backend/internal/service/jwt_service.go | 248 +++++++++++ backend/internal/service/oidc_service.go | 282 ++++++++++++ .../test.go => service/test_service.go} | 76 ++-- backend/internal/service/user_sevice.go | 165 +++++++ backend/internal/service/webauthn_service.go | 196 +++++++++ backend/internal/utils/file_util.go | 22 + backend/internal/utils/handler_error_util.go | 12 +- backend/internal/utils/paging_util.go | 7 +- ...0240817191051_rename_config_table.down.sql | 2 + .../20240817191051_rename_config_table.up.sql | 2 + frontend/tests/account-settings.spec.ts | 12 - frontend/tests/user-settings.spec.ts | 15 + 44 files changed, 2220 insertions(+), 1751 deletions(-) create mode 100644 backend/internal/bootstrap/application_images_bootstrap.go rename backend/internal/{common/db.go => bootstrap/db_bootstrap.go} (63%) create mode 100644 backend/internal/bootstrap/router_bootstrap.go delete mode 100644 backend/internal/common/config.go create mode 100644 backend/internal/common/env_config.go create mode 100644 backend/internal/common/errors.go delete mode 100644 backend/internal/common/jwt.go delete mode 100644 backend/internal/common/webauthn.go create mode 100644 backend/internal/controller/app_config_controller.go create mode 100644 backend/internal/controller/oidc_controller.go create mode 100644 backend/internal/controller/test_controller.go create mode 100644 backend/internal/controller/user_controller.go create mode 100644 backend/internal/controller/webauthn_controller.go rename backend/internal/{handler/well_known.go => controller/well_known_controller.go} (57%) delete mode 100644 backend/internal/handler/application_configuration.go delete mode 100644 backend/internal/handler/oidc.go delete mode 100644 backend/internal/handler/user.go delete mode 100644 backend/internal/handler/webauthn.go rename backend/internal/{common => }/middleware/cors.go (57%) rename backend/internal/{common => }/middleware/file_size_limit.go (76%) rename backend/internal/{common => }/middleware/jwt_auth.go (68%) rename backend/internal/{common => }/middleware/rate_limit.go (83%) create mode 100644 backend/internal/model/app_config.go delete mode 100644 backend/internal/model/application_configuration.go create mode 100644 backend/internal/service/app_config_service.go create mode 100644 backend/internal/service/jwt_service.go create mode 100644 backend/internal/service/oidc_service.go rename backend/internal/{handler/test.go => service/test_service.go} (74%) create mode 100644 backend/internal/service/user_sevice.go create mode 100644 backend/internal/service/webauthn_service.go create mode 100644 backend/migrations/20240817191051_rename_config_table.down.sql create mode 100644 backend/migrations/20240817191051_rename_config_table.up.sql diff --git a/backend/cmd/main.go b/backend/cmd/main.go index 6d19218..b2ebc11 100644 --- a/backend/cmd/main.go +++ b/backend/cmd/main.go @@ -1,7 +1,7 @@ package main import ( - "golang-rest-api-template/internal/bootstrap" + "github.com/stonith404/pocket-id/backend/internal/bootstrap" ) func main() { diff --git a/backend/go.mod b/backend/go.mod index 2c8dce8..da81855 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -1,4 +1,4 @@ -module golang-rest-api-template +module github.com/stonith404/pocket-id/backend go 1.22 diff --git a/backend/internal/bootstrap/application_images_bootstrap.go b/backend/internal/bootstrap/application_images_bootstrap.go new file mode 100644 index 0000000..2f08c9e --- /dev/null +++ b/backend/internal/bootstrap/application_images_bootstrap.go @@ -0,0 +1,28 @@ +package bootstrap + +import ( + "github.com/stonith404/pocket-id/backend/internal/common" + "github.com/stonith404/pocket-id/backend/internal/utils" + "log" + "os" +) + +func initApplicationImages() { + dirPath := common.EnvConfig.UploadPath + "/application-images" + + files, err := os.ReadDir(dirPath) + if err != nil && !os.IsNotExist(err) { + log.Fatalf("Error reading directory: %v", err) + } + + // Skip if files already exist + if len(files) > 1 { + return + } + + // Copy files from source to destination + err = utils.CopyDirectory("./images", dirPath) + if err != nil { + log.Fatalf("Error copying directory: %v", err) + } +} diff --git a/backend/internal/bootstrap/bootstrap.go b/backend/internal/bootstrap/bootstrap.go index 32fbd55..7fdabb6 100644 --- a/backend/internal/bootstrap/bootstrap.go +++ b/backend/internal/bootstrap/bootstrap.go @@ -1,78 +1,16 @@ package bootstrap import ( - "github.com/gin-gonic/gin" _ "github.com/golang-migrate/migrate/v4/source/file" - "golang-rest-api-template/internal/common" - "golang-rest-api-template/internal/common/middleware" - "golang-rest-api-template/internal/handler" - "golang-rest-api-template/internal/job" - "golang-rest-api-template/internal/utils" - "golang.org/x/time/rate" - "log" - "os" - "time" + "github.com/stonith404/pocket-id/backend/internal/job" + "github.com/stonith404/pocket-id/backend/internal/service" ) func Bootstrap() { - common.InitDatabase() - common.InitDbConfig() + db := newDatabase() + appConfigService := service.NewAppConfigService(db) + initApplicationImages() - job.RegisterJobs() - initRouter() -} - -func initRouter() { - switch common.EnvConfig.AppEnv { - case "production": - gin.SetMode(gin.ReleaseMode) - case "development": - gin.SetMode(gin.DebugMode) - case "test": - gin.SetMode(gin.TestMode) - } - - r := gin.Default() - - r.Use(gin.Logger()) - - r.Use(middleware.Cors()) - r.Use(middleware.RateLimiter(rate.Every(time.Second), 60)) - - apiGroup := r.Group("/api") - handler.RegisterRoutes(apiGroup) - handler.RegisterOIDCRoutes(apiGroup) - handler.RegisterUserRoutes(apiGroup) - handler.RegisterConfigurationRoutes(apiGroup) - if common.EnvConfig.AppEnv != "production" { - handler.RegisterTestRoutes(apiGroup) - } - - baseGroup := r.Group("/") - handler.RegisterWellKnownRoutes(baseGroup) - - if err := r.Run(common.EnvConfig.Host + ":" + common.EnvConfig.Port); err != nil { - log.Fatal(err) - } - -} - -func initApplicationImages() { - dirPath := common.EnvConfig.UploadPath + "/application-images" - - files, err := os.ReadDir(dirPath) - if err != nil && !os.IsNotExist(err) { - log.Fatalf("Error reading directory: %v", err) - } - - // Skip if files already exist - if len(files) > 1 { - return - } - - // Copy files from source to destination - err = utils.CopyDirectory("./images", dirPath) - if err != nil { - log.Fatalf("Error copying directory: %v", err) - } + job.RegisterJobs(db) + initRouter(db, appConfigService) } diff --git a/backend/internal/common/db.go b/backend/internal/bootstrap/db_bootstrap.go similarity index 63% rename from backend/internal/common/db.go rename to backend/internal/bootstrap/db_bootstrap.go index 3c3cf1c..d64172b 100644 --- a/backend/internal/common/db.go +++ b/backend/internal/bootstrap/db_bootstrap.go @@ -1,51 +1,54 @@ -package common +package bootstrap import ( "errors" "github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4/database/sqlite3" + "github.com/stonith404/pocket-id/backend/internal/common" + "gorm.io/driver/sqlite" + "gorm.io/gorm" "gorm.io/gorm/logger" "log" "os" "time" - - "gorm.io/driver/sqlite" - "gorm.io/gorm" ) -var DB *gorm.DB - -func InitDatabase() { - connectDatabase() - sqlDb, err := DB.DB() +func newDatabase() (db *gorm.DB) { + db, err := connectDatabase() if err != nil { - log.Fatal("failed to get sql db", err) + log.Fatalf("failed to connect to database: %v", err) } + sqlDb, err := db.DB() + if err != nil { + log.Fatalf("failed to get sql.DB: %v", err) + } + driver, err := sqlite3.WithInstance(sqlDb, &sqlite3.Config{}) m, err := migrate.NewWithDatabaseInstance( "file://migrations", "postgres", driver) if err != nil { - log.Fatal("failed to create migration instance", err) + log.Fatalf("failed to create migration instance: %v", err) } err = m.Up() if err != nil && !errors.Is(err, migrate.ErrNoChange) { - log.Fatal("failed to run migrations", err) + log.Fatalf("failed to apply migrations: %v", err) } + + return db } -func connectDatabase() { - var database *gorm.DB - var err error +func connectDatabase() (db *gorm.DB, err error) { + dbPath := common.EnvConfig.DBPath - dbPath := EnvConfig.DBPath - if EnvConfig.AppEnv == "test" { + // Use in-memory database for testing + if common.EnvConfig.AppEnv == "test" { dbPath = "file::memory:?cache=shared" } for i := 1; i <= 3; i++ { - database, err = gorm.Open(sqlite.Open(dbPath), &gorm.Config{ + db, err = gorm.Open(sqlite.Open(dbPath), &gorm.Config{ TranslateError: true, Logger: getLogger(), }) @@ -57,11 +60,11 @@ func connectDatabase() { } } - DB = database + return db, err } func getLogger() logger.Interface { - isProduction := EnvConfig.AppEnv == "production" + isProduction := common.EnvConfig.AppEnv == "production" var logLevel logger.LogLevel if isProduction { @@ -70,7 +73,6 @@ func getLogger() logger.Interface { logLevel = logger.Info } - // Create the GORM logger return logger.New( log.New(os.Stdout, "\r\n", log.LstdFlags), logger.Config{ diff --git a/backend/internal/bootstrap/router_bootstrap.go b/backend/internal/bootstrap/router_bootstrap.go new file mode 100644 index 0000000..008cdc6 --- /dev/null +++ b/backend/internal/bootstrap/router_bootstrap.go @@ -0,0 +1,67 @@ +package bootstrap + +import ( + "log" + "time" + + "github.com/gin-gonic/gin" + "github.com/stonith404/pocket-id/backend/internal/common" + "github.com/stonith404/pocket-id/backend/internal/controller" + "github.com/stonith404/pocket-id/backend/internal/middleware" + "github.com/stonith404/pocket-id/backend/internal/service" + "golang.org/x/time/rate" + "gorm.io/gorm" +) + +func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) { + // Set the appropriate Gin mode based on the environment + switch common.EnvConfig.AppEnv { + case "production": + gin.SetMode(gin.ReleaseMode) + case "development": + gin.SetMode(gin.DebugMode) + case "test": + gin.SetMode(gin.TestMode) + } + + r := gin.Default() + r.Use(gin.Logger()) + + // Add middleware + r.Use( + middleware.NewCorsMiddleware().Add(), + middleware.NewRateLimitMiddleware().Add(rate.Every(time.Second), 60), + ) + + // Initialize services + webauthnService := service.NewWebAuthnService(db, appConfigService) + jwtService := service.NewJwtService(appConfigService) + userService := service.NewUserService(db, jwtService) + oidcService := service.NewOidcService(db, jwtService) + testService := service.NewTestService(db, appConfigService) + + // Initialize middleware + jwtAuthMiddleware := middleware.NewJwtAuthMiddleware(jwtService) + fileSizeLimitMiddleware := middleware.NewFileSizeLimitMiddleware() + + // Set up API routes + apiGroup := r.Group("/api") + controller.NewWebauthnController(apiGroup, jwtAuthMiddleware, middleware.NewRateLimitMiddleware(), webauthnService, jwtService) + controller.NewOidcController(apiGroup, jwtAuthMiddleware, fileSizeLimitMiddleware, oidcService) + controller.NewUserController(apiGroup, jwtAuthMiddleware, middleware.NewRateLimitMiddleware(), userService) + controller.NewApplicationConfigurationController(apiGroup, jwtAuthMiddleware, appConfigService) + + // Add test controller in non-production environments + if common.EnvConfig.AppEnv != "production" { + controller.NewTestController(apiGroup, testService) + } + + // Set up base routes + baseGroup := r.Group("/") + controller.NewWellKnownController(baseGroup, jwtService) + + // Run the server + if err := r.Run(common.EnvConfig.Host + ":" + common.EnvConfig.Port); err != nil { + log.Fatal(err) + } +} diff --git a/backend/internal/common/config.go b/backend/internal/common/config.go deleted file mode 100644 index 6762e59..0000000 --- a/backend/internal/common/config.go +++ /dev/null @@ -1,138 +0,0 @@ -package common - -import ( - "github.com/caarlos0/env/v11" - _ "github.com/joho/godotenv/autoload" - "golang-rest-api-template/internal/model" - "log" - "reflect" -) - -type EnvConfigSchema struct { - AppEnv string `env:"APP_ENV"` - AppURL string `env:"PUBLIC_APP_URL"` - DBPath string `env:"DB_PATH"` - UploadPath string `env:"UPLOAD_PATH"` - Port string `env:"BACKEND_PORT"` - Host string `env:"HOST"` -} - -var EnvConfig = &EnvConfigSchema{ - AppEnv: "production", - DBPath: "data/pocket-id.db", - UploadPath: "data/uploads", - AppURL: "http://localhost", - Port: "8080", - Host: "localhost", -} - -var DbConfig = NewDefaultDbConfig() - -func NewDefaultDbConfig() model.ApplicationConfiguration { - return model.ApplicationConfiguration{ - AppName: model.ApplicationConfigurationVariable{ - Key: "appName", - Type: "string", - IsPublic: true, - Value: "Pocket ID", - }, - SessionDuration: model.ApplicationConfigurationVariable{ - Key: "sessionDuration", - Type: "number", - Value: "60", - }, - BackgroundImageType: model.ApplicationConfigurationVariable{ - Key: "backgroundImageType", - Type: "string", - IsInternal: true, - Value: "jpg", - }, - LogoImageType: model.ApplicationConfigurationVariable{ - Key: "logoImageType", - Type: "string", - IsInternal: true, - Value: "svg", - }, - } -} - -// LoadDbConfigFromDb refreshes the database configuration by loading the current values -// from the database and updating the DbConfig struct. -func LoadDbConfigFromDb() error { - dbConfigReflectValue := reflect.ValueOf(&DbConfig).Elem() - - for i := 0; i < dbConfigReflectValue.NumField(); i++ { - dbConfigField := dbConfigReflectValue.Field(i) - currentConfigVar := dbConfigField.Interface().(model.ApplicationConfigurationVariable) - var storedConfigVar model.ApplicationConfigurationVariable - if err := DB.First(&storedConfigVar, "key = ?", currentConfigVar.Key).Error; err != nil { - return err - } - - dbConfigField.Set(reflect.ValueOf(storedConfigVar)) - } - - return nil -} - -// InitDbConfig creates the default configuration values in the database if they do not exist, -// updates existing configurations if they differ from the default, and deletes any configurations -// that are not in the default configuration. -func InitDbConfig() { - // Reflect to get the underlying value of DbConfig and its default configuration - dbConfigReflectValue := reflect.ValueOf(&DbConfig).Elem() - defaultDbConfig := NewDefaultDbConfig() - defaultConfigReflectValue := reflect.ValueOf(&defaultDbConfig).Elem() - defaultKeys := make(map[string]struct{}) - - // Iterate over the fields of DbConfig - for i := 0; i < dbConfigReflectValue.NumField(); i++ { - dbConfigField := dbConfigReflectValue.Field(i) - currentConfigVar := dbConfigField.Interface().(model.ApplicationConfigurationVariable) - defaultConfigVar := defaultConfigReflectValue.Field(i).Interface().(model.ApplicationConfigurationVariable) - defaultKeys[currentConfigVar.Key] = struct{}{} - - var storedConfigVar model.ApplicationConfigurationVariable - if err := DB.First(&storedConfigVar, "key = ?", currentConfigVar.Key).Error; err != nil { - // If the configuration does not exist, create it - if err := DB.Create(&defaultConfigVar).Error; err != nil { - log.Fatalf("Failed to create default configuration: %v", err) - } - dbConfigField.Set(reflect.ValueOf(defaultConfigVar)) - continue - } - - // Update existing configuration if it differs from the default - if storedConfigVar.Type != defaultConfigVar.Type || storedConfigVar.IsPublic != defaultConfigVar.IsPublic || storedConfigVar.IsInternal != defaultConfigVar.IsInternal { - storedConfigVar.Type = defaultConfigVar.Type - storedConfigVar.IsPublic = defaultConfigVar.IsPublic - storedConfigVar.IsInternal = defaultConfigVar.IsInternal - if err := DB.Save(&storedConfigVar).Error; err != nil { - log.Fatalf("Failed to update configuration: %v", err) - } - } - - // Set the value in DbConfig - dbConfigField.Set(reflect.ValueOf(storedConfigVar)) - } - - // Delete any configurations not in the default keys - var allConfigVars []model.ApplicationConfigurationVariable - if err := DB.Find(&allConfigVars).Error; err != nil { - log.Fatalf("Failed to retrieve existing configurations: %v", err) - } - - for _, config := range allConfigVars { - if _, exists := defaultKeys[config.Key]; !exists { - if err := DB.Delete(&config).Error; err != nil { - log.Fatalf("Failed to delete outdated configuration: %v", err) - } - } - } -} - -func init() { - if err := env.ParseWithOptions(EnvConfig, env.Options{}); err != nil { - log.Fatal(err) - } -} diff --git a/backend/internal/common/env_config.go b/backend/internal/common/env_config.go new file mode 100644 index 0000000..4a0f673 --- /dev/null +++ b/backend/internal/common/env_config.go @@ -0,0 +1,31 @@ +package common + +import ( + "github.com/caarlos0/env/v11" + _ "github.com/joho/godotenv/autoload" + "log" +) + +type EnvConfigSchema struct { + AppEnv string `env:"APP_ENV"` + AppURL string `env:"PUBLIC_APP_URL"` + DBPath string `env:"DB_PATH"` + UploadPath string `env:"UPLOAD_PATH"` + Port string `env:"BACKEND_PORT"` + Host string `env:"HOST"` +} + +var EnvConfig = &EnvConfigSchema{ + AppEnv: "production", + DBPath: "data/pocket-id.db", + UploadPath: "data/uploads", + AppURL: "http://localhost", + Port: "8080", + Host: "localhost", +} + +func init() { + if err := env.ParseWithOptions(EnvConfig, env.Options{}); err != nil { + log.Fatal(err) + } +} diff --git a/backend/internal/common/errors.go b/backend/internal/common/errors.go new file mode 100644 index 0000000..7747f97 --- /dev/null +++ b/backend/internal/common/errors.go @@ -0,0 +1,18 @@ +package common + +import "errors" + +var ( + ErrUsernameTaken = errors.New("username is already taken") + ErrEmailTaken = errors.New("email is already taken") + ErrSetupAlreadyCompleted = errors.New("setup already completed") + ErrInvalidBody = errors.New("invalid request body") + ErrTokenInvalidOrExpired = errors.New("token is invalid or expired") + ErrOidcMissingAuthorization = errors.New("missing authorization") + ErrOidcGrantTypeNotSupported = errors.New("grant type not supported") + ErrOidcMissingClientCredentials = errors.New("client id or secret not provided") + ErrOidcClientSecretInvalid = errors.New("invalid client secret") + ErrOidcInvalidAuthorizationCode = errors.New("invalid authorization code") + ErrFileTypeNotSupported = errors.New("file type not supported") + ErrInvalidCredentials = errors.New("no user found with provided credentials") +) diff --git a/backend/internal/common/jwt.go b/backend/internal/common/jwt.go deleted file mode 100644 index 8a4586e..0000000 --- a/backend/internal/common/jwt.go +++ /dev/null @@ -1,209 +0,0 @@ -package common - -import ( - "crypto/rand" - "crypto/rsa" - "crypto/x509" - "encoding/base64" - "encoding/pem" - "errors" - "github.com/golang-jwt/jwt/v5" - "golang-rest-api-template/internal/model" - "golang-rest-api-template/internal/utils" - "log" - "math/big" - "os" - "path/filepath" - "slices" - "strconv" - "strings" - "time" -) - -var ( - PrivateKey *rsa.PrivateKey - PublicKey *rsa.PublicKey -) - -const ( - privateKeyPath = "data/keys/jwt_private_key.pem" - publicKeyPath = "data/keys/jwt_public_key.pem" -) - -type accessTokenJWTClaims struct { - jwt.RegisteredClaims - IsAdmin bool `json:"isAdmin,omitempty"` -} - -// GenerateIDToken generates an ID token for the given user, clientID, scope and nonce. -func GenerateIDToken(user model.User, clientID string, scope string, nonce string) (tokenString string, err error) { - profileClaims := map[string]interface{}{ - "given_name": user.FirstName, - "family_name": user.LastName, - "email": user.Email, - "preferred_username": user.Username, - } - - claims := jwt.MapClaims{ - "sub": user.ID, - "aud": clientID, - "exp": jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), - "iat": jwt.NewNumericDate(time.Now()), - } - - if nonce != "" { - claims["nonce"] = nonce - } - if strings.Contains(scope, "profile") { - for k, v := range profileClaims { - claims[k] = v - } - } - if strings.Contains(scope, "email") { - claims["email"] = user.Email - } - - token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) - signedToken, err := token.SignedString(PrivateKey) - if err != nil { - return "", err - } - - return signedToken, nil -} - -// GenerateAccessToken generates an access token for the given user. -func GenerateAccessToken(user model.User) (tokenString string, err error) { - sessionDurationInMinutes, _ := strconv.Atoi(DbConfig.SessionDuration.Value) - claim := accessTokenJWTClaims{ - RegisteredClaims: jwt.RegisteredClaims{ - Subject: user.ID, - ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Duration(sessionDurationInMinutes) * time.Minute)), - IssuedAt: jwt.NewNumericDate(time.Now()), - Audience: jwt.ClaimStrings{utils.GetHostFromURL(EnvConfig.AppURL)}, - }, - IsAdmin: user.IsAdmin, - } - token := jwt.NewWithClaims(jwt.SigningMethodRS256, claim) - tokenString, err = token.SignedString(PrivateKey) - return tokenString, err -} - -// VerifyAccessToken verifies the given access token and returns the claims if the token is valid. -func VerifyAccessToken(tokenString string) (*accessTokenJWTClaims, error) { - token, err := jwt.ParseWithClaims(tokenString, &accessTokenJWTClaims{}, func(token *jwt.Token) (interface{}, error) { - return PublicKey, nil - }) - if err != nil || !token.Valid { - return nil, errors.New("couldn't handle this token") - } - - claims, isValid := token.Claims.(*accessTokenJWTClaims) - if !isValid { - return nil, errors.New("can't parse claims") - } - - if !slices.Contains(claims.Audience, utils.GetHostFromURL(EnvConfig.AppURL)) { - return nil, errors.New("audience doesn't match") - } - return claims, nil -} - -type JWK struct { - Kty string `json:"kty"` - Use string `json:"use"` - Kid string `json:"kid"` - Alg string `json:"alg"` - N string `json:"n"` - E string `json:"e"` -} - -// GetJWK returns the JSON Web Key (JWK) for the public key. -func GetJWK() (JWK, error) { - if PublicKey == nil { - return JWK{}, errors.New("public key is not initialized") - } - - // Create JWK from RSA public key - jwk := JWK{ - Kty: "RSA", - Use: "sig", - Kid: "1", // Key ID can be set to any identifier. Here it's statically set to "1" - Alg: "RS256", - N: base64.RawURLEncoding.EncodeToString(PublicKey.N.Bytes()), - E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(PublicKey.E)).Bytes()), - } - - return jwk, nil -} - -// generateKeys generates a new RSA key pair and saves the private and public keys to the data folder. -func generateKeys() { - if err := os.MkdirAll(filepath.Dir(privateKeyPath), 0700); err != nil { - log.Fatal("Failed to create directories for keys", err) - } - - privateKey, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - log.Fatal("Failed to generate private key", err) - } - - privateKeyFile, err := os.Create(privateKeyPath) - if err != nil { - log.Fatal("Failed to create private key file", err) - } - defer privateKeyFile.Close() - - privateKeyPEM := pem.EncodeToMemory( - &pem.Block{ - Type: "RSA PRIVATE KEY", - Bytes: x509.MarshalPKCS1PrivateKey(privateKey), - }, - ) - _, err = privateKeyFile.Write(privateKeyPEM) - if err != nil { - log.Fatal("Failed to write private key file", err) - } - - publicKey := &privateKey.PublicKey - publicKeyFile, err := os.Create(publicKeyPath) - if err != nil { - log.Fatal("Failed to create public key file", err) - } - defer publicKeyFile.Close() - - publicKeyPEM := pem.EncodeToMemory( - &pem.Block{ - Type: "RSA PUBLIC KEY", - Bytes: x509.MarshalPKCS1PublicKey(publicKey), - }, - ) - _, err = publicKeyFile.Write(publicKeyPEM) - if err != nil { - log.Fatal("Failed to write public key file", err) - } -} - -func init() { - if _, err := os.Stat(privateKeyPath); os.IsNotExist(err) { - generateKeys() - } - - privateKeyBytes, err := os.ReadFile(privateKeyPath) - if err != nil { - log.Fatal("Can't read jwt private key", err) - } - PrivateKey, err = jwt.ParseRSAPrivateKeyFromPEM(privateKeyBytes) - if err != nil { - log.Fatal("Can't parse jwt private key", err) - } - - publicKeyBytes, err := os.ReadFile(publicKeyPath) - if err != nil { - log.Fatal("Can't read jwt public key", err) - } - PublicKey, err = jwt.ParseRSAPublicKeyFromPEM(publicKeyBytes) - if err != nil { - log.Fatal("Can't parse jwt public key", err) - } -} diff --git a/backend/internal/common/webauthn.go b/backend/internal/common/webauthn.go deleted file mode 100644 index 2c966a4..0000000 --- a/backend/internal/common/webauthn.go +++ /dev/null @@ -1,37 +0,0 @@ -package common - -import ( - "github.com/go-webauthn/webauthn/webauthn" - "golang-rest-api-template/internal/utils" - "log" - "time" -) - -var ( - WebAuthn *webauthn.WebAuthn - err error -) - -func init() { - config := &webauthn.Config{ - RPDisplayName: DbConfig.AppName.Value, - RPID: utils.GetHostFromURL(EnvConfig.AppURL), - RPOrigins: []string{EnvConfig.AppURL}, - Timeouts: webauthn.TimeoutsConfig{ - Login: webauthn.TimeoutConfig{ - Enforce: true, - Timeout: time.Second * 60, - TimeoutUVD: time.Second * 60, - }, - Registration: webauthn.TimeoutConfig{ - Enforce: true, - Timeout: time.Second * 60, - TimeoutUVD: time.Second * 60, - }, - }, - } - - if WebAuthn, err = webauthn.New(config); err != nil { - log.Fatal(err) - } -} diff --git a/backend/internal/controller/app_config_controller.go b/backend/internal/controller/app_config_controller.go new file mode 100644 index 0000000..11b382b --- /dev/null +++ b/backend/internal/controller/app_config_controller.go @@ -0,0 +1,140 @@ +package controller + +import ( + "errors" + "fmt" + "github.com/gin-gonic/gin" + "github.com/stonith404/pocket-id/backend/internal/common" + "github.com/stonith404/pocket-id/backend/internal/middleware" + "github.com/stonith404/pocket-id/backend/internal/model" + "github.com/stonith404/pocket-id/backend/internal/service" + "github.com/stonith404/pocket-id/backend/internal/utils" + "net/http" +) + +func NewApplicationConfigurationController( + group *gin.RouterGroup, + jwtAuthMiddleware *middleware.JwtAuthMiddleware, + appConfigService *service.AppConfigService) { + + acc := &ApplicationConfigurationController{ + appConfigService: appConfigService, + } + group.GET("/application-configuration", acc.listApplicationConfigurationHandler) + group.GET("/application-configuration/all", jwtAuthMiddleware.Add(true), acc.listAllApplicationConfigurationHandler) + group.PUT("/application-configuration", acc.updateApplicationConfigurationHandler) + + group.GET("/application-configuration/logo", acc.getLogoHandler) + group.GET("/application-configuration/background-image", acc.getBackgroundImageHandler) + group.GET("/application-configuration/favicon", acc.getFaviconHandler) + group.PUT("/application-configuration/logo", jwtAuthMiddleware.Add(true), acc.updateLogoHandler) + group.PUT("/application-configuration/favicon", jwtAuthMiddleware.Add(true), acc.updateFaviconHandler) + group.PUT("/application-configuration/background-image", jwtAuthMiddleware.Add(true), acc.updateBackgroundImageHandler) +} + +type ApplicationConfigurationController struct { + appConfigService *service.AppConfigService +} + +func (acc *ApplicationConfigurationController) listApplicationConfigurationHandler(c *gin.Context) { + configuration, err := acc.appConfigService.ListApplicationConfiguration(false) + if err != nil { + utils.UnknownHandlerError(c, err) + return + } + + c.JSON(200, configuration) +} + +func (acc *ApplicationConfigurationController) listAllApplicationConfigurationHandler(c *gin.Context) { + configuration, err := acc.appConfigService.ListApplicationConfiguration(true) + if err != nil { + utils.UnknownHandlerError(c, err) + return + } + + c.JSON(200, configuration) +} + +func (acc *ApplicationConfigurationController) updateApplicationConfigurationHandler(c *gin.Context) { + var input model.AppConfigUpdateDto + if err := c.ShouldBindJSON(&input); err != nil { + utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error()) + return + } + + savedConfigVariables, err := acc.appConfigService.UpdateApplicationConfiguration(input) + if err != nil { + utils.UnknownHandlerError(c, err) + return + } + + c.JSON(http.StatusOK, savedConfigVariables) +} + +func (acc *ApplicationConfigurationController) getLogoHandler(c *gin.Context) { + imageType := acc.appConfigService.DbConfig.LogoImageType.Value + acc.getImage(c, "logo", imageType) +} + +func (acc *ApplicationConfigurationController) getFaviconHandler(c *gin.Context) { + acc.getImage(c, "favicon", "ico") +} + +func (acc *ApplicationConfigurationController) getBackgroundImageHandler(c *gin.Context) { + imageType := acc.appConfigService.DbConfig.BackgroundImageType.Value + acc.getImage(c, "background", imageType) +} + +func (acc *ApplicationConfigurationController) updateLogoHandler(c *gin.Context) { + imageType := acc.appConfigService.DbConfig.LogoImageType.Value + acc.updateImage(c, "logo", imageType) +} + +func (acc *ApplicationConfigurationController) updateFaviconHandler(c *gin.Context) { + file, err := c.FormFile("file") + if err != nil { + utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error()) + return + } + + fileType := utils.GetFileExtension(file.Filename) + if fileType != "ico" { + utils.HandlerError(c, http.StatusBadRequest, "File must be of type .ico") + return + } + acc.updateImage(c, "favicon", "ico") +} + +func (acc *ApplicationConfigurationController) updateBackgroundImageHandler(c *gin.Context) { + imageType := acc.appConfigService.DbConfig.BackgroundImageType.Value + acc.updateImage(c, "background", imageType) +} + +func (acc *ApplicationConfigurationController) getImage(c *gin.Context, name string, imageType string) { + imagePath := fmt.Sprintf("%s/application-images/%s.%s", common.EnvConfig.UploadPath, name, imageType) + mimeType := utils.GetImageMimeType(imageType) + + c.Header("Content-Type", mimeType) + c.File(imagePath) +} + +func (acc *ApplicationConfigurationController) updateImage(c *gin.Context, imageName string, oldImageType string) { + file, err := c.FormFile("file") + if err != nil { + utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error()) + return + } + + err = acc.appConfigService.UpdateImage(file, imageName, oldImageType) + if err != nil { + if errors.Is(err, common.ErrFileTypeNotSupported) { + utils.HandlerError(c, http.StatusBadRequest, err.Error()) + } else { + utils.UnknownHandlerError(c, err) + } + return + } + + c.Status(http.StatusNoContent) +} diff --git a/backend/internal/controller/oidc_controller.go b/backend/internal/controller/oidc_controller.go new file mode 100644 index 0000000..c0ea358 --- /dev/null +++ b/backend/internal/controller/oidc_controller.go @@ -0,0 +1,218 @@ +package controller + +import ( + "errors" + "github.com/gin-gonic/gin" + "github.com/stonith404/pocket-id/backend/internal/common" + "github.com/stonith404/pocket-id/backend/internal/middleware" + "github.com/stonith404/pocket-id/backend/internal/model" + "github.com/stonith404/pocket-id/backend/internal/service" + "github.com/stonith404/pocket-id/backend/internal/utils" + "net/http" + "strconv" +) + +func NewOidcController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.JwtAuthMiddleware, fileSizeLimitMiddleware *middleware.FileSizeLimitMiddleware, oidcService *service.OidcService) { + oc := &OidcController{OidcService: oidcService} + + group.POST("/oidc/authorize", jwtAuthMiddleware.Add(false), oc.authorizeHandler) + group.POST("/oidc/authorize/new-client", jwtAuthMiddleware.Add(false), oc.authorizeNewClientHandler) + group.POST("/oidc/token", oc.createIDTokenHandler) + + group.GET("/oidc/clients", jwtAuthMiddleware.Add(true), oc.listClientsHandler) + group.POST("/oidc/clients", jwtAuthMiddleware.Add(true), oc.createClientHandler) + group.GET("/oidc/clients/:id", oc.getClientHandler) + group.PUT("/oidc/clients/:id", jwtAuthMiddleware.Add(true), oc.updateClientHandler) + group.DELETE("/oidc/clients/:id", jwtAuthMiddleware.Add(true), oc.deleteClientHandler) + + group.POST("/oidc/clients/:id/secret", jwtAuthMiddleware.Add(true), oc.createClientSecretHandler) + + group.GET("/oidc/clients/:id/logo", oc.getClientLogoHandler) + group.DELETE("/oidc/clients/:id/logo", oc.deleteClientLogoHandler) + group.POST("/oidc/clients/:id/logo", jwtAuthMiddleware.Add(true), fileSizeLimitMiddleware.Add(2<<20), oc.updateClientLogoHandler) +} + +type OidcController struct { + OidcService *service.OidcService +} + +func (oc *OidcController) authorizeHandler(c *gin.Context) { + var parsedBody model.AuthorizeRequest + if err := c.ShouldBindJSON(&parsedBody); err != nil { + utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error()) + return + } + + code, err := oc.OidcService.Authorize(parsedBody, c.GetString("userID")) + if err != nil { + if errors.Is(err, common.ErrOidcMissingAuthorization) { + utils.HandlerError(c, http.StatusForbidden, err.Error()) + } else { + utils.UnknownHandlerError(c, err) + } + return + } + + c.JSON(http.StatusOK, gin.H{"code": code}) +} + +func (oc *OidcController) authorizeNewClientHandler(c *gin.Context) { + var parsedBody model.AuthorizeNewClientDto + if err := c.ShouldBindJSON(&parsedBody); err != nil { + utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error()) + return + } + + code, err := oc.OidcService.AuthorizeNewClient(parsedBody, c.GetString("userID")) + if err != nil { + utils.UnknownHandlerError(c, err) + return + } + + c.JSON(http.StatusOK, gin.H{"code": code}) +} + +func (oc *OidcController) createIDTokenHandler(c *gin.Context) { + var body model.OidcIdTokenDto + + if err := c.ShouldBind(&body); err != nil { + utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error()) + return + } + + idToken, err := oc.OidcService.CreateIDToken(body) + if err != nil { + if errors.Is(err, common.ErrOidcGrantTypeNotSupported) || + errors.Is(err, common.ErrOidcMissingClientCredentials) || + errors.Is(err, common.ErrOidcClientSecretInvalid) || + errors.Is(err, common.ErrOidcInvalidAuthorizationCode) { + utils.HandlerError(c, http.StatusBadRequest, err.Error()) + } else { + utils.UnknownHandlerError(c, err) + } + return + } + + c.JSON(http.StatusOK, gin.H{"id_token": idToken}) +} + +func (oc *OidcController) getClientHandler(c *gin.Context) { + clientId := c.Param("id") + client, err := oc.OidcService.GetClient(clientId) + if err != nil { + utils.UnknownHandlerError(c, err) + return + } + + c.JSON(http.StatusOK, client) +} + +func (oc *OidcController) listClientsHandler(c *gin.Context) { + page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) + pageSize, _ := strconv.Atoi(c.DefaultQuery("limit", "10")) + searchTerm := c.Query("search") + + clients, pagination, err := oc.OidcService.ListClients(searchTerm, page, pageSize) + if err != nil { + utils.UnknownHandlerError(c, err) + return + } + + c.JSON(http.StatusOK, gin.H{ + "data": clients, + "pagination": pagination, + }) +} + +func (oc *OidcController) createClientHandler(c *gin.Context) { + var input model.OidcClientCreateDto + if err := c.ShouldBindJSON(&input); err != nil { + utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error()) + return + } + + client, err := oc.OidcService.CreateClient(input, c.GetString("userID")) + if err != nil { + utils.UnknownHandlerError(c, err) + return + } + + c.JSON(http.StatusCreated, client) +} + +func (oc *OidcController) deleteClientHandler(c *gin.Context) { + err := oc.OidcService.DeleteClient(c.Param("id")) + if err != nil { + utils.HandlerError(c, http.StatusNotFound, "OIDC client not found") + return + } + + c.Status(http.StatusNoContent) +} + +func (oc *OidcController) updateClientHandler(c *gin.Context) { + var input model.OidcClientCreateDto + if err := c.ShouldBindJSON(&input); err != nil { + utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error()) + return + } + + client, err := oc.OidcService.UpdateClient(c.Param("id"), input) + if err != nil { + utils.UnknownHandlerError(c, err) + return + } + + c.JSON(http.StatusNoContent, client) +} + +func (oc *OidcController) createClientSecretHandler(c *gin.Context) { + secret, err := oc.OidcService.CreateClientSecret(c.Param("id")) + if err != nil { + utils.UnknownHandlerError(c, err) + return + } + + c.JSON(http.StatusOK, gin.H{"secret": secret}) +} + +func (oc *OidcController) getClientLogoHandler(c *gin.Context) { + imagePath, mimeType, err := oc.OidcService.GetClientLogo(c.Param("id")) + if err != nil { + utils.UnknownHandlerError(c, err) + return + } + + c.Header("Content-Type", mimeType) + c.File(imagePath) +} + +func (oc *OidcController) updateClientLogoHandler(c *gin.Context) { + file, err := c.FormFile("file") + if err != nil { + utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error()) + return + } + + err = oc.OidcService.UpdateClientLogo(c.Param("id"), file) + if err != nil { + if errors.Is(err, common.ErrFileTypeNotSupported) { + utils.HandlerError(c, http.StatusBadRequest, err.Error()) + } else { + utils.UnknownHandlerError(c, err) + } + return + } + + c.Status(http.StatusNoContent) +} + +func (oc *OidcController) deleteClientLogoHandler(c *gin.Context) { + err := oc.OidcService.DeleteClientLogo(c.Param("id")) + if err != nil { + utils.UnknownHandlerError(c, err) + return + } + + c.Status(http.StatusNoContent) +} diff --git a/backend/internal/controller/test_controller.go b/backend/internal/controller/test_controller.go new file mode 100644 index 0000000..4353cc8 --- /dev/null +++ b/backend/internal/controller/test_controller.go @@ -0,0 +1,36 @@ +package controller + +import ( + "github.com/gin-gonic/gin" + "github.com/stonith404/pocket-id/backend/internal/service" + "github.com/stonith404/pocket-id/backend/internal/utils" +) + +func NewTestController(group *gin.RouterGroup, testService *service.TestService) { + testController := &TestController{TestService: testService} + + group.POST("/test/reset", testController.resetAndSeedHandler) +} + +type TestController struct { + TestService *service.TestService +} + +func (tc *TestController) resetAndSeedHandler(c *gin.Context) { + if err := tc.TestService.ResetDatabase(); err != nil { + utils.UnknownHandlerError(c, err) + return + } + + if err := tc.TestService.ResetApplicationImages(); err != nil { + utils.UnknownHandlerError(c, err) + return + } + + if err := tc.TestService.SeedDatabase(); err != nil { + utils.UnknownHandlerError(c, err) + return + } + + c.JSON(200, gin.H{"message": "Database reset and seeded"}) +} diff --git a/backend/internal/controller/user_controller.go b/backend/internal/controller/user_controller.go new file mode 100644 index 0000000..30e29f7 --- /dev/null +++ b/backend/internal/controller/user_controller.go @@ -0,0 +1,182 @@ +package controller + +import ( + "errors" + "github.com/gin-gonic/gin" + "github.com/stonith404/pocket-id/backend/internal/common" + "github.com/stonith404/pocket-id/backend/internal/middleware" + "github.com/stonith404/pocket-id/backend/internal/model" + "github.com/stonith404/pocket-id/backend/internal/service" + "github.com/stonith404/pocket-id/backend/internal/utils" + "golang.org/x/time/rate" + "net/http" + "strconv" + "time" +) + +func NewUserController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.JwtAuthMiddleware, rateLimitMiddleware *middleware.RateLimitMiddleware, userService *service.UserService) { + uc := UserController{ + UserService: userService, + } + + group.GET("/users", jwtAuthMiddleware.Add(true), uc.listUsersHandler) + group.GET("/users/me", jwtAuthMiddleware.Add(false), uc.getCurrentUserHandler) + group.GET("/users/:id", jwtAuthMiddleware.Add(true), uc.getUserHandler) + group.POST("/users", jwtAuthMiddleware.Add(true), uc.createUserHandler) + group.PUT("/users/:id", jwtAuthMiddleware.Add(true), uc.updateUserHandler) + group.PUT("/users/me", jwtAuthMiddleware.Add(false), uc.updateCurrentUserHandler) + group.DELETE("/users/:id", jwtAuthMiddleware.Add(true), uc.deleteUserHandler) + + group.POST("/users/:id/one-time-access-token", jwtAuthMiddleware.Add(true), uc.createOneTimeAccessTokenHandler) + group.POST("/one-time-access-token/:token", rateLimitMiddleware.Add(rate.Every(10*time.Second), 5), uc.exchangeOneTimeAccessTokenHandler) + group.POST("/one-time-access-token/setup", uc.getSetupAccessTokenHandler) +} + +type UserController struct { + UserService *service.UserService +} + +func (uc *UserController) listUsersHandler(c *gin.Context) { + page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) + pageSize, _ := strconv.Atoi(c.DefaultQuery("limit", "10")) + searchTerm := c.Query("search") + + users, pagination, err := uc.UserService.ListUsers(searchTerm, page, pageSize) + if err != nil { + utils.UnknownHandlerError(c, err) + return + } + + c.JSON(http.StatusOK, gin.H{ + "data": users, + "pagination": pagination, + }) +} + +func (uc *UserController) getUserHandler(c *gin.Context) { + user, err := uc.UserService.GetUser(c.Param("id")) + if err != nil { + utils.UnknownHandlerError(c, err) + return + } + + c.JSON(http.StatusOK, user) +} + +func (uc *UserController) getCurrentUserHandler(c *gin.Context) { + user, err := uc.UserService.GetUser(c.GetString("userID")) + if err != nil { + utils.UnknownHandlerError(c, err) + return + } + c.JSON(http.StatusOK, user) +} + +func (uc *UserController) deleteUserHandler(c *gin.Context) { + if err := uc.UserService.DeleteUser(c.Param("id")); err != nil { + utils.UnknownHandlerError(c, err) + return + } + + c.Status(http.StatusNoContent) +} + +func (uc *UserController) createUserHandler(c *gin.Context) { + var user model.User + if err := c.ShouldBindJSON(&user); err != nil { + utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error()) + return + } + + if err := uc.UserService.CreateUser(&user); err != nil { + if errors.Is(err, common.ErrEmailTaken) || errors.Is(err, common.ErrUsernameTaken) { + utils.HandlerError(c, http.StatusConflict, err.Error()) + } else { + utils.UnknownHandlerError(c, err) + } + return + } + + c.JSON(http.StatusCreated, user) +} + +func (uc *UserController) updateUserHandler(c *gin.Context) { + uc.updateUser(c, false) +} + +func (uc *UserController) updateCurrentUserHandler(c *gin.Context) { + uc.updateUser(c, true) +} + +func (uc *UserController) createOneTimeAccessTokenHandler(c *gin.Context) { + var input model.OneTimeAccessTokenCreateDto + if err := c.ShouldBindJSON(&input); err != nil { + utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error()) + return + } + + token, err := uc.UserService.CreateOneTimeAccessToken(input.UserID, input.ExpiresAt) + if err != nil { + utils.UnknownHandlerError(c, err) + return + } + + c.JSON(http.StatusCreated, gin.H{"token": token}) +} + +func (uc *UserController) exchangeOneTimeAccessTokenHandler(c *gin.Context) { + user, token, err := uc.UserService.ExchangeOneTimeAccessToken(c.Param("token")) + if err != nil { + if errors.Is(err, common.ErrTokenInvalidOrExpired) { + utils.HandlerError(c, http.StatusUnauthorized, err.Error()) + } else { + utils.UnknownHandlerError(c, err) + } + return + } + + c.SetCookie("access_token", token, int(time.Hour.Seconds()), "/", "", false, true) + c.JSON(http.StatusOK, user) +} + +func (uc *UserController) getSetupAccessTokenHandler(c *gin.Context) { + user, token, err := uc.UserService.SetupInitialAdmin() + if err != nil { + if errors.Is(err, common.ErrSetupAlreadyCompleted) { + utils.HandlerError(c, http.StatusBadRequest, err.Error()) + } else { + utils.UnknownHandlerError(c, err) + } + return + } + + c.SetCookie("access_token", token, int(time.Hour.Seconds()), "/", "", false, true) + c.JSON(http.StatusOK, user) +} + +func (uc *UserController) updateUser(c *gin.Context, updateOwnUser bool) { + var updatedUser model.User + if err := c.ShouldBindJSON(&updatedUser); err != nil { + utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error()) + return + } + + var userID string + if updateOwnUser { + userID = c.GetString("userID") + } else { + userID = c.Param("id") + } + + user, err := uc.UserService.UpdateUser(userID, updatedUser, updateOwnUser) + if err != nil { + if errors.Is(err, common.ErrEmailTaken) || errors.Is(err, common.ErrUsernameTaken) { + utils.HandlerError(c, http.StatusConflict, err.Error()) + } else { + utils.UnknownHandlerError(c, err) + } + return + } + + c.JSON(http.StatusOK, user) +} diff --git a/backend/internal/controller/webauthn_controller.go b/backend/internal/controller/webauthn_controller.go new file mode 100644 index 0000000..1c7acf7 --- /dev/null +++ b/backend/internal/controller/webauthn_controller.go @@ -0,0 +1,160 @@ +package controller + +import ( + "errors" + "github.com/go-webauthn/webauthn/protocol" + "github.com/stonith404/pocket-id/backend/internal/middleware" + "github.com/stonith404/pocket-id/backend/internal/model" + "log" + "net/http" + "time" + + "github.com/gin-gonic/gin" + "github.com/stonith404/pocket-id/backend/internal/common" + "github.com/stonith404/pocket-id/backend/internal/service" + "github.com/stonith404/pocket-id/backend/internal/utils" + "golang.org/x/time/rate" +) + +func NewWebauthnController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.JwtAuthMiddleware, rateLimitMiddleware *middleware.RateLimitMiddleware, webauthnService *service.WebAuthnService, jwtService *service.JwtService) { + wc := &WebauthnController{webAuthnService: webauthnService, jwtService: jwtService} + group.GET("/webauthn/register/start", jwtAuthMiddleware.Add(false), wc.beginRegistrationHandler) + group.POST("/webauthn/register/finish", jwtAuthMiddleware.Add(false), wc.verifyRegistrationHandler) + + group.GET("/webauthn/login/start", wc.beginLoginHandler) + group.POST("/webauthn/login/finish", rateLimitMiddleware.Add(rate.Every(10*time.Second), 5), wc.verifyLoginHandler) + + group.POST("/webauthn/logout", jwtAuthMiddleware.Add(false), wc.logoutHandler) + + group.GET("/webauthn/credentials", jwtAuthMiddleware.Add(false), wc.listCredentialsHandler) + group.PATCH("/webauthn/credentials/:id", jwtAuthMiddleware.Add(false), wc.updateCredentialHandler) + group.DELETE("/webauthn/credentials/:id", jwtAuthMiddleware.Add(false), wc.deleteCredentialHandler) +} + +type WebauthnController struct { + webAuthnService *service.WebAuthnService + jwtService *service.JwtService +} + +func (wc *WebauthnController) beginRegistrationHandler(c *gin.Context) { + userID := c.GetString("userID") + options, err := wc.webAuthnService.BeginRegistration(userID) + if err != nil { + utils.UnknownHandlerError(c, err) + log.Println(err) + return + } + + c.SetCookie("session_id", options.SessionID, int(options.Timeout.Seconds()), "/", "", false, true) + c.JSON(http.StatusOK, options.Response) +} + +func (wc *WebauthnController) verifyRegistrationHandler(c *gin.Context) { + sessionID, err := c.Cookie("session_id") + if err != nil { + utils.HandlerError(c, http.StatusBadRequest, "Session ID missing") + return + } + + userID := c.GetString("userID") + credential, err := wc.webAuthnService.VerifyRegistration(sessionID, userID, c.Request) + if err != nil { + utils.UnknownHandlerError(c, err) + return + } + + c.JSON(http.StatusOK, credential) +} + +func (wc *WebauthnController) beginLoginHandler(c *gin.Context) { + options, err := wc.webAuthnService.BeginLogin() + if err != nil { + utils.UnknownHandlerError(c, err) + return + } + + c.SetCookie("session_id", options.SessionID, int(options.Timeout.Seconds()), "/", "", false, true) + c.JSON(http.StatusOK, options.Response) +} + +func (wc *WebauthnController) verifyLoginHandler(c *gin.Context) { + sessionID, err := c.Cookie("session_id") + if err != nil { + utils.HandlerError(c, http.StatusBadRequest, "Session ID missing") + return + } + + credentialAssertionData, err := protocol.ParseCredentialRequestResponseBody(c.Request.Body) + if err != nil { + utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error()) + return + } + + userID := c.GetString("userID") + user, err := wc.webAuthnService.VerifyLogin(sessionID, userID, credentialAssertionData) + if err != nil { + if errors.Is(err, common.ErrInvalidCredentials) { + utils.HandlerError(c, http.StatusUnauthorized, err.Error()) + } else { + utils.UnknownHandlerError(c, err) + } + return + } + + token, err := wc.jwtService.GenerateAccessToken(*user) + if err != nil { + utils.UnknownHandlerError(c, err) + return + } + + c.SetCookie("access_token", token, int(time.Hour.Seconds()), "/", "", false, true) + c.JSON(http.StatusOK, user) +} + +func (wc *WebauthnController) listCredentialsHandler(c *gin.Context) { + userID := c.GetString("userID") + credentials, err := wc.webAuthnService.ListCredentials(userID) + if err != nil { + utils.UnknownHandlerError(c, err) + return + } + + c.JSON(http.StatusOK, credentials) +} + +func (wc *WebauthnController) deleteCredentialHandler(c *gin.Context) { + userID := c.GetString("userID") + credentialID := c.Param("id") + + err := wc.webAuthnService.DeleteCredential(userID, credentialID) + if err != nil { + utils.UnknownHandlerError(c, err) + return + } + + c.Status(http.StatusNoContent) +} + +func (wc *WebauthnController) updateCredentialHandler(c *gin.Context) { + userID := c.GetString("userID") + credentialID := c.Param("id") + + var input model.WebauthnCredentialUpdateDto + if err := c.ShouldBindJSON(&input); err != nil { + utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error()) + return + } + + err := wc.webAuthnService.UpdateCredential(userID, credentialID, input.Name) + if err != nil { + utils.UnknownHandlerError(c, err) + return + } + + c.Status(http.StatusNoContent) +} + +func (wc *WebauthnController) logoutHandler(c *gin.Context) { + c.SetCookie("access_token", "", 0, "/", "", false, true) + c.Status(http.StatusNoContent) +} diff --git a/backend/internal/handler/well_known.go b/backend/internal/controller/well_known_controller.go similarity index 57% rename from backend/internal/handler/well_known.go rename to backend/internal/controller/well_known_controller.go index 242cc32..17fb039 100644 --- a/backend/internal/handler/well_known.go +++ b/backend/internal/controller/well_known_controller.go @@ -1,19 +1,25 @@ -package handler +package controller import ( "github.com/gin-gonic/gin" - "golang-rest-api-template/internal/common" - "golang-rest-api-template/internal/utils" + "github.com/stonith404/pocket-id/backend/internal/common" + "github.com/stonith404/pocket-id/backend/internal/service" + "github.com/stonith404/pocket-id/backend/internal/utils" "net/http" ) -func RegisterWellKnownRoutes(group *gin.RouterGroup) { - group.GET("/.well-known/jwks.json", jwks) - group.GET("/.well-known/openid-configuration", openIDConfiguration) +func NewWellKnownController(group *gin.RouterGroup, jwtService *service.JwtService) { + wkc := &WellKnownController{jwtService: jwtService} + group.GET("/.well-known/jwks.json", wkc.jwksHandler) + group.GET("/.well-known/openid-configuration", wkc.openIDConfigurationHandler) } -func jwks(c *gin.Context) { - jwk, err := common.GetJWK() +type WellKnownController struct { + jwtService *service.JwtService +} + +func (wkc *WellKnownController) jwksHandler(c *gin.Context) { + jwk, err := wkc.jwtService.GetJWK() if err != nil { utils.UnknownHandlerError(c, err) return @@ -22,7 +28,7 @@ func jwks(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"keys": []interface{}{jwk}}) } -func openIDConfiguration(c *gin.Context) { +func (wkc *WellKnownController) openIDConfigurationHandler(c *gin.Context) { appUrl := common.EnvConfig.AppURL config := map[string]interface{}{ "issuer": appUrl, diff --git a/backend/internal/handler/application_configuration.go b/backend/internal/handler/application_configuration.go deleted file mode 100644 index 62a3efc..0000000 --- a/backend/internal/handler/application_configuration.go +++ /dev/null @@ -1,196 +0,0 @@ -package handler - -import ( - "errors" - "fmt" - "github.com/gin-gonic/gin" - "golang-rest-api-template/internal/common" - "golang-rest-api-template/internal/common/middleware" - "golang-rest-api-template/internal/model" - "golang-rest-api-template/internal/utils" - "gorm.io/gorm" - "net/http" - "os" - "reflect" -) - -func RegisterConfigurationRoutes(group *gin.RouterGroup) { - group.GET("/application-configuration", listApplicationConfigurationHandler) - group.GET("/application-configuration/all", middleware.JWTAuth(true), listAllApplicationConfigurationHandler) - group.PUT("/application-configuration", updateApplicationConfigurationHandler) - - group.GET("/application-configuration/logo", getLogoHandler) - group.GET("/application-configuration/background-image", getBackgroundImageHandler) - group.GET("/application-configuration/favicon", getFaviconHandler) - group.PUT("/application-configuration/logo", middleware.JWTAuth(true), updateLogoHandler) - group.PUT("/application-configuration/favicon", middleware.JWTAuth(true), updateFaviconHandler) - group.PUT("/application-configuration/background-image", middleware.JWTAuth(true), updateBackgroundImageHandler) -} - -func listApplicationConfigurationHandler(c *gin.Context) { - listApplicationConfiguration(c, false) -} - -func listAllApplicationConfigurationHandler(c *gin.Context) { - listApplicationConfiguration(c, true) -} - -func updateApplicationConfigurationHandler(c *gin.Context) { - var input model.ApplicationConfigurationUpdateDto - if err := c.ShouldBindJSON(&input); err != nil { - utils.HandlerError(c, http.StatusBadRequest, "invalid request body") - return - } - - savedConfigVariables := make([]model.ApplicationConfigurationVariable, 10) - - tx := common.DB.Begin() - rt := reflect.ValueOf(input).Type() - rv := reflect.ValueOf(input) - - // Loop over the input struct fields and update the related configuration variables - for i := 0; i < rt.NumField(); i++ { - field := rt.Field(i) - key := field.Tag.Get("json") - value := rv.FieldByName(field.Name).String() - - // Get the existing configuration variable from the db - var applicationConfigurationVariable model.ApplicationConfigurationVariable - if err := tx.First(&applicationConfigurationVariable, "key = ? AND is_internal = false", key).Error; err != nil { - tx.Rollback() - if errors.Is(err, gorm.ErrRecordNotFound) { - utils.HandlerError(c, http.StatusNotFound, fmt.Sprintf("Invalid configuration variable '%s'", value)) - } else { - utils.UnknownHandlerError(c, err) - } - return - } - - // Update the value of the existing configuration variable and save it - applicationConfigurationVariable.Value = value - if err := tx.Save(&applicationConfigurationVariable).Error; err != nil { - tx.Rollback() - utils.UnknownHandlerError(c, err) - return - } - - savedConfigVariables[i] = applicationConfigurationVariable - } - - tx.Commit() - - if err := common.LoadDbConfigFromDb(); err != nil { - utils.UnknownHandlerError(c, err) - } - - c.JSON(http.StatusOK, savedConfigVariables) - -} - -func getLogoHandler(c *gin.Context) { - imagType := common.DbConfig.LogoImageType.Value - getImage(c, "logo", imagType) -} - -func getFaviconHandler(c *gin.Context) { - getImage(c, "favicon", "ico") -} - -func getBackgroundImageHandler(c *gin.Context) { - imageType := common.DbConfig.BackgroundImageType.Value - getImage(c, "background", imageType) -} - -func updateLogoHandler(c *gin.Context) { - imageType := common.DbConfig.LogoImageType.Value - updateImage(c, "logo", imageType) -} - -func updateFaviconHandler(c *gin.Context) { - file, err := c.FormFile("file") - if err != nil { - utils.HandlerError(c, http.StatusBadRequest, "invalid request body") - return - } - fileType := utils.GetFileExtension(file.Filename) - if fileType != "ico" { - utils.HandlerError(c, http.StatusBadRequest, "File must be of type .ico") - return - } - updateImage(c, "favicon", "ico") -} - -func updateBackgroundImageHandler(c *gin.Context) { - imagType := common.DbConfig.BackgroundImageType.Value - updateImage(c, "background", imagType) -} - -func getImage(c *gin.Context, name string, imageType string) { - imagePath := fmt.Sprintf("%s/application-images/%s.%s", common.EnvConfig.UploadPath, name, imageType) - mimeType := utils.GetImageMimeType(imageType) - - c.Header("Content-Type", mimeType) - c.File(imagePath) -} - -func updateImage(c *gin.Context, imageName string, oldImageType string) { - file, err := c.FormFile("file") - if err != nil { - utils.HandlerError(c, http.StatusBadRequest, "invalid request body") - return - } - - fileType := utils.GetFileExtension(file.Filename) - if mimeType := utils.GetImageMimeType(fileType); mimeType == "" { - utils.HandlerError(c, http.StatusBadRequest, "File type not supported") - return - } - - // Delete the old image if it has a different file type - if fileType != oldImageType { - oldImagePath := fmt.Sprintf("%s/application-images/%s.%s", common.EnvConfig.UploadPath, imageName, oldImageType) - if err := os.Remove(oldImagePath); err != nil { - utils.UnknownHandlerError(c, err) - return - } - } - - imagePath := fmt.Sprintf("%s/application-images/%s.%s", common.EnvConfig.UploadPath, imageName, fileType) - err = c.SaveUploadedFile(file, imagePath) - if err != nil { - utils.UnknownHandlerError(c, err) - return - } - - // Update the file type in the database - key := fmt.Sprintf("%sImageType", imageName) - err = common.DB.Model(&model.ApplicationConfigurationVariable{}).Where("key = ?", key).Update("value", fileType).Error - if err != nil { - utils.UnknownHandlerError(c, err) - return - } - - if err := common.LoadDbConfigFromDb(); err != nil { - utils.UnknownHandlerError(c, err) - } - - c.Status(http.StatusNoContent) -} - -func listApplicationConfiguration(c *gin.Context, showAll bool) { - var configuration []model.ApplicationConfigurationVariable - var err error - - if showAll { - err = common.DB.Find(&configuration).Error - } else { - err = common.DB.Find(&configuration, "is_public = true").Error - } - - if err != nil { - utils.UnknownHandlerError(c, err) - return - } - - c.JSON(200, configuration) -} diff --git a/backend/internal/handler/oidc.go b/backend/internal/handler/oidc.go deleted file mode 100644 index f45c6b3..0000000 --- a/backend/internal/handler/oidc.go +++ /dev/null @@ -1,415 +0,0 @@ -package handler - -import ( - "errors" - "fmt" - "github.com/gin-gonic/gin" - "golang-rest-api-template/internal/common" - "golang-rest-api-template/internal/common/middleware" - "golang-rest-api-template/internal/model" - "golang-rest-api-template/internal/utils" - "golang.org/x/crypto/bcrypt" - "gorm.io/gorm" - "net/http" - "os" - "time" -) - -func RegisterOIDCRoutes(group *gin.RouterGroup) { - group.POST("/oidc/authorize", middleware.JWTAuth(false), authorizeHandler) - group.POST("/oidc/authorize/new-client", middleware.JWTAuth(false), authorizeNewClientHandler) - group.POST("/oidc/token", createIDTokenHandler) - - group.GET("/oidc/clients", middleware.JWTAuth(true), listClientsHandler) - group.POST("/oidc/clients", middleware.JWTAuth(true), createClientHandler) - group.GET("/oidc/clients/:id", getClientHandler) - group.PUT("/oidc/clients/:id", middleware.JWTAuth(true), updateClientHandler) - group.DELETE("/oidc/clients/:id", middleware.JWTAuth(true), deleteClientHandler) - - group.POST("/oidc/clients/:id/secret", middleware.JWTAuth(true), createClientSecretHandler) - - group.GET("/oidc/clients/:id/logo", getClientLogoHandler) - group.DELETE("/oidc/clients/:id/logo", deleteClientLogoHandler) - group.POST("/oidc/clients/:id/logo", middleware.JWTAuth(true), middleware.LimitFileSize(2<<20), updateClientLogoHandler) -} - -type AuthorizeRequest struct { - ClientID string `json:"clientID" binding:"required"` - Scope string `json:"scope" binding:"required"` - Nonce string `json:"nonce"` -} - -func authorizeHandler(c *gin.Context) { - var parsedBody AuthorizeRequest - if err := c.ShouldBindJSON(&parsedBody); err != nil { - utils.HandlerError(c, http.StatusBadRequest, "invalid request body") - return - } - - var userAuthorizedOIDCClient model.UserAuthorizedOidcClient - common.DB.First(&userAuthorizedOIDCClient, "client_id = ? AND user_id = ?", parsedBody.ClientID, c.GetString("userID")) - - // If the record isn't found or the scope is different return an error - // The client will have to call the authorizeNewClientHandler - if userAuthorizedOIDCClient.Scope != parsedBody.Scope { - utils.HandlerError(c, http.StatusForbidden, "missing authorization") - return - } - - authorizationCode, err := createAuthorizationCode(parsedBody.ClientID, c.GetString("userID"), parsedBody.Scope, parsedBody.Nonce) - if err != nil { - utils.UnknownHandlerError(c, err) - return - } - - c.JSON(http.StatusOK, gin.H{"code": authorizationCode}) -} - -// authorizeNewClientHandler authorizes a new client for the user -// a new client is a new client when the user has not authorized the client before -func authorizeNewClientHandler(c *gin.Context) { - var parsedBody model.AuthorizeNewClientDto - if err := c.ShouldBindJSON(&parsedBody); err != nil { - utils.HandlerError(c, http.StatusBadRequest, "invalid request body") - return - } - - userAuthorizedClient := model.UserAuthorizedOidcClient{ - UserID: c.GetString("userID"), - ClientID: parsedBody.ClientID, - Scope: parsedBody.Scope, - } - err := common.DB.Create(&userAuthorizedClient).Error - - if err != nil && errors.Is(err, gorm.ErrDuplicatedKey) { - err = common.DB.Model(&userAuthorizedClient).Update("scope", parsedBody.Scope).Error - } - - if err != nil { - utils.UnknownHandlerError(c, err) - return - } - - authorizationCode, err := createAuthorizationCode(parsedBody.ClientID, c.GetString("userID"), parsedBody.Scope, parsedBody.Nonce) - if err != nil { - utils.UnknownHandlerError(c, err) - return - } - - c.JSON(http.StatusOK, gin.H{"code": authorizationCode}) - -} - -func createIDTokenHandler(c *gin.Context) { - var body model.OidcIdTokenDto - - if err := c.ShouldBind(&body); err != nil { - utils.HandlerError(c, http.StatusBadRequest, "invalid request body") - return - } - - // Currently only authorization_code grant type is supported - if body.GrantType != "authorization_code" { - utils.HandlerError(c, http.StatusBadRequest, "grant type not supported") - return - } - - clientID := body.ClientID - clientSecret := body.ClientSecret - - // Client id and secret can also be passed over the Authorization header - if clientID == "" || clientSecret == "" { - var ok bool - clientID, clientSecret, ok = c.Request.BasicAuth() - if !ok { - utils.HandlerError(c, http.StatusBadRequest, "Client id and secret not provided") - return - } - } - - // Get the client - var client model.OidcClient - err := common.DB.First(&client, "id = ?", clientID, clientSecret).Error - if err != nil { - utils.HandlerError(c, http.StatusBadRequest, "OIDC OIDC client not found") - return - } - - // Check if client secret is correct - err = bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret)) - if err != nil { - utils.HandlerError(c, http.StatusBadRequest, "invalid client secret") - return - } - - var authorizationCodeMetaData model.OidcAuthorizationCode - err = common.DB.Preload("User").First(&authorizationCodeMetaData, "code = ?", body.Code).Error - if err != nil { - utils.HandlerError(c, http.StatusBadRequest, "invalid authorization code") - return - } - - // Check if the client id matches the client id in the authorization code and if the code has expired - if authorizationCodeMetaData.ClientID != clientID && authorizationCodeMetaData.ExpiresAt.Before(time.Now()) { - utils.HandlerError(c, http.StatusBadRequest, "invalid authorization code") - return - } - - idToken, e := common.GenerateIDToken(authorizationCodeMetaData.User, clientID, authorizationCodeMetaData.Scope, authorizationCodeMetaData.Nonce) - if e != nil { - utils.UnknownHandlerError(c, err) - return - } - - // Delete the authorization code after it has been used - common.DB.Delete(&authorizationCodeMetaData) - - c.JSON(http.StatusOK, gin.H{"id_token": idToken}) -} - -func getClientHandler(c *gin.Context) { - clientId := c.Param("id") - - var client model.OidcClient - err := common.DB.First(&client, "id = ?", clientId).Error - if err != nil { - utils.HandlerError(c, http.StatusNotFound, "OIDC client not found") - return - } - - c.JSON(http.StatusOK, client) -} - -func listClientsHandler(c *gin.Context) { - var clients []model.OidcClient - searchTerm := c.Query("search") - - query := common.DB.Model(&model.OidcClient{}) - - if searchTerm != "" { - searchPattern := "%" + searchTerm + "%" - query = query.Where("name LIKE ?", searchPattern) - } - - pagination, err := utils.Paginate(c, query, &clients) - if err != nil { - utils.UnknownHandlerError(c, err) - return - } - - c.JSON(http.StatusOK, gin.H{ - "data": clients, - "pagination": pagination, - }) -} - -func createClientHandler(c *gin.Context) { - var input model.OidcClientCreateDto - if err := c.ShouldBindJSON(&input); err != nil { - utils.HandlerError(c, http.StatusBadRequest, "invalid request body") - return - } - - client := model.OidcClient{ - Name: input.Name, - CallbackURL: input.CallbackURL, - CreatedByID: c.GetString("userID"), - } - - if err := common.DB.Create(&client).Error; err != nil { - utils.UnknownHandlerError(c, err) - return - } - - c.JSON(http.StatusCreated, client) -} - -func deleteClientHandler(c *gin.Context) { - var client model.OidcClient - if err := common.DB.First(&client, "id = ?", c.Param("id")).Error; err != nil { - utils.HandlerError(c, http.StatusNotFound, "OIDC OIDC client not found") - return - } - - if err := common.DB.Delete(&client).Error; err != nil { - utils.UnknownHandlerError(c, err) - return - } - - c.Status(http.StatusNoContent) -} - -func updateClientHandler(c *gin.Context) { - var input model.OidcClientCreateDto - if err := c.ShouldBindJSON(&input); err != nil { - utils.HandlerError(c, http.StatusBadRequest, "invalid request body") - return - } - - var client model.OidcClient - if err := common.DB.First(&client, "id = ?", c.Param("id")).Error; err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - utils.HandlerError(c, http.StatusNotFound, "OIDC client not found") - return - } - utils.UnknownHandlerError(c, err) - return - } - - client.Name = input.Name - client.CallbackURL = input.CallbackURL - - if err := common.DB.Save(&client).Error; err != nil { - utils.UnknownHandlerError(c, err) - return - } - - c.JSON(http.StatusNoContent, client) -} - -// createClientSecretHandler creates a new secret for the client and revokes the old one -func createClientSecretHandler(c *gin.Context) { - var client model.OidcClient - if err := common.DB.First(&client, "id = ?", c.Param("id")).Error; err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - utils.HandlerError(c, http.StatusNotFound, "OIDC client not found") - return - } - utils.UnknownHandlerError(c, err) - return - } - - clientSecret, err := utils.GenerateRandomAlphanumericString(32) - if err != nil { - utils.UnknownHandlerError(c, err) - return - } - - hashedSecret, err := bcrypt.GenerateFromPassword([]byte(clientSecret), bcrypt.DefaultCost) - if err != nil { - utils.UnknownHandlerError(c, err) - return - } - - client.Secret = string(hashedSecret) - if err := common.DB.Save(&client).Error; err != nil { - utils.UnknownHandlerError(c, err) - return - } - - c.JSON(http.StatusOK, gin.H{"secret": clientSecret}) -} - -func getClientLogoHandler(c *gin.Context) { - var client model.OidcClient - if err := common.DB.First(&client, "id = ?", c.Param("id")).Error; err != nil { - utils.HandlerError(c, http.StatusNotFound, "OIDC client not found") - return - } - - if client.ImageType == nil { - utils.HandlerError(c, http.StatusNotFound, "image not found") - return - } - - imageType := *client.ImageType - - imagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, client.ID, imageType) - mimeType := utils.GetImageMimeType(imageType) - - c.Header("Content-Type", mimeType) - c.File(imagePath) -} - -func updateClientLogoHandler(c *gin.Context) { - file, err := c.FormFile("file") - if err != nil { - utils.HandlerError(c, http.StatusBadRequest, "invalid request body") - return - } - - fileType := utils.GetFileExtension(file.Filename) - if mimeType := utils.GetImageMimeType(fileType); mimeType == "" { - utils.HandlerError(c, http.StatusBadRequest, "file type not supported") - return - } - - imagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, c.Param("id"), fileType) - err = c.SaveUploadedFile(file, imagePath) - if err != nil { - utils.UnknownHandlerError(c, err) - return - } - - var client model.OidcClient - if err := common.DB.First(&client, "id = ?", c.Param("id")).Error; err != nil { - utils.HandlerError(c, http.StatusNotFound, "OIDC client not found") - return - } - - // Delete the old image if it has a different file type - if client.ImageType != nil && fileType != *client.ImageType { - oldImagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, client.ID, *client.ImageType) - if err := os.Remove(oldImagePath); err != nil { - utils.UnknownHandlerError(c, err) - return - } - } - - client.ImageType = &fileType - if err := common.DB.Save(&client).Error; err != nil { - utils.UnknownHandlerError(c, err) - return - } - - c.Status(http.StatusNoContent) -} - -func deleteClientLogoHandler(c *gin.Context) { - var client model.OidcClient - if err := common.DB.First(&client, "id = ?", c.Param("id")).Error; err != nil { - utils.HandlerError(c, http.StatusNotFound, "OIDC client not found") - return - } - - if client.ImageType == nil { - utils.HandlerError(c, http.StatusNotFound, "image not found") - return - } - - imagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, client.ID, *client.ImageType) - if err := os.Remove(imagePath); err != nil { - utils.UnknownHandlerError(c, err) - return - } - - client.ImageType = nil - if err := common.DB.Save(&client).Error; err != nil { - utils.UnknownHandlerError(c, err) - return - } - - c.Status(http.StatusNoContent) -} - -func createAuthorizationCode(clientID string, userID string, scope string, nonce string) (string, error) { - randomString, err := utils.GenerateRandomAlphanumericString(32) - if err != nil { - return "", err - } - - oidcAuthorizationCode := model.OidcAuthorizationCode{ - ExpiresAt: time.Now().Add(15 * time.Minute), - Code: randomString, - ClientID: clientID, - UserID: userID, - Scope: scope, - Nonce: nonce, - } - - if err := common.DB.Create(&oidcAuthorizationCode).Error; err != nil { - return "", err - } - - return randomString, nil -} diff --git a/backend/internal/handler/user.go b/backend/internal/handler/user.go deleted file mode 100644 index afbf083..0000000 --- a/backend/internal/handler/user.go +++ /dev/null @@ -1,276 +0,0 @@ -package handler - -import ( - "errors" - "github.com/gin-gonic/gin" - "golang-rest-api-template/internal/common" - "golang-rest-api-template/internal/common/middleware" - "golang-rest-api-template/internal/model" - "golang-rest-api-template/internal/utils" - "golang.org/x/time/rate" - "gorm.io/gorm" - "log" - "net/http" - "time" -) - -func RegisterUserRoutes(group *gin.RouterGroup) { - group.GET("/users", middleware.JWTAuth(true), listUsersHandler) - group.GET("/users/me", middleware.JWTAuth(false), getCurrentUserHandler) - group.GET("/users/:id", middleware.JWTAuth(true), getUserHandler) - group.POST("/users", middleware.JWTAuth(true), createUserHandler) - group.PUT("/users/:id", middleware.JWTAuth(true), updateUserHandler) - group.PUT("/users/me", middleware.JWTAuth(false), updateCurrentUserHandler) - group.DELETE("/users/:id", middleware.JWTAuth(true), deleteUserHandler) - - group.POST("/users/:id/one-time-access-token", middleware.JWTAuth(true), createOneTimeAccessTokenHandler) - group.POST("/one-time-access-token/:token", middleware.RateLimiter(rate.Every(10*time.Second), 5), exchangeOneTimeAccessTokenHandler) - group.POST("/one-time-access-token/setup", getSetupAccessTokenHandler) -} - -func listUsersHandler(c *gin.Context) { - var users []model.User - searchTerm := c.Query("search") - - query := common.DB.Model(&model.User{}) - - if searchTerm != "" { - searchPattern := "%" + searchTerm + "%" - query = query.Where("email LIKE ? OR first_name LIKE ? OR username LIKE ?", searchPattern, searchPattern, searchPattern) - } - - pagination, err := utils.Paginate(c, query, &users) - if err != nil { - utils.UnknownHandlerError(c, err) - return - } - - c.JSON(http.StatusOK, gin.H{ - "data": users, - "pagination": pagination, - }) -} - -func getUserHandler(c *gin.Context) { - var user model.User - if err := common.DB.Where("id = ?", c.Param("id")).First(&user).Error; err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - utils.HandlerError(c, http.StatusNotFound, "User not found") - return - } - utils.UnknownHandlerError(c, err) - return - } - - c.JSON(http.StatusOK, user) -} - -func getCurrentUserHandler(c *gin.Context) { - var user model.User - if err := common.DB.Where("id = ?", c.GetString("userID")).First(&user).Error; err != nil { - utils.UnknownHandlerError(c, err) - return - } - c.JSON(http.StatusOK, user) - -} - -func deleteUserHandler(c *gin.Context) { - var user model.User - if err := common.DB.Where("id = ?", c.Param("id")).First(&user).Error; err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - utils.HandlerError(c, http.StatusNotFound, "User not found") - return - } - utils.UnknownHandlerError(c, err) - return - } - - if err := common.DB.Delete(&user).Error; err != nil { - utils.UnknownHandlerError(c, err) - return - } - - c.Status(http.StatusNoContent) -} - -func createUserHandler(c *gin.Context) { - var user model.User - if err := c.ShouldBindJSON(&user); err != nil { - utils.HandlerError(c, http.StatusBadRequest, "invalid request body") - return - } - - if err := common.DB.Create(&user).Error; err != nil { - if errors.Is(err, gorm.ErrDuplicatedKey) { - if err := checkDuplicatedFields(user); err != nil { - utils.HandlerError(c, http.StatusBadRequest, err.Error()) - return - } - } else { - utils.UnknownHandlerError(c, err) - return - } - } - - c.JSON(http.StatusCreated, user) -} - -func updateUserHandler(c *gin.Context) { - updateUser(c, c.Param("id"), false) -} - -func updateCurrentUserHandler(c *gin.Context) { - updateUser(c, c.GetString("userID"), true) -} - -func createOneTimeAccessTokenHandler(c *gin.Context) { - var input model.OneTimeAccessTokenCreateDto - if err := c.ShouldBindJSON(&input); err != nil { - utils.HandlerError(c, http.StatusBadRequest, "invalid request body") - return - } - - randomString, err := utils.GenerateRandomAlphanumericString(16) - if err != nil { - utils.UnknownHandlerError(c, err) - return - } - - oneTimeAccessToken := model.OneTimeAccessToken{ - UserID: input.UserID, - ExpiresAt: input.ExpiresAt, - Token: randomString, - } - - if err := common.DB.Create(&oneTimeAccessToken).Error; err != nil { - utils.UnknownHandlerError(c, err) - return - } - - c.JSON(http.StatusCreated, gin.H{"token": oneTimeAccessToken.Token}) -} - -func exchangeOneTimeAccessTokenHandler(c *gin.Context) { - var oneTimeAccessToken model.OneTimeAccessToken - if err := common.DB.Where("token = ? AND expires_at > ?", c.Param("token"), utils.FormatDateForDb(time.Now())).Preload("User").First(&oneTimeAccessToken).Error; err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - utils.HandlerError(c, http.StatusForbidden, "Token is invalid or expired") - return - } - utils.UnknownHandlerError(c, err) - return - } - - token, err := common.GenerateAccessToken(oneTimeAccessToken.User) - if err != nil { - utils.UnknownHandlerError(c, err) - log.Println(err) - return - } - - if err := common.DB.Delete(&oneTimeAccessToken).Error; err != nil { - utils.UnknownHandlerError(c, err) - return - } - - c.SetCookie("access_token", token, int(time.Hour.Seconds()), "/", "", false, true) - - c.JSON(http.StatusOK, oneTimeAccessToken.User) -} - -// getSetupAccessTokenHandler creates the initial admin user and returns an access token for the user -// This handler is only available if there are no users in the database -func getSetupAccessTokenHandler(c *gin.Context) { - var userCount int64 - if err := common.DB.Model(&model.User{}).Count(&userCount).Error; err != nil { - log.Fatal("failed to count users", err) - } - - // If there are more than one user, we don't need to create the admin user - if userCount > 1 { - utils.HandlerError(c, http.StatusForbidden, "Setup already completed") - return - } - - var user = model.User{ - FirstName: "Admin", - LastName: "Admin", - Username: "admin", - Email: "admin@admin.com", - IsAdmin: true, - } - - // Create the initial admin user if it doesn't exist - if err := common.DB.Model(&model.User{}).Preload("Credentials").FirstOrCreate(&user).Error; err != nil { - log.Fatal("failed to create admin user", err) - } - - // If the user already has credentials, the setup is already completed - if len(user.Credentials) > 0 { - utils.HandlerError(c, http.StatusForbidden, "Setup already completed") - return - } - - token, err := common.GenerateAccessToken(user) - if err != nil { - utils.UnknownHandlerError(c, err) - log.Println(err) - return - } - c.SetCookie("access_token", token, int(time.Hour.Seconds()), "/", "", false, true) - c.JSON(http.StatusOK, user) -} - -func updateUser(c *gin.Context, userID string, updateOwnUser bool) { - var user model.User - if err := common.DB.Where("id = ?", userID).First(&user).Error; err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - utils.HandlerError(c, http.StatusNotFound, "User not found") - return - } - utils.UnknownHandlerError(c, err) - return - } - var updatedUser model.User - if err := c.ShouldBindJSON(&updatedUser); err != nil { - utils.HandlerError(c, http.StatusBadRequest, "invalid request body") - return - } - - user.FirstName = updatedUser.FirstName - user.LastName = updatedUser.LastName - user.Email = updatedUser.Email - user.Username = updatedUser.Username - user.Username = updatedUser.Username - if !updateOwnUser { - user.IsAdmin = updatedUser.IsAdmin - } - - if err := common.DB.Save(user).Error; err != nil { - if errors.Is(err, gorm.ErrDuplicatedKey) { - if err := checkDuplicatedFields(user); err != nil { - utils.HandlerError(c, http.StatusBadRequest, err.Error()) - return - } - } else { - utils.UnknownHandlerError(c, err) - return - } - } - c.JSON(http.StatusOK, user) -} - -func checkDuplicatedFields(user model.User) error { - var existingUser model.User - - if common.DB.Where("id != ? AND email = ?", user.ID, user.Email).First(&existingUser).Error == nil { - return errors.New("email is already taken") - } - - if common.DB.Where("id != ? AND username = ?", user.ID, user.Username).First(&existingUser).Error == nil { - return errors.New("username is already taken") - } - - return nil -} diff --git a/backend/internal/handler/webauthn.go b/backend/internal/handler/webauthn.go deleted file mode 100644 index ce81c4e..0000000 --- a/backend/internal/handler/webauthn.go +++ /dev/null @@ -1,257 +0,0 @@ -package handler - -import ( - "github.com/gin-gonic/gin" - "github.com/go-webauthn/webauthn/protocol" - "github.com/go-webauthn/webauthn/webauthn" - "golang-rest-api-template/internal/common" - "golang-rest-api-template/internal/common/middleware" - "golang-rest-api-template/internal/model" - "golang-rest-api-template/internal/utils" - "golang.org/x/time/rate" - "gorm.io/gorm" - "log" - "net/http" - "strings" - "time" -) - -func RegisterRoutes(group *gin.RouterGroup) { - group.GET("/webauthn/register/start", middleware.JWTAuth(false), beginRegistrationHandler) - group.POST("/webauthn/register/finish", middleware.JWTAuth(false), verifyRegistrationHandler) - - group.GET("/webauthn/login/start", beginLoginHandler) - group.POST("/webauthn/login/finish", middleware.RateLimiter(rate.Every(10*time.Second), 5), verifyLoginHandler) - - group.POST("/webauthn/logout", middleware.JWTAuth(false), logoutHandler) - - group.GET("/webauthn/credentials", middleware.JWTAuth(false), listCredentialsHandler) - group.PATCH("/webauthn/credentials/:id", middleware.JWTAuth(false), updateCredentialHandler) - group.DELETE("/webauthn/credentials/:id", middleware.JWTAuth(false), deleteCredentialHandler) -} - -func beginRegistrationHandler(c *gin.Context) { - var user model.User - err := common.DB.Preload("Credentials").Find(&user, "id = ?", c.GetString("userID")).Error - if err != nil { - utils.UnknownHandlerError(c, err) - log.Println(err) - return - } - - options, session, err := common.WebAuthn.BeginRegistration(&user, webauthn.WithResidentKeyRequirement(protocol.ResidentKeyRequirementRequired), webauthn.WithExclusions(user.WebAuthnCredentialDescriptors())) - if err != nil { - utils.UnknownHandlerError(c, err) - return - } - - // Save the webauthn session so we can retrieve it in the verifyRegistrationHandler - sessionToStore := &model.WebauthnSession{ - ExpiresAt: session.Expires, - Challenge: session.Challenge, - UserVerification: string(session.UserVerification), - } - - if err = common.DB.Create(&sessionToStore).Error; err != nil { - utils.UnknownHandlerError(c, err) - return - } - - c.SetCookie("session_id", sessionToStore.ID, int(common.WebAuthn.Config.Timeouts.Registration.Timeout.Seconds()), "/", "", false, true) - c.JSON(http.StatusOK, options.Response) -} - -func verifyRegistrationHandler(c *gin.Context) { - sessionID, err := c.Cookie("session_id") - if err != nil { - utils.HandlerError(c, http.StatusBadRequest, "Session ID missing") - return - } - - // Retrieve the session that was previously created by the beginRegistrationHandler - var storedSession model.WebauthnSession - err = common.DB.First(&storedSession, "id = ?", sessionID).Error - - session := webauthn.SessionData{ - Challenge: storedSession.Challenge, - Expires: storedSession.ExpiresAt, - UserID: []byte(c.GetString("userID")), - } - - var user model.User - err = common.DB.Find(&user, "id = ?", c.GetString("userID")).Error - if err != nil { - utils.UnknownHandlerError(c, err) - return - } - - credential, err := common.WebAuthn.FinishRegistration(&user, session, c.Request) - if err != nil { - utils.UnknownHandlerError(c, err) - return - } - - credentialToStore := model.WebauthnCredential{ - Name: "New Passkey", - CredentialID: string(credential.ID), - AttestationType: credential.AttestationType, - PublicKey: credential.PublicKey, - Transport: credential.Transport, - UserID: user.ID, - BackupEligible: credential.Flags.BackupEligible, - BackupState: credential.Flags.BackupState, - } - if err := common.DB.Create(&credentialToStore).Error; err != nil { - utils.UnknownHandlerError(c, err) - return - } - - c.JSON(http.StatusOK, credentialToStore) -} - -func beginLoginHandler(c *gin.Context) { - options, session, err := common.WebAuthn.BeginDiscoverableLogin() - if err != nil { - utils.UnknownHandlerError(c, err) - return - } - - // Save the webauthn session so we can retrieve it in the verifyLoginHandler - sessionToStore := &model.WebauthnSession{ - ExpiresAt: session.Expires, - Challenge: session.Challenge, - UserVerification: string(session.UserVerification), - } - - if err = common.DB.Create(&sessionToStore).Error; err != nil { - utils.UnknownHandlerError(c, err) - return - } - - c.SetCookie("session_id", sessionToStore.ID, int(common.WebAuthn.Config.Timeouts.Registration.Timeout.Seconds()), "/", "", false, true) - c.JSON(http.StatusOK, options.Response) -} - -func verifyLoginHandler(c *gin.Context) { - sessionID, err := c.Cookie("session_id") - if err != nil { - utils.HandlerError(c, http.StatusBadRequest, "Session ID missing") - return - } - - credentialAssertionData, err := protocol.ParseCredentialRequestResponseBody(c.Request.Body) - if err != nil { - utils.HandlerError(c, http.StatusBadRequest, "Invalid body") - return - } - - // Retrieve the session that was previously created by the beginLoginHandler - var storedSession model.WebauthnSession - if err := common.DB.First(&storedSession, "id = ?", sessionID).Error; err != nil { - utils.UnknownHandlerError(c, err) - return - } - - session := webauthn.SessionData{ - Challenge: storedSession.Challenge, - Expires: storedSession.ExpiresAt, - } - - var user *model.User - _, err = common.WebAuthn.ValidateDiscoverableLogin(func(_, userHandle []byte) (webauthn.User, error) { - if err := common.DB.Preload("Credentials").First(&user, "id = ?", string(userHandle)).Error; err != nil { - return nil, err - } - return user, nil - }, session, credentialAssertionData) - - if err != nil { - if strings.Contains(err.Error(), gorm.ErrRecordNotFound.Error()) { - utils.HandlerError(c, http.StatusBadRequest, "no user with this passkey exists") - } else { - utils.UnknownHandlerError(c, err) - } - return - } - - err = common.DB.Find(&user, "id = ?", c.GetString("userID")).Error - if err != nil { - utils.UnknownHandlerError(c, err) - return - } - - token, err := common.GenerateAccessToken(*user) - if err != nil { - utils.UnknownHandlerError(c, err) - return - } - - c.SetCookie("access_token", token, int(time.Hour.Seconds()), "/", "", false, true) - c.JSON(http.StatusOK, user) -} - -func listCredentialsHandler(c *gin.Context) { - var credentials []model.WebauthnCredential - if err := common.DB.Find(&credentials, "user_id = ?", c.GetString("userID")).Error; err != nil { - utils.UnknownHandlerError(c, err) - return - } - - c.JSON(http.StatusOK, credentials) -} - -func deleteCredentialHandler(c *gin.Context) { - var passkeyCount int64 - if err := common.DB.Model(&model.WebauthnCredential{}).Where("user_id = ?", c.GetString("userID")).Count(&passkeyCount).Error; err != nil { - utils.UnknownHandlerError(c, err) - return - } - - if passkeyCount == 1 { - utils.HandlerError(c, http.StatusBadRequest, "You must have at least one passkey") - return - } - - var credential model.WebauthnCredential - if err := common.DB.First(&credential, "id = ? AND user_id = ?", c.Param("id"), c.GetString("userID")).Error; err != nil { - utils.HandlerError(c, http.StatusNotFound, "Credential not found") - return - } - - if err := common.DB.Delete(&credential).Error; err != nil { - utils.UnknownHandlerError(c, err) - return - } - - c.Status(http.StatusNoContent) -} - -func updateCredentialHandler(c *gin.Context) { - var credential model.WebauthnCredential - if err := common.DB.Where("id = ? AND user_id = ?", c.Param("id"), c.GetString("userID")).First(&credential).Error; err != nil { - utils.HandlerError(c, http.StatusNotFound, "Credential not found") - return - } - - var input struct { - Name string `json:"name"` - } - if err := c.ShouldBindJSON(&input); err != nil { - utils.HandlerError(c, http.StatusBadRequest, "invalid request body") - return - } - - credential.Name = input.Name - - if err := common.DB.Save(&credential).Error; err != nil { - utils.UnknownHandlerError(c, err) - return - } - - c.Status(http.StatusNoContent) -} - -func logoutHandler(c *gin.Context) { - c.SetCookie("access_token", "", 0, "/", "", false, true) - c.Status(http.StatusNoContent) -} diff --git a/backend/internal/job/db_cleanup.go b/backend/internal/job/db_cleanup.go index 09b7092..fa7ac04 100644 --- a/backend/internal/job/db_cleanup.go +++ b/backend/internal/job/db_cleanup.go @@ -3,28 +3,46 @@ package job import ( "github.com/go-co-op/gocron/v2" "github.com/google/uuid" - "golang-rest-api-template/internal/common" - "golang-rest-api-template/internal/model" - "golang-rest-api-template/internal/utils" + "github.com/stonith404/pocket-id/backend/internal/model" + "github.com/stonith404/pocket-id/backend/internal/utils" + "gorm.io/gorm" "log" "time" ) -func RegisterJobs() { +func RegisterJobs(db *gorm.DB) { scheduler, err := gocron.NewScheduler() if err != nil { log.Fatalf("Failed to create a new scheduler: %s", err) } - registerJob(scheduler, "ClearWebauthnSessions", "0 3 * * *", clearWebauthnSessions) - registerJob(scheduler, "ClearOneTimeAccessTokens", "0 3 * * *", clearOneTimeAccessTokens) - registerJob(scheduler, "ClearOidcAuthorizationCodes", "0 3 * * *", clearOidcAuthorizationCodes) + jobs := &Jobs{db: db} + + registerJob(scheduler, "ClearWebauthnSessions", "0 3 * * *", jobs.clearWebauthnSessions) + registerJob(scheduler, "ClearOneTimeAccessTokens", "0 3 * * *", jobs.clearOneTimeAccessTokens) + registerJob(scheduler, "ClearOidcAuthorizationCodes", "0 3 * * *", jobs.clearOidcAuthorizationCodes) scheduler.Start() } -func registerJob(scheduler gocron.Scheduler, name string, interval string, job func() error) { +type Jobs struct { + db *gorm.DB +} +func (j *Jobs) clearWebauthnSessions() error { + return j.db.Delete(&model.WebauthnSession{}, "expires_at < ?", utils.FormatDateForDb(time.Now())).Error +} + +func (j *Jobs) clearOneTimeAccessTokens() error { + return j.db.Debug().Delete(&model.OneTimeAccessToken{}, "expires_at < ?", utils.FormatDateForDb(time.Now())).Error +} + +func (j *Jobs) clearOidcAuthorizationCodes() error { + return j.db.Delete(&model.OidcAuthorizationCode{}, "expires_at < ?", utils.FormatDateForDb(time.Now())).Error + +} + +func registerJob(scheduler gocron.Scheduler, name string, interval string, job func() error) { _, err := scheduler.NewJob( gocron.CronJob(interval, false), gocron.NewTask(job), @@ -42,16 +60,3 @@ func registerJob(scheduler gocron.Scheduler, name string, interval string, job f log.Fatalf("Failed to register job %q: %v", name, err) } } - -func clearWebauthnSessions() error { - return common.DB.Delete(&model.WebauthnSession{}, "expires_at < ?", utils.FormatDateForDb(time.Now())).Error -} - -func clearOneTimeAccessTokens() error { - return common.DB.Debug().Delete(&model.OneTimeAccessToken{}, "expires_at < ?", utils.FormatDateForDb(time.Now())).Error -} - -func clearOidcAuthorizationCodes() error { - return common.DB.Delete(&model.OidcAuthorizationCode{}, "expires_at < ?", utils.FormatDateForDb(time.Now())).Error - -} diff --git a/backend/internal/common/middleware/cors.go b/backend/internal/middleware/cors.go similarity index 57% rename from backend/internal/common/middleware/cors.go rename to backend/internal/middleware/cors.go index 30d478d..1cb09ce 100644 --- a/backend/internal/common/middleware/cors.go +++ b/backend/internal/middleware/cors.go @@ -1,14 +1,20 @@ package middleware import ( - "golang-rest-api-template/internal/common" + "github.com/stonith404/pocket-id/backend/internal/common" "time" "github.com/gin-contrib/cors" "github.com/gin-gonic/gin" ) -func Cors() gin.HandlerFunc { +type CorsMiddleware struct{} + +func NewCorsMiddleware() *CorsMiddleware { + return &CorsMiddleware{} +} + +func (m *CorsMiddleware) Add() gin.HandlerFunc { return cors.New(cors.Config{ AllowOrigins: []string{common.EnvConfig.AppURL}, AllowMethods: []string{"*"}, diff --git a/backend/internal/common/middleware/file_size_limit.go b/backend/internal/middleware/file_size_limit.go similarity index 76% rename from backend/internal/common/middleware/file_size_limit.go rename to backend/internal/middleware/file_size_limit.go index 5fd2dcf..d300f6f 100644 --- a/backend/internal/common/middleware/file_size_limit.go +++ b/backend/internal/middleware/file_size_limit.go @@ -3,11 +3,17 @@ package middleware import ( "fmt" "github.com/gin-gonic/gin" - "golang-rest-api-template/internal/utils" + "github.com/stonith404/pocket-id/backend/internal/utils" "net/http" ) -func LimitFileSize(maxSize int64) gin.HandlerFunc { +type FileSizeLimitMiddleware struct{} + +func NewFileSizeLimitMiddleware() *FileSizeLimitMiddleware { + return &FileSizeLimitMiddleware{} +} + +func (m *FileSizeLimitMiddleware) Add(maxSize int64) gin.HandlerFunc { return func(c *gin.Context) { c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, maxSize) if err := c.Request.ParseMultipartForm(maxSize); err != nil { diff --git a/backend/internal/common/middleware/jwt_auth.go b/backend/internal/middleware/jwt_auth.go similarity index 68% rename from backend/internal/common/middleware/jwt_auth.go rename to backend/internal/middleware/jwt_auth.go index f88f438..e7ebd6e 100644 --- a/backend/internal/common/middleware/jwt_auth.go +++ b/backend/internal/middleware/jwt_auth.go @@ -2,15 +2,22 @@ package middleware import ( "github.com/gin-gonic/gin" - "golang-rest-api-template/internal/common" - "golang-rest-api-template/internal/utils" + "github.com/stonith404/pocket-id/backend/internal/service" + "github.com/stonith404/pocket-id/backend/internal/utils" "net/http" "strings" ) -func JWTAuth(adminOnly bool) gin.HandlerFunc { - return func(c *gin.Context) { +type JwtAuthMiddleware struct { + jwtService *service.JwtService +} +func NewJwtAuthMiddleware(jwtService *service.JwtService) *JwtAuthMiddleware { + return &JwtAuthMiddleware{jwtService: jwtService} +} + +func (m *JwtAuthMiddleware) Add(adminOnly bool) gin.HandlerFunc { + return func(c *gin.Context) { // Extract the token from the cookie or the Authorization header token, err := c.Cookie("access_token") if err != nil { @@ -22,11 +29,9 @@ func JWTAuth(adminOnly bool) gin.HandlerFunc { c.Abort() return } - } - // Verify the token - claims, err := common.VerifyAccessToken(token) + claims, err := m.jwtService.VerifyAccessToken(token) if err != nil { utils.HandlerError(c, http.StatusUnauthorized, "You're not signed in") c.Abort() diff --git a/backend/internal/common/middleware/rate_limit.go b/backend/internal/middleware/rate_limit.go similarity index 83% rename from backend/internal/common/middleware/rate_limit.go rename to backend/internal/middleware/rate_limit.go index b133a3c..36aba16 100644 --- a/backend/internal/common/middleware/rate_limit.go +++ b/backend/internal/middleware/rate_limit.go @@ -1,8 +1,8 @@ package middleware import ( - "golang-rest-api-template/internal/common" - "golang-rest-api-template/internal/utils" + "github.com/stonith404/pocket-id/backend/internal/common" + "github.com/stonith404/pocket-id/backend/internal/utils" "net/http" "sync" "time" @@ -11,8 +11,13 @@ import ( "golang.org/x/time/rate" ) -// RateLimiter is a Gin middleware for rate limiting based on client IP -func RateLimiter(limit rate.Limit, burst int) gin.HandlerFunc { +type RateLimitMiddleware struct{} + +func NewRateLimitMiddleware() *RateLimitMiddleware { + return &RateLimitMiddleware{} +} + +func (m *RateLimitMiddleware) Add(limit rate.Limit, burst int) gin.HandlerFunc { // Start the cleanup routine go cleanupClients() diff --git a/backend/internal/model/app_config.go b/backend/internal/model/app_config.go new file mode 100644 index 0000000..3ce60d1 --- /dev/null +++ b/backend/internal/model/app_config.go @@ -0,0 +1,20 @@ +package model + +type AppConfigVariable struct { + Key string `gorm:"primaryKey;not null" json:"key"` + Type string `json:"type"` + IsPublic bool `json:"-"` + IsInternal bool `json:"-"` + Value string `json:"value"` +} + +type AppConfig struct { + AppName AppConfigVariable + BackgroundImageType AppConfigVariable + LogoImageType AppConfigVariable + SessionDuration AppConfigVariable +} + +type AppConfigUpdateDto struct { + AppName string `json:"appName" binding:"required"` +} diff --git a/backend/internal/model/application_configuration.go b/backend/internal/model/application_configuration.go deleted file mode 100644 index b7a464d..0000000 --- a/backend/internal/model/application_configuration.go +++ /dev/null @@ -1,20 +0,0 @@ -package model - -type ApplicationConfigurationVariable struct { - Key string `gorm:"primaryKey;not null" json:"key"` - Type string `json:"type"` - IsPublic bool `json:"-"` - IsInternal bool `json:"-"` - Value string `json:"value"` -} - -type ApplicationConfiguration struct { - AppName ApplicationConfigurationVariable - BackgroundImageType ApplicationConfigurationVariable - LogoImageType ApplicationConfigurationVariable - SessionDuration ApplicationConfigurationVariable -} - -type ApplicationConfigurationUpdateDto struct { - AppName string `json:"appName" binding:"required"` -} diff --git a/backend/internal/model/base.go b/backend/internal/model/base.go index 1fcdcde..dc1d402 100644 --- a/backend/internal/model/base.go +++ b/backend/internal/model/base.go @@ -12,7 +12,7 @@ type Base struct { CreatedAt time.Time `json:"createdAt"` } -func (b *Base) BeforeCreate(db *gorm.DB) (err error) { +func (b *Base) BeforeCreate(_ *gorm.DB) (err error) { if b.ID == "" { b.ID = uuid.New().String() } diff --git a/backend/internal/model/oidc.go b/backend/internal/model/oidc.go index 5fd2e3b..2b5e45e 100644 --- a/backend/internal/model/oidc.go +++ b/backend/internal/model/oidc.go @@ -63,3 +63,9 @@ type OidcIdTokenDto struct { ClientID string `form:"client_id"` ClientSecret string `form:"client_secret"` } + +type AuthorizeRequest struct { + ClientID string `json:"clientID" binding:"required"` + Scope string `json:"scope" binding:"required"` + Nonce string `json:"nonce"` +} diff --git a/backend/internal/model/webauthn.go b/backend/internal/model/webauthn.go index 94886f7..379077e 100644 --- a/backend/internal/model/webauthn.go +++ b/backend/internal/model/webauthn.go @@ -31,6 +31,18 @@ type WebauthnCredential struct { UserID string } +type PublicKeyCredentialCreationOptions struct { + Response protocol.PublicKeyCredentialCreationOptions `json:"response"` + SessionID string `json:"session_id"` + Timeout time.Duration `json:"timeout"` +} + +type PublicKeyCredentialRequestOptions struct { + Response protocol.PublicKeyCredentialRequestOptions `json:"response"` + SessionID string `json:"session_id"` + Timeout time.Duration `json:"timeout"` +} + type AuthenticatorTransportList []protocol.AuthenticatorTransport // Scan and Value methods for GORM to handle the custom type @@ -46,3 +58,7 @@ func (atl *AuthenticatorTransportList) Scan(value interface{}) error { func (atl AuthenticatorTransportList) Value() (driver.Value, error) { return json.Marshal(atl) } + +type WebauthnCredentialUpdateDto struct { + Name string `json:"name"` +} diff --git a/backend/internal/service/app_config_service.go b/backend/internal/service/app_config_service.go new file mode 100644 index 0000000..4c15340 --- /dev/null +++ b/backend/internal/service/app_config_service.go @@ -0,0 +1,213 @@ +package service + +import ( + "fmt" + "github.com/stonith404/pocket-id/backend/internal/common" + "github.com/stonith404/pocket-id/backend/internal/model" + "github.com/stonith404/pocket-id/backend/internal/utils" + "gorm.io/gorm" + "log" + "mime/multipart" + "os" + "reflect" +) + +type AppConfigService struct { + DbConfig *model.AppConfig + db *gorm.DB +} + +func NewAppConfigService(db *gorm.DB) *AppConfigService { + service := &AppConfigService{ + DbConfig: &defaultDbConfig, + db: db, + } + if err := service.InitDbConfig(); err != nil { + log.Fatalf("Failed to initialize app config service: %v", err) + } + return service +} + +var defaultDbConfig = model.AppConfig{ + AppName: model.AppConfigVariable{ + Key: "appName", + Type: "string", + IsPublic: true, + Value: "Pocket ID", + }, + SessionDuration: model.AppConfigVariable{ + Key: "sessionDuration", + Type: "number", + Value: "60", + }, + BackgroundImageType: model.AppConfigVariable{ + Key: "backgroundImageType", + Type: "string", + IsInternal: true, + Value: "jpg", + }, + LogoImageType: model.AppConfigVariable{ + Key: "logoImageType", + Type: "string", + IsInternal: true, + Value: "svg", + }, +} + +func (s *AppConfigService) UpdateApplicationConfiguration(input model.AppConfigUpdateDto) ([]model.AppConfigVariable, error) { + savedConfigVariables := make([]model.AppConfigVariable, 10) + + tx := s.db.Begin() + rt := reflect.ValueOf(input).Type() + rv := reflect.ValueOf(input) + + for i := 0; i < rt.NumField(); i++ { + field := rt.Field(i) + key := field.Tag.Get("json") + value := rv.FieldByName(field.Name).String() + + var applicationConfigurationVariable model.AppConfigVariable + if err := tx.First(&applicationConfigurationVariable, "key = ? AND is_internal = false", key).Error; err != nil { + tx.Rollback() + return nil, err + } + + applicationConfigurationVariable.Value = value + if err := tx.Save(&applicationConfigurationVariable).Error; err != nil { + tx.Rollback() + return nil, err + } + + savedConfigVariables[i] = applicationConfigurationVariable + } + + tx.Commit() + + if err := s.loadDbConfigFromDb(); err != nil { + return nil, err + } + + return savedConfigVariables, nil +} + +func (s *AppConfigService) UpdateImageType(imageName string, fileType string) error { + key := fmt.Sprintf("%sImageType", imageName) + err := s.db.Model(&model.AppConfigVariable{}).Where("key = ?", key).Update("value", fileType).Error + if err != nil { + return err + } + + return s.loadDbConfigFromDb() +} + +func (s *AppConfigService) ListApplicationConfiguration(showAll bool) ([]model.AppConfigVariable, error) { + var configuration []model.AppConfigVariable + var err error + + if showAll { + err = s.db.Find(&configuration).Error + } else { + err = s.db.Find(&configuration, "is_public = true").Error + } + + if err != nil { + return nil, err + } + + return configuration, nil +} + +func (s *AppConfigService) UpdateImage(uploadedFile *multipart.FileHeader, imageName string, oldImageType string) error { + fileType := utils.GetFileExtension(uploadedFile.Filename) + mimeType := utils.GetImageMimeType(fileType) + if mimeType == "" { + return common.ErrFileTypeNotSupported + } + + // Delete the old image if it has a different file type + if fileType != oldImageType { + oldImagePath := fmt.Sprintf("%s/application-images/%s.%s", common.EnvConfig.UploadPath, imageName, oldImageType) + if err := os.Remove(oldImagePath); err != nil { + return err + } + } + + imagePath := fmt.Sprintf("%s/application-images/%s.%s", common.EnvConfig.UploadPath, imageName, fileType) + if err := utils.SaveFile(uploadedFile, imagePath); err != nil { + return err + } + + // Update the file type in the database + if err := s.UpdateImageType(imageName, fileType); err != nil { + return err + } + + return nil +} + +// InitDbConfig creates the default configuration values in the database if they do not exist, +// updates existing configurations if they differ from the default, and deletes any configurations +// that are not in the default configuration. +func (s *AppConfigService) InitDbConfig() error { + // Reflect to get the underlying value of DbConfig and its default configuration + defaultConfigReflectValue := reflect.ValueOf(defaultDbConfig) + defaultKeys := make(map[string]struct{}) + + // Iterate over the fields of DbConfig + for i := 0; i < defaultConfigReflectValue.NumField(); i++ { + defaultConfigVar := defaultConfigReflectValue.Field(i).Interface().(model.AppConfigVariable) + + defaultKeys[defaultConfigVar.Key] = struct{}{} + + var storedConfigVar model.AppConfigVariable + if err := s.db.First(&storedConfigVar, "key = ?", defaultConfigVar.Key).Error; err != nil { + // If the configuration does not exist, create it + if err := s.db.Create(&defaultConfigVar).Error; err != nil { + return err + } + continue + } + + // Update existing configuration if it differs from the default + if storedConfigVar.Type != defaultConfigVar.Type || storedConfigVar.IsPublic != defaultConfigVar.IsPublic || storedConfigVar.IsInternal != defaultConfigVar.IsInternal { + storedConfigVar.Type = defaultConfigVar.Type + storedConfigVar.IsPublic = defaultConfigVar.IsPublic + storedConfigVar.IsInternal = defaultConfigVar.IsInternal + if err := s.db.Save(&storedConfigVar).Error; err != nil { + return err + } + } + } + + // Delete any configurations not in the default keys + var allConfigVars []model.AppConfigVariable + if err := s.db.Find(&allConfigVars).Error; err != nil { + return err + } + + for _, config := range allConfigVars { + if _, exists := defaultKeys[config.Key]; !exists { + if err := s.db.Delete(&config).Error; err != nil { + return err + } + } + } + return s.loadDbConfigFromDb() +} + +func (s *AppConfigService) loadDbConfigFromDb() error { + dbConfigReflectValue := reflect.ValueOf(s.DbConfig).Elem() + + for i := 0; i < dbConfigReflectValue.NumField(); i++ { + dbConfigField := dbConfigReflectValue.Field(i) + currentConfigVar := dbConfigField.Interface().(model.AppConfigVariable) + var storedConfigVar model.AppConfigVariable + if err := s.db.First(&storedConfigVar, "key = ?", currentConfigVar.Key).Error; err != nil { + return err + } + + dbConfigField.Set(reflect.ValueOf(storedConfigVar)) + } + + return nil +} diff --git a/backend/internal/service/jwt_service.go b/backend/internal/service/jwt_service.go new file mode 100644 index 0000000..c47508d --- /dev/null +++ b/backend/internal/service/jwt_service.go @@ -0,0 +1,248 @@ +package service + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/base64" + "encoding/pem" + "errors" + "fmt" + "github.com/golang-jwt/jwt/v5" + "github.com/stonith404/pocket-id/backend/internal/common" + "github.com/stonith404/pocket-id/backend/internal/model" + "github.com/stonith404/pocket-id/backend/internal/utils" + "log" + "math/big" + "os" + "path/filepath" + "slices" + "strconv" + "strings" + "time" +) + +const ( + privateKeyPath = "data/keys/jwt_private_key.pem" + publicKeyPath = "data/keys/jwt_public_key.pem" +) + +type JwtService struct { + publicKey *rsa.PublicKey + privateKey *rsa.PrivateKey + appConfigService *AppConfigService +} + +func NewJwtService(appConfigService *AppConfigService) *JwtService { + service := &JwtService{ + appConfigService: appConfigService, + } + + // Ensure keys are generated or loaded + if err := service.loadOrGenerateKeys(); err != nil { + log.Fatalf("Failed to initialize jwt service: %v", err) + } + + return service +} + +type AccessTokenJWTClaims struct { + jwt.RegisteredClaims + IsAdmin bool `json:"isAdmin,omitempty"` +} + +type JWK struct { + Kty string `json:"kty"` + Use string `json:"use"` + Kid string `json:"kid"` + Alg string `json:"alg"` + N string `json:"n"` + E string `json:"e"` +} + +// loadOrGenerateKeys loads RSA keys from the given paths or generates them if they do not exist. +func (s *JwtService) loadOrGenerateKeys() error { + if _, err := os.Stat(privateKeyPath); os.IsNotExist(err) { + if err := s.generateKeys(); err != nil { + return err + } + } + + privateKeyBytes, err := os.ReadFile(privateKeyPath) + if err != nil { + return errors.New("can't read jwt private key: " + err.Error()) + } + s.privateKey, err = jwt.ParseRSAPrivateKeyFromPEM(privateKeyBytes) + if err != nil { + return errors.New("can't parse jwt private key: " + err.Error()) + } + + publicKeyBytes, err := os.ReadFile(publicKeyPath) + if err != nil { + return errors.New("can't read jwt public key: " + err.Error()) + } + s.publicKey, err = jwt.ParseRSAPublicKeyFromPEM(publicKeyBytes) + if err != nil { + return errors.New("can't parse jwt public key: " + err.Error()) + } + + return nil +} + +func (s *JwtService) GenerateIDToken(user model.User, clientID string, scope string, nonce string) (string, error) { + profileClaims := map[string]interface{}{ + "given_name": user.FirstName, + "family_name": user.LastName, + "email": user.Email, + "preferred_username": user.Username, + } + + claims := jwt.MapClaims{ + "sub": user.ID, + "aud": clientID, + "exp": jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), + "iat": jwt.NewNumericDate(time.Now()), + } + + if nonce != "" { + claims["nonce"] = nonce + } + if strings.Contains(scope, "profile") { + for k, v := range profileClaims { + claims[k] = v + } + } + if strings.Contains(scope, "email") { + claims["email"] = user.Email + } + + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + return token.SignedString(s.privateKey) +} + +func (s *JwtService) GenerateAccessToken(user model.User) (string, error) { + sessionDurationInMinutes, _ := strconv.Atoi(s.appConfigService.DbConfig.SessionDuration.Value) + claim := AccessTokenJWTClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Subject: user.ID, + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Duration(sessionDurationInMinutes) * time.Minute)), + IssuedAt: jwt.NewNumericDate(time.Now()), + Audience: jwt.ClaimStrings{utils.GetHostFromURL(common.EnvConfig.AppURL)}, + }, + IsAdmin: user.IsAdmin, + } + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claim) + return token.SignedString(s.privateKey) +} + +func (s *JwtService) VerifyAccessToken(tokenString string) (*AccessTokenJWTClaims, error) { + token, err := jwt.ParseWithClaims(tokenString, &AccessTokenJWTClaims{}, func(token *jwt.Token) (interface{}, error) { + return s.publicKey, nil + }) + if err != nil || !token.Valid { + return nil, errors.New("couldn't handle this token") + } + + claims, isValid := token.Claims.(*AccessTokenJWTClaims) + if !isValid { + return nil, errors.New("can't parse claims") + } + + if !slices.Contains(claims.Audience, utils.GetHostFromURL(common.EnvConfig.AppURL)) { + return nil, errors.New("audience doesn't match") + } + return claims, nil +} + +// GetJWK returns the JSON Web Key (JWK) for the public key. +func (s *JwtService) GetJWK() (JWK, error) { + if s.publicKey == nil { + return JWK{}, errors.New("public key is not initialized") + } + + jwk := JWK{ + Kty: "RSA", + Use: "sig", + Kid: "1", + Alg: "RS256", + N: base64.RawURLEncoding.EncodeToString(s.publicKey.N.Bytes()), + E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(s.publicKey.E)).Bytes()), + } + + return jwk, nil +} + +// generateKeys generates a new RSA key pair and saves them to the specified paths. +func (s *JwtService) generateKeys() error { + if err := os.MkdirAll(filepath.Dir(privateKeyPath), 0700); err != nil { + return errors.New("failed to create directories for keys: " + err.Error()) + } + + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return errors.New("failed to generate private key: " + err.Error()) + } + s.privateKey = privateKey + + if err := s.savePEMKey(privateKeyPath, x509.MarshalPKCS1PrivateKey(privateKey), "RSA PRIVATE KEY"); err != nil { + return err + } + + publicKey := &privateKey.PublicKey + s.publicKey = publicKey + + if err := s.savePEMKey(publicKeyPath, x509.MarshalPKCS1PublicKey(publicKey), "RSA PUBLIC KEY"); err != nil { + return err + } + + return nil +} + +// savePEMKey saves a PEM encoded key to a file. +func (s *JwtService) savePEMKey(path string, keyBytes []byte, keyType string) error { + keyFile, err := os.Create(path) + if err != nil { + return errors.New("failed to create key file: " + err.Error()) + } + defer keyFile.Close() + + keyPEM := pem.EncodeToMemory(&pem.Block{ + Type: keyType, + Bytes: keyBytes, + }) + + if _, err := keyFile.Write(keyPEM); err != nil { + return errors.New("failed to write key file: " + err.Error()) + } + + return nil +} + +// loadKeys loads RSA keys from the given paths. +func (s *JwtService) loadKeys() error { + if _, err := os.Stat(privateKeyPath); os.IsNotExist(err) { + if err := s.generateKeys(); err != nil { + return err + } + } + + privateKeyBytes, err := os.ReadFile(privateKeyPath) + if err != nil { + return fmt.Errorf("can't read jwt private key: %w", err) + } + s.privateKey, err = jwt.ParseRSAPrivateKeyFromPEM(privateKeyBytes) + if err != nil { + return fmt.Errorf("can't parse jwt private key: %w", err) + } + + publicKeyBytes, err := os.ReadFile(publicKeyPath) + if err != nil { + return fmt.Errorf("can't read jwt public key: %w", err) + } + s.publicKey, err = jwt.ParseRSAPublicKeyFromPEM(publicKeyBytes) + if err != nil { + return fmt.Errorf("can't parse jwt public key: %w", err) + } + + return nil +} diff --git a/backend/internal/service/oidc_service.go b/backend/internal/service/oidc_service.go new file mode 100644 index 0000000..ba7c66b --- /dev/null +++ b/backend/internal/service/oidc_service.go @@ -0,0 +1,282 @@ +package service + +import ( + "errors" + "fmt" + "github.com/stonith404/pocket-id/backend/internal/common" + "github.com/stonith404/pocket-id/backend/internal/model" + "github.com/stonith404/pocket-id/backend/internal/utils" + "golang.org/x/crypto/bcrypt" + "gorm.io/gorm" + "mime/multipart" + "os" + "time" +) + +type OidcService struct { + db *gorm.DB + jwtService *JwtService +} + +func NewOidcService(db *gorm.DB, jwtService *JwtService) *OidcService { + return &OidcService{ + db: db, + jwtService: jwtService, + } +} + +func (s *OidcService) Authorize(req model.AuthorizeRequest, userID string) (string, error) { + var userAuthorizedOIDCClient model.UserAuthorizedOidcClient + s.db.First(&userAuthorizedOIDCClient, "client_id = ? AND user_id = ?", req.ClientID, userID) + + if userAuthorizedOIDCClient.Scope != req.Scope { + return "", common.ErrOidcMissingAuthorization + } + + return s.createAuthorizationCode(req.ClientID, userID, req.Scope, req.Nonce) +} + +func (s *OidcService) AuthorizeNewClient(req model.AuthorizeNewClientDto, userID string) (string, error) { + userAuthorizedClient := model.UserAuthorizedOidcClient{ + UserID: userID, + ClientID: req.ClientID, + Scope: req.Scope, + } + + if err := s.db.Create(&userAuthorizedClient).Error; err != nil { + if errors.Is(err, gorm.ErrDuplicatedKey) { + err = s.db.Model(&userAuthorizedClient).Update("scope", req.Scope).Error + } else { + return "", err + } + } + + return s.createAuthorizationCode(req.ClientID, userID, req.Scope, req.Nonce) +} + +func (s *OidcService) CreateIDToken(req model.OidcIdTokenDto) (string, error) { + if req.GrantType != "authorization_code" { + return "", common.ErrOidcGrantTypeNotSupported + } + + clientID := req.ClientID + clientSecret := req.ClientSecret + + if clientID == "" || clientSecret == "" { + return "", common.ErrOidcMissingClientCredentials + } + + var client model.OidcClient + if err := s.db.First(&client, "id = ?", clientID).Error; err != nil { + return "", err + } + + err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret)) + if err != nil { + return "", common.ErrOidcClientSecretInvalid + } + + var authorizationCodeMetaData model.OidcAuthorizationCode + err = s.db.Preload("User").First(&authorizationCodeMetaData, "code = ?", req.Code).Error + if err != nil { + return "", common.ErrOidcInvalidAuthorizationCode + } + + if authorizationCodeMetaData.ClientID != clientID && authorizationCodeMetaData.ExpiresAt.Before(time.Now()) { + return "", common.ErrOidcInvalidAuthorizationCode + } + + idToken, err := s.jwtService.GenerateIDToken(authorizationCodeMetaData.User, clientID, authorizationCodeMetaData.Scope, authorizationCodeMetaData.Nonce) + if err != nil { + return "", err + } + + s.db.Delete(&authorizationCodeMetaData) + + return idToken, nil +} + +func (s *OidcService) GetClient(clientID string) (*model.OidcClient, error) { + var client model.OidcClient + if err := s.db.First(&client, "id = ?", clientID).Error; err != nil { + return nil, err + } + return &client, nil +} + +func (s *OidcService) ListClients(searchTerm string, page int, pageSize int) ([]model.OidcClient, utils.PaginationResponse, error) { + var clients []model.OidcClient + + query := s.db.Model(&model.OidcClient{}) + if searchTerm != "" { + searchPattern := "%" + searchTerm + "%" + query = query.Where("name LIKE ?", searchPattern) + } + + pagination, err := utils.Paginate(page, pageSize, query, &clients) + if err != nil { + return nil, utils.PaginationResponse{}, err + } + + return clients, pagination, nil +} + +func (s *OidcService) CreateClient(input model.OidcClientCreateDto, userID string) (*model.OidcClient, error) { + client := model.OidcClient{ + Name: input.Name, + CallbackURL: input.CallbackURL, + CreatedByID: userID, + } + + if err := s.db.Create(&client).Error; err != nil { + return nil, err + } + + return &client, nil +} + +func (s *OidcService) UpdateClient(clientID string, input model.OidcClientCreateDto) (*model.OidcClient, error) { + var client model.OidcClient + if err := s.db.First(&client, "id = ?", clientID).Error; err != nil { + return nil, err + } + + client.Name = input.Name + client.CallbackURL = input.CallbackURL + + if err := s.db.Save(&client).Error; err != nil { + return nil, err + } + + return &client, nil +} + +func (s *OidcService) DeleteClient(clientID string) error { + var client model.OidcClient + if err := s.db.First(&client, "id = ?", clientID).Error; err != nil { + return err + } + + if err := s.db.Delete(&client).Error; err != nil { + return err + } + + return nil +} + +func (s *OidcService) CreateClientSecret(clientID string) (string, error) { + var client model.OidcClient + if err := s.db.First(&client, "id = ?", clientID).Error; err != nil { + return "", err + } + + clientSecret, err := utils.GenerateRandomAlphanumericString(32) + if err != nil { + return "", err + } + + hashedSecret, err := bcrypt.GenerateFromPassword([]byte(clientSecret), bcrypt.DefaultCost) + if err != nil { + return "", err + } + + client.Secret = string(hashedSecret) + if err := s.db.Save(&client).Error; err != nil { + return "", err + } + + return clientSecret, nil +} + +func (s *OidcService) GetClientLogo(clientID string) (string, string, error) { + var client model.OidcClient + if err := s.db.First(&client, "id = ?", clientID).Error; err != nil { + return "", "", err + } + + if client.ImageType == nil { + return "", "", errors.New("image not found") + } + + imageType := *client.ImageType + imagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, client.ID, imageType) + mimeType := utils.GetImageMimeType(imageType) + + return imagePath, mimeType, nil +} + +func (s *OidcService) UpdateClientLogo(clientID string, file *multipart.FileHeader) error { + fileType := utils.GetFileExtension(file.Filename) + if mimeType := utils.GetImageMimeType(fileType); mimeType == "" { + return common.ErrFileTypeNotSupported + } + + imagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, clientID, fileType) + if err := utils.SaveFile(file, imagePath); err != nil { + return err + } + + var client model.OidcClient + if err := s.db.First(&client, "id = ?", clientID).Error; err != nil { + return err + } + + if client.ImageType != nil && fileType != *client.ImageType { + oldImagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, client.ID, *client.ImageType) + if err := os.Remove(oldImagePath); err != nil { + return err + } + } + + client.ImageType = &fileType + if err := s.db.Save(&client).Error; err != nil { + return err + } + + return nil +} + +func (s *OidcService) DeleteClientLogo(clientID string) error { + var client model.OidcClient + if err := s.db.First(&client, "id = ?", clientID).Error; err != nil { + return err + } + + if client.ImageType == nil { + return errors.New("image not found") + } + + imagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, client.ID, *client.ImageType) + if err := os.Remove(imagePath); err != nil { + return err + } + + client.ImageType = nil + if err := s.db.Save(&client).Error; err != nil { + return err + } + + return nil +} + +func (s *OidcService) createAuthorizationCode(clientID string, userID string, scope string, nonce string) (string, error) { + randomString, err := utils.GenerateRandomAlphanumericString(32) + if err != nil { + return "", err + } + + oidcAuthorizationCode := model.OidcAuthorizationCode{ + ExpiresAt: time.Now().Add(15 * time.Minute), + Code: randomString, + ClientID: clientID, + UserID: userID, + Scope: scope, + Nonce: nonce, + } + + if err := s.db.Create(&oidcAuthorizationCode).Error; err != nil { + return "", err + } + + return randomString, nil +} diff --git a/backend/internal/handler/test.go b/backend/internal/service/test_service.go similarity index 74% rename from backend/internal/handler/test.go rename to backend/internal/service/test_service.go index 71c1797..fe71be4 100644 --- a/backend/internal/handler/test.go +++ b/backend/internal/service/test_service.go @@ -1,48 +1,33 @@ -package handler +package service import ( "crypto/ecdsa" "crypto/x509" "encoding/base64" + "fmt" + "github.com/fxamacker/cbor/v2" "log" "os" "time" - "github.com/fxamacker/cbor/v2" - "github.com/gin-gonic/gin" "github.com/go-webauthn/webauthn/protocol" - "golang-rest-api-template/internal/common" - "golang-rest-api-template/internal/model" - "golang-rest-api-template/internal/utils" + "github.com/stonith404/pocket-id/backend/internal/common" + "github.com/stonith404/pocket-id/backend/internal/model" + "github.com/stonith404/pocket-id/backend/internal/utils" "gorm.io/gorm" ) -func RegisterTestRoutes(group *gin.RouterGroup) { - group.POST("/test/reset", resetAndSeedHandler) +type TestService struct { + db *gorm.DB + appConfigService *AppConfigService } -func resetAndSeedHandler(c *gin.Context) { - if err := resetDatabase(); err != nil { - utils.UnknownHandlerError(c, err) - return - } - - if err := resetApplicationImages(); err != nil { - utils.UnknownHandlerError(c, err) - return - } - - if err := seedDatabase(); err != nil { - utils.UnknownHandlerError(c, err) - return - } - - c.JSON(200, gin.H{"message": "Database reset and seeded"}) +func NewTestService(db *gorm.DB, appConfigService *AppConfigService) *TestService { + return &TestService{db: db, appConfigService: appConfigService} } -// seedDatabase seeds the database with initial data and uses a transaction to ensure atomicity. -func seedDatabase() error { - return common.DB.Transaction(func(tx *gorm.DB) error { +func (s *TestService) SeedDatabase() error { + return s.db.Transaction(func(tx *gorm.DB) error { users := []model.User{ { Base: model.Base{ @@ -128,11 +113,16 @@ func seedDatabase() error { return err } + publicKey1, err := getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEwcOo5KV169KR67QEHrcYkeXE3CCxv2BgwnSq4VYTQxyLtdmKxegexa8JdwFKhKXa2BMI9xaN15BoL6wSCRFJhg==") + publicKey2, err := getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAESq/wR8QbBu3dKnpaw/v0mDxFFDwnJ/L5XHSg2tAmq5x1BpSMmIr3+DxCbybVvGRmWGh8kKhy7SMnK91M6rFHTA==") + if err != nil { + return err + } webauthnCredentials := []model.WebauthnCredential{ { Name: "Passkey 1", CredentialID: "test-credential-1", - PublicKey: getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEwcOo5KV169KR67QEHrcYkeXE3CCxv2BgwnSq4VYTQxyLtdmKxegexa8JdwFKhKXa2BMI9xaN15BoL6wSCRFJhg=="), + PublicKey: publicKey1, AttestationType: "none", Transport: model.AuthenticatorTransportList{protocol.Internal}, UserID: users[0].ID, @@ -140,7 +130,7 @@ func seedDatabase() error { { Name: "Passkey 2", CredentialID: "test-credential-2", - PublicKey: getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAESq/wR8QbBu3dKnpaw/v0mDxFFDwnJ/L5XHSg2tAmq5x1BpSMmIr3+DxCbybVvGRmWGh8kKhy7SMnK91M6rFHTA=="), + PublicKey: publicKey2, AttestationType: "none", Transport: model.AuthenticatorTransportList{protocol.Internal}, UserID: users[0].ID, @@ -165,9 +155,8 @@ func seedDatabase() error { }) } -// resetDatabase resets the database by deleting all rows from each table. -func resetDatabase() error { - err := common.DB.Transaction(func(tx *gorm.DB) error { +func (s *TestService) ResetDatabase() error { + err := s.db.Transaction(func(tx *gorm.DB) error { var tables []string if err := tx.Raw("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%' AND name != 'schema_migrations';").Scan(&tables).Error; err != nil { return err @@ -183,13 +172,11 @@ func resetDatabase() error { if err != nil { return err } - common.InitDbConfig() - return nil + err = s.appConfigService.InitDbConfig() + return err } -// resetApplicationImages resets the application images by removing existing images and replacing them with the default ones -func resetApplicationImages() error { - +func (s *TestService) ResetApplicationImages() error { if err := os.RemoveAll(common.EnvConfig.UploadPath); err != nil { log.Printf("Error removing directory: %v", err) return err @@ -204,20 +191,19 @@ func resetApplicationImages() error { } // getCborPublicKey decodes a Base64 encoded public key and returns the CBOR encoded COSE key -func getCborPublicKey(base64PublicKey string) []byte { +func getCborPublicKey(base64PublicKey string) ([]byte, error) { decodedKey, err := base64.StdEncoding.DecodeString(base64PublicKey) if err != nil { - log.Fatalf("Failed to decode base64 key: %v", err) + return nil, fmt.Errorf("failed to decode base64 key: %w", err) } - pubKey, err := x509.ParsePKIXPublicKey(decodedKey) if err != nil { - log.Fatalf("Failed to parse public key: %v", err) + return nil, fmt.Errorf("failed to parse public key: %w", err) } ecdsaPubKey, ok := pubKey.(*ecdsa.PublicKey) if !ok { - log.Fatalf("Not an ECDSA public key") + return nil, fmt.Errorf("not an ECDSA public key") } coseKey := map[int]interface{}{ @@ -230,8 +216,8 @@ func getCborPublicKey(base64PublicKey string) []byte { cborPublicKey, err := cbor.Marshal(coseKey) if err != nil { - log.Fatalf("Failed to encode CBOR: %v", err) + return nil, fmt.Errorf("failed to marshal COSE key: %w", err) } - return cborPublicKey + return cborPublicKey, nil } diff --git a/backend/internal/service/user_sevice.go b/backend/internal/service/user_sevice.go new file mode 100644 index 0000000..0c5c5a2 --- /dev/null +++ b/backend/internal/service/user_sevice.go @@ -0,0 +1,165 @@ +package service + +import ( + "errors" + "github.com/stonith404/pocket-id/backend/internal/common" + "github.com/stonith404/pocket-id/backend/internal/model" + "github.com/stonith404/pocket-id/backend/internal/utils" + "gorm.io/gorm" + "time" +) + +type UserService struct { + db *gorm.DB + jwtService *JwtService +} + +func NewUserService(db *gorm.DB, jwtService *JwtService) *UserService { + return &UserService{db: db, jwtService: jwtService} +} + +func (s *UserService) ListUsers(searchTerm string, page int, pageSize int) ([]model.User, utils.PaginationResponse, error) { + var users []model.User + query := s.db.Model(&model.User{}) + + if searchTerm != "" { + searchPattern := "%" + searchTerm + "%" + query = query.Where("email LIKE ? OR first_name LIKE ? OR username LIKE ?", searchPattern, searchPattern, searchPattern) + } + + pagination, err := utils.Paginate(page, pageSize, query, &users) + return users, pagination, err +} + +func (s *UserService) GetUser(userID string) (model.User, error) { + var user model.User + err := s.db.Where("id = ?", userID).First(&user).Error + return user, err +} + +func (s *UserService) DeleteUser(userID string) error { + var user model.User + if err := s.db.Where("id = ?", userID).First(&user).Error; err != nil { + return err + } + + return s.db.Delete(&user).Error +} + +func (s *UserService) CreateUser(user *model.User) error { + if err := s.db.Create(user).Error; err != nil { + if errors.Is(err, gorm.ErrDuplicatedKey) { + return s.checkDuplicatedFields(*user) + } + return err + } + return nil +} + +func (s *UserService) UpdateUser(userID string, updatedUser model.User, updateOwnUser bool) (model.User, error) { + var user model.User + if err := s.db.Where("id = ?", userID).First(&user).Error; err != nil { + return model.User{}, err + } + user.FirstName = updatedUser.FirstName + user.LastName = updatedUser.LastName + user.Email = updatedUser.Email + user.Username = updatedUser.Username + if !updateOwnUser { + user.IsAdmin = updatedUser.IsAdmin + } + + if err := s.db.Save(&user).Error; err != nil { + if errors.Is(err, gorm.ErrDuplicatedKey) { + return user, s.checkDuplicatedFields(user) + } + return user, err + } + + return user, nil +} + +func (s *UserService) CreateOneTimeAccessToken(userID string, expiresAt time.Time) (string, error) { + randomString, err := utils.GenerateRandomAlphanumericString(16) + if err != nil { + return "", err + } + + oneTimeAccessToken := model.OneTimeAccessToken{ + UserID: userID, + ExpiresAt: expiresAt, + Token: randomString, + } + + if err := s.db.Create(&oneTimeAccessToken).Error; err != nil { + return "", err + } + + return oneTimeAccessToken.Token, nil +} + +func (s *UserService) ExchangeOneTimeAccessToken(token string) (model.User, string, error) { + var oneTimeAccessToken model.OneTimeAccessToken + if err := s.db.Where("token = ? AND expires_at > ?", token, utils.FormatDateForDb(time.Now())).Preload("User").First(&oneTimeAccessToken).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return model.User{}, "", common.ErrTokenInvalidOrExpired + } + return model.User{}, "", err + } + accessToken, err := s.jwtService.GenerateAccessToken(oneTimeAccessToken.User) + if err != nil { + return model.User{}, "", err + } + + if err := s.db.Delete(&oneTimeAccessToken).Error; err != nil { + return model.User{}, "", err + } + + return oneTimeAccessToken.User, accessToken, nil +} + +func (s *UserService) SetupInitialAdmin() (model.User, string, error) { + var userCount int64 + if err := s.db.Model(&model.User{}).Count(&userCount).Error; err != nil { + return model.User{}, "", err + } + if userCount > 1 { + return model.User{}, "", common.ErrSetupAlreadyCompleted + } + + user := model.User{ + FirstName: "Admin", + LastName: "Admin", + Username: "admin", + Email: "admin@admin.com", + IsAdmin: true, + } + + if err := s.db.Model(&model.User{}).Preload("Credentials").FirstOrCreate(&user).Error; err != nil { + return model.User{}, "", err + } + + if len(user.Credentials) > 0 { + return model.User{}, "", common.ErrSetupAlreadyCompleted + } + + token, err := s.jwtService.GenerateAccessToken(user) + if err != nil { + return model.User{}, "", err + } + + return user, token, nil +} + +func (s *UserService) checkDuplicatedFields(user model.User) error { + var existingUser model.User + if s.db.Where("id != ? AND email = ?", user.ID, user.Email).First(&existingUser).Error == nil { + return common.ErrEmailTaken + } + + if s.db.Where("id != ? AND username = ?", user.ID, user.Username).First(&existingUser).Error == nil { + return common.ErrUsernameTaken + } + + return nil +} diff --git a/backend/internal/service/webauthn_service.go b/backend/internal/service/webauthn_service.go new file mode 100644 index 0000000..fdf4dbc --- /dev/null +++ b/backend/internal/service/webauthn_service.go @@ -0,0 +1,196 @@ +package service + +import ( + "github.com/go-webauthn/webauthn/protocol" + "github.com/go-webauthn/webauthn/webauthn" + "github.com/stonith404/pocket-id/backend/internal/common" + "github.com/stonith404/pocket-id/backend/internal/model" + "github.com/stonith404/pocket-id/backend/internal/utils" + "gorm.io/gorm" + "net/http" + "time" +) + +type WebAuthnService struct { + db *gorm.DB + webAuthn *webauthn.WebAuthn +} + +func NewWebAuthnService(db *gorm.DB, appConfigService *AppConfigService) *WebAuthnService { + webauthnConfig := &webauthn.Config{ + RPDisplayName: appConfigService.DbConfig.AppName.Value, + RPID: utils.GetHostFromURL(common.EnvConfig.AppURL), + RPOrigins: []string{common.EnvConfig.AppURL}, + Timeouts: webauthn.TimeoutsConfig{ + Login: webauthn.TimeoutConfig{ + Enforce: true, + Timeout: time.Second * 60, + TimeoutUVD: time.Second * 60, + }, + Registration: webauthn.TimeoutConfig{ + Enforce: true, + Timeout: time.Second * 60, + TimeoutUVD: time.Second * 60, + }, + }, + } + + wa, _ := webauthn.New(webauthnConfig) + return &WebAuthnService{db: db, webAuthn: wa} +} + +func (s *WebAuthnService) BeginRegistration(userID string) (*model.PublicKeyCredentialCreationOptions, error) { + var user model.User + if err := s.db.Preload("Credentials").Find(&user, "id = ?", userID).Error; err != nil { + return nil, err + } + + options, session, err := s.webAuthn.BeginRegistration(&user, webauthn.WithResidentKeyRequirement(protocol.ResidentKeyRequirementRequired), webauthn.WithExclusions(user.WebAuthnCredentialDescriptors())) + if err != nil { + return nil, err + } + + sessionToStore := &model.WebauthnSession{ + ExpiresAt: session.Expires, + Challenge: session.Challenge, + UserVerification: string(session.UserVerification), + } + + if err := s.db.Create(&sessionToStore).Error; err != nil { + return nil, err + } + + return &model.PublicKeyCredentialCreationOptions{ + Response: options.Response, + SessionID: sessionToStore.ID, + Timeout: s.webAuthn.Config.Timeouts.Registration.Timeout, + }, nil +} + +func (s *WebAuthnService) VerifyRegistration(sessionID, userID string, r *http.Request) (*model.WebauthnCredential, error) { + var storedSession model.WebauthnSession + if err := s.db.First(&storedSession, "id = ?", sessionID).Error; err != nil { + return nil, err + } + + session := webauthn.SessionData{ + Challenge: storedSession.Challenge, + Expires: storedSession.ExpiresAt, + UserID: []byte(userID), + } + + var user model.User + if err := s.db.Find(&user, "id = ?", userID).Error; err != nil { + return nil, err + } + + credential, err := s.webAuthn.FinishRegistration(&user, session, r) + if err != nil { + return nil, err + } + + credentialToStore := model.WebauthnCredential{ + Name: "New Passkey", + CredentialID: string(credential.ID), + AttestationType: credential.AttestationType, + PublicKey: credential.PublicKey, + Transport: credential.Transport, + UserID: user.ID, + BackupEligible: credential.Flags.BackupEligible, + BackupState: credential.Flags.BackupState, + } + if err := s.db.Create(&credentialToStore).Error; err != nil { + return nil, err + } + + return &credentialToStore, nil +} + +func (s *WebAuthnService) BeginLogin() (*model.PublicKeyCredentialRequestOptions, error) { + options, session, err := s.webAuthn.BeginDiscoverableLogin() + if err != nil { + return nil, err + } + + sessionToStore := &model.WebauthnSession{ + ExpiresAt: session.Expires, + Challenge: session.Challenge, + UserVerification: string(session.UserVerification), + } + + if err := s.db.Create(&sessionToStore).Error; err != nil { + return nil, err + } + + return &model.PublicKeyCredentialRequestOptions{ + Response: options.Response, + SessionID: sessionToStore.ID, + Timeout: s.webAuthn.Config.Timeouts.Registration.Timeout, + }, nil +} + +func (s *WebAuthnService) VerifyLogin(sessionID, userID string, credentialAssertionData *protocol.ParsedCredentialAssertionData) (*model.User, error) { + var storedSession model.WebauthnSession + if err := s.db.First(&storedSession, "id = ?", sessionID).Error; err != nil { + return nil, err + } + + session := webauthn.SessionData{ + Challenge: storedSession.Challenge, + Expires: storedSession.ExpiresAt, + } + + var user *model.User + _, err := s.webAuthn.ValidateDiscoverableLogin(func(_, userHandle []byte) (webauthn.User, error) { + if err := s.db.Preload("Credentials").First(&user, "id = ?", string(userHandle)).Error; err != nil { + return nil, err + } + return user, nil + }, session, credentialAssertionData) + + if err != nil { + return nil, common.ErrInvalidCredentials + } + + if err := s.db.Find(&user, "id = ?", userID).Error; err != nil { + return nil, err + } + + return user, nil +} + +func (s *WebAuthnService) ListCredentials(userID string) ([]model.WebauthnCredential, error) { + var credentials []model.WebauthnCredential + if err := s.db.Find(&credentials, "user_id = ?", userID).Error; err != nil { + return nil, err + } + return credentials, nil +} + +func (s *WebAuthnService) DeleteCredential(userID, credentialID string) error { + var credential model.WebauthnCredential + if err := s.db.First(&credential, "id = ? AND user_id = ?", credentialID, userID).Error; err != nil { + return err + } + + if err := s.db.Delete(&credential).Error; err != nil { + return err + } + + return nil +} + +func (s *WebAuthnService) UpdateCredential(userID, credentialID, name string) error { + var credential model.WebauthnCredential + if err := s.db.Where("id = ? AND user_id = ?", credentialID, userID).First(&credential).Error; err != nil { + return err + } + + credential.Name = name + + if err := s.db.Save(&credential).Error; err != nil { + return err + } + + return nil +} diff --git a/backend/internal/utils/file_util.go b/backend/internal/utils/file_util.go index f8141cd..5bf45f4 100644 --- a/backend/internal/utils/file_util.go +++ b/backend/internal/utils/file_util.go @@ -2,6 +2,7 @@ package utils import ( "io" + "mime/multipart" "os" "path/filepath" "strings" @@ -71,3 +72,24 @@ func copyFile(srcFilePath, destFilePath string) error { return nil } + +func SaveFile(file *multipart.FileHeader, dst string) error { + src, err := file.Open() + if err != nil { + return err + } + defer src.Close() + + if err = os.MkdirAll(filepath.Dir(dst), 0o750); err != nil { + return err + } + + out, err := os.Create(dst) + if err != nil { + return err + } + defer out.Close() + + _, err = io.Copy(out, src) + return err +} diff --git a/backend/internal/utils/handler_error_util.go b/backend/internal/utils/handler_error_util.go index a920630..e644501 100644 --- a/backend/internal/utils/handler_error_util.go +++ b/backend/internal/utils/handler_error_util.go @@ -1,15 +1,23 @@ package utils import ( + "errors" "github.com/gin-gonic/gin" + "gorm.io/gorm" "log" "net/http" "strings" ) func UnknownHandlerError(c *gin.Context, err error) { - log.Println(err) - c.JSON(http.StatusInternalServerError, gin.H{"error": "Something went wrong"}) + if errors.Is(err, gorm.ErrRecordNotFound) { + HandlerError(c, http.StatusNotFound, "Record not found") + return + } else { + log.Println(err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "Something went wrong"}) + } + } func HandlerError(c *gin.Context, statusCode int, message string) { diff --git a/backend/internal/utils/paging_util.go b/backend/internal/utils/paging_util.go index ecd9697..b998648 100644 --- a/backend/internal/utils/paging_util.go +++ b/backend/internal/utils/paging_util.go @@ -1,9 +1,7 @@ package utils import ( - "github.com/gin-gonic/gin" "gorm.io/gorm" - "strconv" ) type PaginationResponse struct { @@ -12,10 +10,7 @@ type PaginationResponse struct { CurrentPage int `json:"currentPage"` } -func Paginate(c *gin.Context, db *gorm.DB, result interface{}) (PaginationResponse, error) { - page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) - pageSize, _ := strconv.Atoi(c.DefaultQuery("limit", "10")) - +func Paginate(page int, pageSize int, db *gorm.DB, result interface{}) (PaginationResponse, error) { if page < 1 { page = 1 } diff --git a/backend/migrations/20240817191051_rename_config_table.down.sql b/backend/migrations/20240817191051_rename_config_table.down.sql new file mode 100644 index 0000000..49996fa --- /dev/null +++ b/backend/migrations/20240817191051_rename_config_table.down.sql @@ -0,0 +1,2 @@ +ALTER TABLE app_config_variables + RENAME TO application_configuration_variables; \ No newline at end of file diff --git a/backend/migrations/20240817191051_rename_config_table.up.sql b/backend/migrations/20240817191051_rename_config_table.up.sql new file mode 100644 index 0000000..87bbb87 --- /dev/null +++ b/backend/migrations/20240817191051_rename_config_table.up.sql @@ -0,0 +1,2 @@ +ALTER TABLE application_configuration_variables + RENAME TO app_config_variables; \ No newline at end of file diff --git a/frontend/tests/account-settings.spec.ts b/frontend/tests/account-settings.spec.ts index ebca515..00a36b0 100644 --- a/frontend/tests/account-settings.spec.ts +++ b/frontend/tests/account-settings.spec.ts @@ -69,15 +69,3 @@ test('Delete passkey from account', async ({ page }) => { await expect(page.getByRole('status')).toHaveText('Passkey deleted successfully'); }); - -test('Delete last passkey from account fails', async ({ page }) => { - await page.goto('/settings/account'); - - await page.getByLabel('Delete').first().click(); - await page.getByText('Delete', { exact: true }).click(); - - await page.getByLabel('Delete').first().click(); - await page.getByText('Delete', { exact: true }).click(); - - await expect(page.getByRole('status').first()).toHaveText('You must have at least one passkey'); -}); diff --git a/frontend/tests/user-settings.spec.ts b/frontend/tests/user-settings.spec.ts index 167be81..ced3154 100644 --- a/frontend/tests/user-settings.spec.ts +++ b/frontend/tests/user-settings.spec.ts @@ -35,6 +35,21 @@ test('Create user fails with already taken email', async ({ page }) => { await expect(page.getByRole('status')).toHaveText('Email is already taken'); }); +test('Create user fails with already taken username', async ({ page }) => { + const user = users.steve; + + await page.goto('/settings/admin/users'); + + await page.getByRole('button', { name: 'Add User' }).click(); + await page.getByLabel('Firstname').fill(user.firstname); + await page.getByLabel('Lastname').fill(user.lastname); + await page.getByLabel('Email').fill(user.email); + await page.getByLabel('Username').fill(users.tim.username); + await page.getByRole('button', { name: 'Save' }).click(); + + await expect(page.getByRole('status')).toHaveText('Username is already taken'); +}); + test('Create one time access token', async ({ page }) => { await page.goto('/settings/admin/users');