mirror of
https://github.com/nikdoof/pocket-id.git
synced 2025-12-14 07:12:19 +00:00
refactor: use dependency injection in backend
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"golang-rest-api-template/internal/bootstrap"
|
"github.com/stonith404/pocket-id/backend/internal/bootstrap"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
module golang-rest-api-template
|
module github.com/stonith404/pocket-id/backend
|
||||||
|
|
||||||
go 1.22
|
go 1.22
|
||||||
|
|
||||||
|
|||||||
28
backend/internal/bootstrap/application_images_bootstrap.go
Normal file
28
backend/internal/bootstrap/application_images_bootstrap.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
package bootstrap
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/stonith404/pocket-id/backend/internal/common"
|
||||||
|
"github.com/stonith404/pocket-id/backend/internal/utils"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
func initApplicationImages() {
|
||||||
|
dirPath := common.EnvConfig.UploadPath + "/application-images"
|
||||||
|
|
||||||
|
files, err := os.ReadDir(dirPath)
|
||||||
|
if err != nil && !os.IsNotExist(err) {
|
||||||
|
log.Fatalf("Error reading directory: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip if files already exist
|
||||||
|
if len(files) > 1 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy files from source to destination
|
||||||
|
err = utils.CopyDirectory("./images", dirPath)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Error copying directory: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,78 +1,16 @@
|
|||||||
package bootstrap
|
package bootstrap
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
_ "github.com/golang-migrate/migrate/v4/source/file"
|
_ "github.com/golang-migrate/migrate/v4/source/file"
|
||||||
"golang-rest-api-template/internal/common"
|
"github.com/stonith404/pocket-id/backend/internal/job"
|
||||||
"golang-rest-api-template/internal/common/middleware"
|
"github.com/stonith404/pocket-id/backend/internal/service"
|
||||||
"golang-rest-api-template/internal/handler"
|
|
||||||
"golang-rest-api-template/internal/job"
|
|
||||||
"golang-rest-api-template/internal/utils"
|
|
||||||
"golang.org/x/time/rate"
|
|
||||||
"log"
|
|
||||||
"os"
|
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func Bootstrap() {
|
func Bootstrap() {
|
||||||
common.InitDatabase()
|
db := newDatabase()
|
||||||
common.InitDbConfig()
|
appConfigService := service.NewAppConfigService(db)
|
||||||
|
|
||||||
initApplicationImages()
|
initApplicationImages()
|
||||||
job.RegisterJobs()
|
job.RegisterJobs(db)
|
||||||
initRouter()
|
initRouter(db, appConfigService)
|
||||||
}
|
|
||||||
|
|
||||||
func initRouter() {
|
|
||||||
switch common.EnvConfig.AppEnv {
|
|
||||||
case "production":
|
|
||||||
gin.SetMode(gin.ReleaseMode)
|
|
||||||
case "development":
|
|
||||||
gin.SetMode(gin.DebugMode)
|
|
||||||
case "test":
|
|
||||||
gin.SetMode(gin.TestMode)
|
|
||||||
}
|
|
||||||
|
|
||||||
r := gin.Default()
|
|
||||||
|
|
||||||
r.Use(gin.Logger())
|
|
||||||
|
|
||||||
r.Use(middleware.Cors())
|
|
||||||
r.Use(middleware.RateLimiter(rate.Every(time.Second), 60))
|
|
||||||
|
|
||||||
apiGroup := r.Group("/api")
|
|
||||||
handler.RegisterRoutes(apiGroup)
|
|
||||||
handler.RegisterOIDCRoutes(apiGroup)
|
|
||||||
handler.RegisterUserRoutes(apiGroup)
|
|
||||||
handler.RegisterConfigurationRoutes(apiGroup)
|
|
||||||
if common.EnvConfig.AppEnv != "production" {
|
|
||||||
handler.RegisterTestRoutes(apiGroup)
|
|
||||||
}
|
|
||||||
|
|
||||||
baseGroup := r.Group("/")
|
|
||||||
handler.RegisterWellKnownRoutes(baseGroup)
|
|
||||||
|
|
||||||
if err := r.Run(common.EnvConfig.Host + ":" + common.EnvConfig.Port); err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func initApplicationImages() {
|
|
||||||
dirPath := common.EnvConfig.UploadPath + "/application-images"
|
|
||||||
|
|
||||||
files, err := os.ReadDir(dirPath)
|
|
||||||
if err != nil && !os.IsNotExist(err) {
|
|
||||||
log.Fatalf("Error reading directory: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Skip if files already exist
|
|
||||||
if len(files) > 1 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Copy files from source to destination
|
|
||||||
err = utils.CopyDirectory("./images", dirPath)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("Error copying directory: %v", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,51 +1,54 @@
|
|||||||
package common
|
package bootstrap
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"github.com/golang-migrate/migrate/v4"
|
"github.com/golang-migrate/migrate/v4"
|
||||||
"github.com/golang-migrate/migrate/v4/database/sqlite3"
|
"github.com/golang-migrate/migrate/v4/database/sqlite3"
|
||||||
|
"github.com/stonith404/pocket-id/backend/internal/common"
|
||||||
|
"gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
"gorm.io/gorm/logger"
|
"gorm.io/gorm/logger"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"gorm.io/driver/sqlite"
|
|
||||||
"gorm.io/gorm"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var DB *gorm.DB
|
func newDatabase() (db *gorm.DB) {
|
||||||
|
db, err := connectDatabase()
|
||||||
func InitDatabase() {
|
|
||||||
connectDatabase()
|
|
||||||
sqlDb, err := DB.DB()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal("failed to get sql db", err)
|
log.Fatalf("failed to connect to database: %v", err)
|
||||||
}
|
}
|
||||||
|
sqlDb, err := db.DB()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("failed to get sql.DB: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
driver, err := sqlite3.WithInstance(sqlDb, &sqlite3.Config{})
|
driver, err := sqlite3.WithInstance(sqlDb, &sqlite3.Config{})
|
||||||
m, err := migrate.NewWithDatabaseInstance(
|
m, err := migrate.NewWithDatabaseInstance(
|
||||||
"file://migrations",
|
"file://migrations",
|
||||||
"postgres", driver)
|
"postgres", driver)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal("failed to create migration instance", err)
|
log.Fatalf("failed to create migration instance: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = m.Up()
|
err = m.Up()
|
||||||
if err != nil && !errors.Is(err, migrate.ErrNoChange) {
|
if err != nil && !errors.Is(err, migrate.ErrNoChange) {
|
||||||
log.Fatal("failed to run migrations", err)
|
log.Fatalf("failed to apply migrations: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return db
|
||||||
}
|
}
|
||||||
|
|
||||||
func connectDatabase() {
|
func connectDatabase() (db *gorm.DB, err error) {
|
||||||
var database *gorm.DB
|
dbPath := common.EnvConfig.DBPath
|
||||||
var err error
|
|
||||||
|
|
||||||
dbPath := EnvConfig.DBPath
|
// Use in-memory database for testing
|
||||||
if EnvConfig.AppEnv == "test" {
|
if common.EnvConfig.AppEnv == "test" {
|
||||||
dbPath = "file::memory:?cache=shared"
|
dbPath = "file::memory:?cache=shared"
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 1; i <= 3; i++ {
|
for i := 1; i <= 3; i++ {
|
||||||
database, err = gorm.Open(sqlite.Open(dbPath), &gorm.Config{
|
db, err = gorm.Open(sqlite.Open(dbPath), &gorm.Config{
|
||||||
TranslateError: true,
|
TranslateError: true,
|
||||||
Logger: getLogger(),
|
Logger: getLogger(),
|
||||||
})
|
})
|
||||||
@@ -57,11 +60,11 @@ func connectDatabase() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
DB = database
|
return db, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func getLogger() logger.Interface {
|
func getLogger() logger.Interface {
|
||||||
isProduction := EnvConfig.AppEnv == "production"
|
isProduction := common.EnvConfig.AppEnv == "production"
|
||||||
|
|
||||||
var logLevel logger.LogLevel
|
var logLevel logger.LogLevel
|
||||||
if isProduction {
|
if isProduction {
|
||||||
@@ -70,7 +73,6 @@ func getLogger() logger.Interface {
|
|||||||
logLevel = logger.Info
|
logLevel = logger.Info
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create the GORM logger
|
|
||||||
return logger.New(
|
return logger.New(
|
||||||
log.New(os.Stdout, "\r\n", log.LstdFlags),
|
log.New(os.Stdout, "\r\n", log.LstdFlags),
|
||||||
logger.Config{
|
logger.Config{
|
||||||
67
backend/internal/bootstrap/router_bootstrap.go
Normal file
67
backend/internal/bootstrap/router_bootstrap.go
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
package bootstrap
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stonith404/pocket-id/backend/internal/common"
|
||||||
|
"github.com/stonith404/pocket-id/backend/internal/controller"
|
||||||
|
"github.com/stonith404/pocket-id/backend/internal/middleware"
|
||||||
|
"github.com/stonith404/pocket-id/backend/internal/service"
|
||||||
|
"golang.org/x/time/rate"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) {
|
||||||
|
// Set the appropriate Gin mode based on the environment
|
||||||
|
switch common.EnvConfig.AppEnv {
|
||||||
|
case "production":
|
||||||
|
gin.SetMode(gin.ReleaseMode)
|
||||||
|
case "development":
|
||||||
|
gin.SetMode(gin.DebugMode)
|
||||||
|
case "test":
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
}
|
||||||
|
|
||||||
|
r := gin.Default()
|
||||||
|
r.Use(gin.Logger())
|
||||||
|
|
||||||
|
// Add middleware
|
||||||
|
r.Use(
|
||||||
|
middleware.NewCorsMiddleware().Add(),
|
||||||
|
middleware.NewRateLimitMiddleware().Add(rate.Every(time.Second), 60),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Initialize services
|
||||||
|
webauthnService := service.NewWebAuthnService(db, appConfigService)
|
||||||
|
jwtService := service.NewJwtService(appConfigService)
|
||||||
|
userService := service.NewUserService(db, jwtService)
|
||||||
|
oidcService := service.NewOidcService(db, jwtService)
|
||||||
|
testService := service.NewTestService(db, appConfigService)
|
||||||
|
|
||||||
|
// Initialize middleware
|
||||||
|
jwtAuthMiddleware := middleware.NewJwtAuthMiddleware(jwtService)
|
||||||
|
fileSizeLimitMiddleware := middleware.NewFileSizeLimitMiddleware()
|
||||||
|
|
||||||
|
// Set up API routes
|
||||||
|
apiGroup := r.Group("/api")
|
||||||
|
controller.NewWebauthnController(apiGroup, jwtAuthMiddleware, middleware.NewRateLimitMiddleware(), webauthnService, jwtService)
|
||||||
|
controller.NewOidcController(apiGroup, jwtAuthMiddleware, fileSizeLimitMiddleware, oidcService)
|
||||||
|
controller.NewUserController(apiGroup, jwtAuthMiddleware, middleware.NewRateLimitMiddleware(), userService)
|
||||||
|
controller.NewApplicationConfigurationController(apiGroup, jwtAuthMiddleware, appConfigService)
|
||||||
|
|
||||||
|
// Add test controller in non-production environments
|
||||||
|
if common.EnvConfig.AppEnv != "production" {
|
||||||
|
controller.NewTestController(apiGroup, testService)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up base routes
|
||||||
|
baseGroup := r.Group("/")
|
||||||
|
controller.NewWellKnownController(baseGroup, jwtService)
|
||||||
|
|
||||||
|
// Run the server
|
||||||
|
if err := r.Run(common.EnvConfig.Host + ":" + common.EnvConfig.Port); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,138 +0,0 @@
|
|||||||
package common
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/caarlos0/env/v11"
|
|
||||||
_ "github.com/joho/godotenv/autoload"
|
|
||||||
"golang-rest-api-template/internal/model"
|
|
||||||
"log"
|
|
||||||
"reflect"
|
|
||||||
)
|
|
||||||
|
|
||||||
type EnvConfigSchema struct {
|
|
||||||
AppEnv string `env:"APP_ENV"`
|
|
||||||
AppURL string `env:"PUBLIC_APP_URL"`
|
|
||||||
DBPath string `env:"DB_PATH"`
|
|
||||||
UploadPath string `env:"UPLOAD_PATH"`
|
|
||||||
Port string `env:"BACKEND_PORT"`
|
|
||||||
Host string `env:"HOST"`
|
|
||||||
}
|
|
||||||
|
|
||||||
var EnvConfig = &EnvConfigSchema{
|
|
||||||
AppEnv: "production",
|
|
||||||
DBPath: "data/pocket-id.db",
|
|
||||||
UploadPath: "data/uploads",
|
|
||||||
AppURL: "http://localhost",
|
|
||||||
Port: "8080",
|
|
||||||
Host: "localhost",
|
|
||||||
}
|
|
||||||
|
|
||||||
var DbConfig = NewDefaultDbConfig()
|
|
||||||
|
|
||||||
func NewDefaultDbConfig() model.ApplicationConfiguration {
|
|
||||||
return model.ApplicationConfiguration{
|
|
||||||
AppName: model.ApplicationConfigurationVariable{
|
|
||||||
Key: "appName",
|
|
||||||
Type: "string",
|
|
||||||
IsPublic: true,
|
|
||||||
Value: "Pocket ID",
|
|
||||||
},
|
|
||||||
SessionDuration: model.ApplicationConfigurationVariable{
|
|
||||||
Key: "sessionDuration",
|
|
||||||
Type: "number",
|
|
||||||
Value: "60",
|
|
||||||
},
|
|
||||||
BackgroundImageType: model.ApplicationConfigurationVariable{
|
|
||||||
Key: "backgroundImageType",
|
|
||||||
Type: "string",
|
|
||||||
IsInternal: true,
|
|
||||||
Value: "jpg",
|
|
||||||
},
|
|
||||||
LogoImageType: model.ApplicationConfigurationVariable{
|
|
||||||
Key: "logoImageType",
|
|
||||||
Type: "string",
|
|
||||||
IsInternal: true,
|
|
||||||
Value: "svg",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// LoadDbConfigFromDb refreshes the database configuration by loading the current values
|
|
||||||
// from the database and updating the DbConfig struct.
|
|
||||||
func LoadDbConfigFromDb() error {
|
|
||||||
dbConfigReflectValue := reflect.ValueOf(&DbConfig).Elem()
|
|
||||||
|
|
||||||
for i := 0; i < dbConfigReflectValue.NumField(); i++ {
|
|
||||||
dbConfigField := dbConfigReflectValue.Field(i)
|
|
||||||
currentConfigVar := dbConfigField.Interface().(model.ApplicationConfigurationVariable)
|
|
||||||
var storedConfigVar model.ApplicationConfigurationVariable
|
|
||||||
if err := DB.First(&storedConfigVar, "key = ?", currentConfigVar.Key).Error; err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
dbConfigField.Set(reflect.ValueOf(storedConfigVar))
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// InitDbConfig creates the default configuration values in the database if they do not exist,
|
|
||||||
// updates existing configurations if they differ from the default, and deletes any configurations
|
|
||||||
// that are not in the default configuration.
|
|
||||||
func InitDbConfig() {
|
|
||||||
// Reflect to get the underlying value of DbConfig and its default configuration
|
|
||||||
dbConfigReflectValue := reflect.ValueOf(&DbConfig).Elem()
|
|
||||||
defaultDbConfig := NewDefaultDbConfig()
|
|
||||||
defaultConfigReflectValue := reflect.ValueOf(&defaultDbConfig).Elem()
|
|
||||||
defaultKeys := make(map[string]struct{})
|
|
||||||
|
|
||||||
// Iterate over the fields of DbConfig
|
|
||||||
for i := 0; i < dbConfigReflectValue.NumField(); i++ {
|
|
||||||
dbConfigField := dbConfigReflectValue.Field(i)
|
|
||||||
currentConfigVar := dbConfigField.Interface().(model.ApplicationConfigurationVariable)
|
|
||||||
defaultConfigVar := defaultConfigReflectValue.Field(i).Interface().(model.ApplicationConfigurationVariable)
|
|
||||||
defaultKeys[currentConfigVar.Key] = struct{}{}
|
|
||||||
|
|
||||||
var storedConfigVar model.ApplicationConfigurationVariable
|
|
||||||
if err := DB.First(&storedConfigVar, "key = ?", currentConfigVar.Key).Error; err != nil {
|
|
||||||
// If the configuration does not exist, create it
|
|
||||||
if err := DB.Create(&defaultConfigVar).Error; err != nil {
|
|
||||||
log.Fatalf("Failed to create default configuration: %v", err)
|
|
||||||
}
|
|
||||||
dbConfigField.Set(reflect.ValueOf(defaultConfigVar))
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update existing configuration if it differs from the default
|
|
||||||
if storedConfigVar.Type != defaultConfigVar.Type || storedConfigVar.IsPublic != defaultConfigVar.IsPublic || storedConfigVar.IsInternal != defaultConfigVar.IsInternal {
|
|
||||||
storedConfigVar.Type = defaultConfigVar.Type
|
|
||||||
storedConfigVar.IsPublic = defaultConfigVar.IsPublic
|
|
||||||
storedConfigVar.IsInternal = defaultConfigVar.IsInternal
|
|
||||||
if err := DB.Save(&storedConfigVar).Error; err != nil {
|
|
||||||
log.Fatalf("Failed to update configuration: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the value in DbConfig
|
|
||||||
dbConfigField.Set(reflect.ValueOf(storedConfigVar))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Delete any configurations not in the default keys
|
|
||||||
var allConfigVars []model.ApplicationConfigurationVariable
|
|
||||||
if err := DB.Find(&allConfigVars).Error; err != nil {
|
|
||||||
log.Fatalf("Failed to retrieve existing configurations: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, config := range allConfigVars {
|
|
||||||
if _, exists := defaultKeys[config.Key]; !exists {
|
|
||||||
if err := DB.Delete(&config).Error; err != nil {
|
|
||||||
log.Fatalf("Failed to delete outdated configuration: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
if err := env.ParseWithOptions(EnvConfig, env.Options{}); err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
31
backend/internal/common/env_config.go
Normal file
31
backend/internal/common/env_config.go
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/caarlos0/env/v11"
|
||||||
|
_ "github.com/joho/godotenv/autoload"
|
||||||
|
"log"
|
||||||
|
)
|
||||||
|
|
||||||
|
type EnvConfigSchema struct {
|
||||||
|
AppEnv string `env:"APP_ENV"`
|
||||||
|
AppURL string `env:"PUBLIC_APP_URL"`
|
||||||
|
DBPath string `env:"DB_PATH"`
|
||||||
|
UploadPath string `env:"UPLOAD_PATH"`
|
||||||
|
Port string `env:"BACKEND_PORT"`
|
||||||
|
Host string `env:"HOST"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var EnvConfig = &EnvConfigSchema{
|
||||||
|
AppEnv: "production",
|
||||||
|
DBPath: "data/pocket-id.db",
|
||||||
|
UploadPath: "data/uploads",
|
||||||
|
AppURL: "http://localhost",
|
||||||
|
Port: "8080",
|
||||||
|
Host: "localhost",
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
if err := env.ParseWithOptions(EnvConfig, env.Options{}); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
18
backend/internal/common/errors.go
Normal file
18
backend/internal/common/errors.go
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrUsernameTaken = errors.New("username is already taken")
|
||||||
|
ErrEmailTaken = errors.New("email is already taken")
|
||||||
|
ErrSetupAlreadyCompleted = errors.New("setup already completed")
|
||||||
|
ErrInvalidBody = errors.New("invalid request body")
|
||||||
|
ErrTokenInvalidOrExpired = errors.New("token is invalid or expired")
|
||||||
|
ErrOidcMissingAuthorization = errors.New("missing authorization")
|
||||||
|
ErrOidcGrantTypeNotSupported = errors.New("grant type not supported")
|
||||||
|
ErrOidcMissingClientCredentials = errors.New("client id or secret not provided")
|
||||||
|
ErrOidcClientSecretInvalid = errors.New("invalid client secret")
|
||||||
|
ErrOidcInvalidAuthorizationCode = errors.New("invalid authorization code")
|
||||||
|
ErrFileTypeNotSupported = errors.New("file type not supported")
|
||||||
|
ErrInvalidCredentials = errors.New("no user found with provided credentials")
|
||||||
|
)
|
||||||
@@ -1,209 +0,0 @@
|
|||||||
package common
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/rsa"
|
|
||||||
"crypto/x509"
|
|
||||||
"encoding/base64"
|
|
||||||
"encoding/pem"
|
|
||||||
"errors"
|
|
||||||
"github.com/golang-jwt/jwt/v5"
|
|
||||||
"golang-rest-api-template/internal/model"
|
|
||||||
"golang-rest-api-template/internal/utils"
|
|
||||||
"log"
|
|
||||||
"math/big"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"slices"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
PrivateKey *rsa.PrivateKey
|
|
||||||
PublicKey *rsa.PublicKey
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
privateKeyPath = "data/keys/jwt_private_key.pem"
|
|
||||||
publicKeyPath = "data/keys/jwt_public_key.pem"
|
|
||||||
)
|
|
||||||
|
|
||||||
type accessTokenJWTClaims struct {
|
|
||||||
jwt.RegisteredClaims
|
|
||||||
IsAdmin bool `json:"isAdmin,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// GenerateIDToken generates an ID token for the given user, clientID, scope and nonce.
|
|
||||||
func GenerateIDToken(user model.User, clientID string, scope string, nonce string) (tokenString string, err error) {
|
|
||||||
profileClaims := map[string]interface{}{
|
|
||||||
"given_name": user.FirstName,
|
|
||||||
"family_name": user.LastName,
|
|
||||||
"email": user.Email,
|
|
||||||
"preferred_username": user.Username,
|
|
||||||
}
|
|
||||||
|
|
||||||
claims := jwt.MapClaims{
|
|
||||||
"sub": user.ID,
|
|
||||||
"aud": clientID,
|
|
||||||
"exp": jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
|
|
||||||
"iat": jwt.NewNumericDate(time.Now()),
|
|
||||||
}
|
|
||||||
|
|
||||||
if nonce != "" {
|
|
||||||
claims["nonce"] = nonce
|
|
||||||
}
|
|
||||||
if strings.Contains(scope, "profile") {
|
|
||||||
for k, v := range profileClaims {
|
|
||||||
claims[k] = v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if strings.Contains(scope, "email") {
|
|
||||||
claims["email"] = user.Email
|
|
||||||
}
|
|
||||||
|
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
|
||||||
signedToken, err := token.SignedString(PrivateKey)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
return signedToken, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GenerateAccessToken generates an access token for the given user.
|
|
||||||
func GenerateAccessToken(user model.User) (tokenString string, err error) {
|
|
||||||
sessionDurationInMinutes, _ := strconv.Atoi(DbConfig.SessionDuration.Value)
|
|
||||||
claim := accessTokenJWTClaims{
|
|
||||||
RegisteredClaims: jwt.RegisteredClaims{
|
|
||||||
Subject: user.ID,
|
|
||||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Duration(sessionDurationInMinutes) * time.Minute)),
|
|
||||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
|
||||||
Audience: jwt.ClaimStrings{utils.GetHostFromURL(EnvConfig.AppURL)},
|
|
||||||
},
|
|
||||||
IsAdmin: user.IsAdmin,
|
|
||||||
}
|
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claim)
|
|
||||||
tokenString, err = token.SignedString(PrivateKey)
|
|
||||||
return tokenString, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// VerifyAccessToken verifies the given access token and returns the claims if the token is valid.
|
|
||||||
func VerifyAccessToken(tokenString string) (*accessTokenJWTClaims, error) {
|
|
||||||
token, err := jwt.ParseWithClaims(tokenString, &accessTokenJWTClaims{}, func(token *jwt.Token) (interface{}, error) {
|
|
||||||
return PublicKey, nil
|
|
||||||
})
|
|
||||||
if err != nil || !token.Valid {
|
|
||||||
return nil, errors.New("couldn't handle this token")
|
|
||||||
}
|
|
||||||
|
|
||||||
claims, isValid := token.Claims.(*accessTokenJWTClaims)
|
|
||||||
if !isValid {
|
|
||||||
return nil, errors.New("can't parse claims")
|
|
||||||
}
|
|
||||||
|
|
||||||
if !slices.Contains(claims.Audience, utils.GetHostFromURL(EnvConfig.AppURL)) {
|
|
||||||
return nil, errors.New("audience doesn't match")
|
|
||||||
}
|
|
||||||
return claims, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type JWK struct {
|
|
||||||
Kty string `json:"kty"`
|
|
||||||
Use string `json:"use"`
|
|
||||||
Kid string `json:"kid"`
|
|
||||||
Alg string `json:"alg"`
|
|
||||||
N string `json:"n"`
|
|
||||||
E string `json:"e"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetJWK returns the JSON Web Key (JWK) for the public key.
|
|
||||||
func GetJWK() (JWK, error) {
|
|
||||||
if PublicKey == nil {
|
|
||||||
return JWK{}, errors.New("public key is not initialized")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create JWK from RSA public key
|
|
||||||
jwk := JWK{
|
|
||||||
Kty: "RSA",
|
|
||||||
Use: "sig",
|
|
||||||
Kid: "1", // Key ID can be set to any identifier. Here it's statically set to "1"
|
|
||||||
Alg: "RS256",
|
|
||||||
N: base64.RawURLEncoding.EncodeToString(PublicKey.N.Bytes()),
|
|
||||||
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(PublicKey.E)).Bytes()),
|
|
||||||
}
|
|
||||||
|
|
||||||
return jwk, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// generateKeys generates a new RSA key pair and saves the private and public keys to the data folder.
|
|
||||||
func generateKeys() {
|
|
||||||
if err := os.MkdirAll(filepath.Dir(privateKeyPath), 0700); err != nil {
|
|
||||||
log.Fatal("Failed to create directories for keys", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal("Failed to generate private key", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
privateKeyFile, err := os.Create(privateKeyPath)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal("Failed to create private key file", err)
|
|
||||||
}
|
|
||||||
defer privateKeyFile.Close()
|
|
||||||
|
|
||||||
privateKeyPEM := pem.EncodeToMemory(
|
|
||||||
&pem.Block{
|
|
||||||
Type: "RSA PRIVATE KEY",
|
|
||||||
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
_, err = privateKeyFile.Write(privateKeyPEM)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal("Failed to write private key file", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
publicKey := &privateKey.PublicKey
|
|
||||||
publicKeyFile, err := os.Create(publicKeyPath)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal("Failed to create public key file", err)
|
|
||||||
}
|
|
||||||
defer publicKeyFile.Close()
|
|
||||||
|
|
||||||
publicKeyPEM := pem.EncodeToMemory(
|
|
||||||
&pem.Block{
|
|
||||||
Type: "RSA PUBLIC KEY",
|
|
||||||
Bytes: x509.MarshalPKCS1PublicKey(publicKey),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
_, err = publicKeyFile.Write(publicKeyPEM)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal("Failed to write public key file", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
if _, err := os.Stat(privateKeyPath); os.IsNotExist(err) {
|
|
||||||
generateKeys()
|
|
||||||
}
|
|
||||||
|
|
||||||
privateKeyBytes, err := os.ReadFile(privateKeyPath)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal("Can't read jwt private key", err)
|
|
||||||
}
|
|
||||||
PrivateKey, err = jwt.ParseRSAPrivateKeyFromPEM(privateKeyBytes)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal("Can't parse jwt private key", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
publicKeyBytes, err := os.ReadFile(publicKeyPath)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal("Can't read jwt public key", err)
|
|
||||||
}
|
|
||||||
PublicKey, err = jwt.ParseRSAPublicKeyFromPEM(publicKeyBytes)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal("Can't parse jwt public key", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,37 +0,0 @@
|
|||||||
package common
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/go-webauthn/webauthn/webauthn"
|
|
||||||
"golang-rest-api-template/internal/utils"
|
|
||||||
"log"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
WebAuthn *webauthn.WebAuthn
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
config := &webauthn.Config{
|
|
||||||
RPDisplayName: DbConfig.AppName.Value,
|
|
||||||
RPID: utils.GetHostFromURL(EnvConfig.AppURL),
|
|
||||||
RPOrigins: []string{EnvConfig.AppURL},
|
|
||||||
Timeouts: webauthn.TimeoutsConfig{
|
|
||||||
Login: webauthn.TimeoutConfig{
|
|
||||||
Enforce: true,
|
|
||||||
Timeout: time.Second * 60,
|
|
||||||
TimeoutUVD: time.Second * 60,
|
|
||||||
},
|
|
||||||
Registration: webauthn.TimeoutConfig{
|
|
||||||
Enforce: true,
|
|
||||||
Timeout: time.Second * 60,
|
|
||||||
TimeoutUVD: time.Second * 60,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
if WebAuthn, err = webauthn.New(config); err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
140
backend/internal/controller/app_config_controller.go
Normal file
140
backend/internal/controller/app_config_controller.go
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stonith404/pocket-id/backend/internal/common"
|
||||||
|
"github.com/stonith404/pocket-id/backend/internal/middleware"
|
||||||
|
"github.com/stonith404/pocket-id/backend/internal/model"
|
||||||
|
"github.com/stonith404/pocket-id/backend/internal/service"
|
||||||
|
"github.com/stonith404/pocket-id/backend/internal/utils"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewApplicationConfigurationController(
|
||||||
|
group *gin.RouterGroup,
|
||||||
|
jwtAuthMiddleware *middleware.JwtAuthMiddleware,
|
||||||
|
appConfigService *service.AppConfigService) {
|
||||||
|
|
||||||
|
acc := &ApplicationConfigurationController{
|
||||||
|
appConfigService: appConfigService,
|
||||||
|
}
|
||||||
|
group.GET("/application-configuration", acc.listApplicationConfigurationHandler)
|
||||||
|
group.GET("/application-configuration/all", jwtAuthMiddleware.Add(true), acc.listAllApplicationConfigurationHandler)
|
||||||
|
group.PUT("/application-configuration", acc.updateApplicationConfigurationHandler)
|
||||||
|
|
||||||
|
group.GET("/application-configuration/logo", acc.getLogoHandler)
|
||||||
|
group.GET("/application-configuration/background-image", acc.getBackgroundImageHandler)
|
||||||
|
group.GET("/application-configuration/favicon", acc.getFaviconHandler)
|
||||||
|
group.PUT("/application-configuration/logo", jwtAuthMiddleware.Add(true), acc.updateLogoHandler)
|
||||||
|
group.PUT("/application-configuration/favicon", jwtAuthMiddleware.Add(true), acc.updateFaviconHandler)
|
||||||
|
group.PUT("/application-configuration/background-image", jwtAuthMiddleware.Add(true), acc.updateBackgroundImageHandler)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ApplicationConfigurationController struct {
|
||||||
|
appConfigService *service.AppConfigService
|
||||||
|
}
|
||||||
|
|
||||||
|
func (acc *ApplicationConfigurationController) listApplicationConfigurationHandler(c *gin.Context) {
|
||||||
|
configuration, err := acc.appConfigService.ListApplicationConfiguration(false)
|
||||||
|
if err != nil {
|
||||||
|
utils.UnknownHandlerError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(200, configuration)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (acc *ApplicationConfigurationController) listAllApplicationConfigurationHandler(c *gin.Context) {
|
||||||
|
configuration, err := acc.appConfigService.ListApplicationConfiguration(true)
|
||||||
|
if err != nil {
|
||||||
|
utils.UnknownHandlerError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(200, configuration)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (acc *ApplicationConfigurationController) updateApplicationConfigurationHandler(c *gin.Context) {
|
||||||
|
var input model.AppConfigUpdateDto
|
||||||
|
if err := c.ShouldBindJSON(&input); err != nil {
|
||||||
|
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
savedConfigVariables, err := acc.appConfigService.UpdateApplicationConfiguration(input)
|
||||||
|
if err != nil {
|
||||||
|
utils.UnknownHandlerError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, savedConfigVariables)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (acc *ApplicationConfigurationController) getLogoHandler(c *gin.Context) {
|
||||||
|
imageType := acc.appConfigService.DbConfig.LogoImageType.Value
|
||||||
|
acc.getImage(c, "logo", imageType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (acc *ApplicationConfigurationController) getFaviconHandler(c *gin.Context) {
|
||||||
|
acc.getImage(c, "favicon", "ico")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (acc *ApplicationConfigurationController) getBackgroundImageHandler(c *gin.Context) {
|
||||||
|
imageType := acc.appConfigService.DbConfig.BackgroundImageType.Value
|
||||||
|
acc.getImage(c, "background", imageType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (acc *ApplicationConfigurationController) updateLogoHandler(c *gin.Context) {
|
||||||
|
imageType := acc.appConfigService.DbConfig.LogoImageType.Value
|
||||||
|
acc.updateImage(c, "logo", imageType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (acc *ApplicationConfigurationController) updateFaviconHandler(c *gin.Context) {
|
||||||
|
file, err := c.FormFile("file")
|
||||||
|
if err != nil {
|
||||||
|
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fileType := utils.GetFileExtension(file.Filename)
|
||||||
|
if fileType != "ico" {
|
||||||
|
utils.HandlerError(c, http.StatusBadRequest, "File must be of type .ico")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
acc.updateImage(c, "favicon", "ico")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (acc *ApplicationConfigurationController) updateBackgroundImageHandler(c *gin.Context) {
|
||||||
|
imageType := acc.appConfigService.DbConfig.BackgroundImageType.Value
|
||||||
|
acc.updateImage(c, "background", imageType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (acc *ApplicationConfigurationController) getImage(c *gin.Context, name string, imageType string) {
|
||||||
|
imagePath := fmt.Sprintf("%s/application-images/%s.%s", common.EnvConfig.UploadPath, name, imageType)
|
||||||
|
mimeType := utils.GetImageMimeType(imageType)
|
||||||
|
|
||||||
|
c.Header("Content-Type", mimeType)
|
||||||
|
c.File(imagePath)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (acc *ApplicationConfigurationController) updateImage(c *gin.Context, imageName string, oldImageType string) {
|
||||||
|
file, err := c.FormFile("file")
|
||||||
|
if err != nil {
|
||||||
|
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = acc.appConfigService.UpdateImage(file, imageName, oldImageType)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, common.ErrFileTypeNotSupported) {
|
||||||
|
utils.HandlerError(c, http.StatusBadRequest, err.Error())
|
||||||
|
} else {
|
||||||
|
utils.UnknownHandlerError(c, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Status(http.StatusNoContent)
|
||||||
|
}
|
||||||
218
backend/internal/controller/oidc_controller.go
Normal file
218
backend/internal/controller/oidc_controller.go
Normal file
@@ -0,0 +1,218 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stonith404/pocket-id/backend/internal/common"
|
||||||
|
"github.com/stonith404/pocket-id/backend/internal/middleware"
|
||||||
|
"github.com/stonith404/pocket-id/backend/internal/model"
|
||||||
|
"github.com/stonith404/pocket-id/backend/internal/service"
|
||||||
|
"github.com/stonith404/pocket-id/backend/internal/utils"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewOidcController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.JwtAuthMiddleware, fileSizeLimitMiddleware *middleware.FileSizeLimitMiddleware, oidcService *service.OidcService) {
|
||||||
|
oc := &OidcController{OidcService: oidcService}
|
||||||
|
|
||||||
|
group.POST("/oidc/authorize", jwtAuthMiddleware.Add(false), oc.authorizeHandler)
|
||||||
|
group.POST("/oidc/authorize/new-client", jwtAuthMiddleware.Add(false), oc.authorizeNewClientHandler)
|
||||||
|
group.POST("/oidc/token", oc.createIDTokenHandler)
|
||||||
|
|
||||||
|
group.GET("/oidc/clients", jwtAuthMiddleware.Add(true), oc.listClientsHandler)
|
||||||
|
group.POST("/oidc/clients", jwtAuthMiddleware.Add(true), oc.createClientHandler)
|
||||||
|
group.GET("/oidc/clients/:id", oc.getClientHandler)
|
||||||
|
group.PUT("/oidc/clients/:id", jwtAuthMiddleware.Add(true), oc.updateClientHandler)
|
||||||
|
group.DELETE("/oidc/clients/:id", jwtAuthMiddleware.Add(true), oc.deleteClientHandler)
|
||||||
|
|
||||||
|
group.POST("/oidc/clients/:id/secret", jwtAuthMiddleware.Add(true), oc.createClientSecretHandler)
|
||||||
|
|
||||||
|
group.GET("/oidc/clients/:id/logo", oc.getClientLogoHandler)
|
||||||
|
group.DELETE("/oidc/clients/:id/logo", oc.deleteClientLogoHandler)
|
||||||
|
group.POST("/oidc/clients/:id/logo", jwtAuthMiddleware.Add(true), fileSizeLimitMiddleware.Add(2<<20), oc.updateClientLogoHandler)
|
||||||
|
}
|
||||||
|
|
||||||
|
type OidcController struct {
|
||||||
|
OidcService *service.OidcService
|
||||||
|
}
|
||||||
|
|
||||||
|
func (oc *OidcController) authorizeHandler(c *gin.Context) {
|
||||||
|
var parsedBody model.AuthorizeRequest
|
||||||
|
if err := c.ShouldBindJSON(&parsedBody); err != nil {
|
||||||
|
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
code, err := oc.OidcService.Authorize(parsedBody, c.GetString("userID"))
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, common.ErrOidcMissingAuthorization) {
|
||||||
|
utils.HandlerError(c, http.StatusForbidden, err.Error())
|
||||||
|
} else {
|
||||||
|
utils.UnknownHandlerError(c, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{"code": code})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (oc *OidcController) authorizeNewClientHandler(c *gin.Context) {
|
||||||
|
var parsedBody model.AuthorizeNewClientDto
|
||||||
|
if err := c.ShouldBindJSON(&parsedBody); err != nil {
|
||||||
|
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
code, err := oc.OidcService.AuthorizeNewClient(parsedBody, c.GetString("userID"))
|
||||||
|
if err != nil {
|
||||||
|
utils.UnknownHandlerError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{"code": code})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (oc *OidcController) createIDTokenHandler(c *gin.Context) {
|
||||||
|
var body model.OidcIdTokenDto
|
||||||
|
|
||||||
|
if err := c.ShouldBind(&body); err != nil {
|
||||||
|
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
idToken, err := oc.OidcService.CreateIDToken(body)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, common.ErrOidcGrantTypeNotSupported) ||
|
||||||
|
errors.Is(err, common.ErrOidcMissingClientCredentials) ||
|
||||||
|
errors.Is(err, common.ErrOidcClientSecretInvalid) ||
|
||||||
|
errors.Is(err, common.ErrOidcInvalidAuthorizationCode) {
|
||||||
|
utils.HandlerError(c, http.StatusBadRequest, err.Error())
|
||||||
|
} else {
|
||||||
|
utils.UnknownHandlerError(c, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{"id_token": idToken})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (oc *OidcController) getClientHandler(c *gin.Context) {
|
||||||
|
clientId := c.Param("id")
|
||||||
|
client, err := oc.OidcService.GetClient(clientId)
|
||||||
|
if err != nil {
|
||||||
|
utils.UnknownHandlerError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, client)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (oc *OidcController) listClientsHandler(c *gin.Context) {
|
||||||
|
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||||
|
pageSize, _ := strconv.Atoi(c.DefaultQuery("limit", "10"))
|
||||||
|
searchTerm := c.Query("search")
|
||||||
|
|
||||||
|
clients, pagination, err := oc.OidcService.ListClients(searchTerm, page, pageSize)
|
||||||
|
if err != nil {
|
||||||
|
utils.UnknownHandlerError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"data": clients,
|
||||||
|
"pagination": pagination,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (oc *OidcController) createClientHandler(c *gin.Context) {
|
||||||
|
var input model.OidcClientCreateDto
|
||||||
|
if err := c.ShouldBindJSON(&input); err != nil {
|
||||||
|
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
client, err := oc.OidcService.CreateClient(input, c.GetString("userID"))
|
||||||
|
if err != nil {
|
||||||
|
utils.UnknownHandlerError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusCreated, client)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (oc *OidcController) deleteClientHandler(c *gin.Context) {
|
||||||
|
err := oc.OidcService.DeleteClient(c.Param("id"))
|
||||||
|
if err != nil {
|
||||||
|
utils.HandlerError(c, http.StatusNotFound, "OIDC client not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Status(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (oc *OidcController) updateClientHandler(c *gin.Context) {
|
||||||
|
var input model.OidcClientCreateDto
|
||||||
|
if err := c.ShouldBindJSON(&input); err != nil {
|
||||||
|
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
client, err := oc.OidcService.UpdateClient(c.Param("id"), input)
|
||||||
|
if err != nil {
|
||||||
|
utils.UnknownHandlerError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusNoContent, client)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (oc *OidcController) createClientSecretHandler(c *gin.Context) {
|
||||||
|
secret, err := oc.OidcService.CreateClientSecret(c.Param("id"))
|
||||||
|
if err != nil {
|
||||||
|
utils.UnknownHandlerError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{"secret": secret})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (oc *OidcController) getClientLogoHandler(c *gin.Context) {
|
||||||
|
imagePath, mimeType, err := oc.OidcService.GetClientLogo(c.Param("id"))
|
||||||
|
if err != nil {
|
||||||
|
utils.UnknownHandlerError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Header("Content-Type", mimeType)
|
||||||
|
c.File(imagePath)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (oc *OidcController) updateClientLogoHandler(c *gin.Context) {
|
||||||
|
file, err := c.FormFile("file")
|
||||||
|
if err != nil {
|
||||||
|
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = oc.OidcService.UpdateClientLogo(c.Param("id"), file)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, common.ErrFileTypeNotSupported) {
|
||||||
|
utils.HandlerError(c, http.StatusBadRequest, err.Error())
|
||||||
|
} else {
|
||||||
|
utils.UnknownHandlerError(c, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Status(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (oc *OidcController) deleteClientLogoHandler(c *gin.Context) {
|
||||||
|
err := oc.OidcService.DeleteClientLogo(c.Param("id"))
|
||||||
|
if err != nil {
|
||||||
|
utils.UnknownHandlerError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Status(http.StatusNoContent)
|
||||||
|
}
|
||||||
36
backend/internal/controller/test_controller.go
Normal file
36
backend/internal/controller/test_controller.go
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stonith404/pocket-id/backend/internal/service"
|
||||||
|
"github.com/stonith404/pocket-id/backend/internal/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewTestController(group *gin.RouterGroup, testService *service.TestService) {
|
||||||
|
testController := &TestController{TestService: testService}
|
||||||
|
|
||||||
|
group.POST("/test/reset", testController.resetAndSeedHandler)
|
||||||
|
}
|
||||||
|
|
||||||
|
type TestController struct {
|
||||||
|
TestService *service.TestService
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tc *TestController) resetAndSeedHandler(c *gin.Context) {
|
||||||
|
if err := tc.TestService.ResetDatabase(); err != nil {
|
||||||
|
utils.UnknownHandlerError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tc.TestService.ResetApplicationImages(); err != nil {
|
||||||
|
utils.UnknownHandlerError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tc.TestService.SeedDatabase(); err != nil {
|
||||||
|
utils.UnknownHandlerError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(200, gin.H{"message": "Database reset and seeded"})
|
||||||
|
}
|
||||||
182
backend/internal/controller/user_controller.go
Normal file
182
backend/internal/controller/user_controller.go
Normal file
@@ -0,0 +1,182 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stonith404/pocket-id/backend/internal/common"
|
||||||
|
"github.com/stonith404/pocket-id/backend/internal/middleware"
|
||||||
|
"github.com/stonith404/pocket-id/backend/internal/model"
|
||||||
|
"github.com/stonith404/pocket-id/backend/internal/service"
|
||||||
|
"github.com/stonith404/pocket-id/backend/internal/utils"
|
||||||
|
"golang.org/x/time/rate"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewUserController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.JwtAuthMiddleware, rateLimitMiddleware *middleware.RateLimitMiddleware, userService *service.UserService) {
|
||||||
|
uc := UserController{
|
||||||
|
UserService: userService,
|
||||||
|
}
|
||||||
|
|
||||||
|
group.GET("/users", jwtAuthMiddleware.Add(true), uc.listUsersHandler)
|
||||||
|
group.GET("/users/me", jwtAuthMiddleware.Add(false), uc.getCurrentUserHandler)
|
||||||
|
group.GET("/users/:id", jwtAuthMiddleware.Add(true), uc.getUserHandler)
|
||||||
|
group.POST("/users", jwtAuthMiddleware.Add(true), uc.createUserHandler)
|
||||||
|
group.PUT("/users/:id", jwtAuthMiddleware.Add(true), uc.updateUserHandler)
|
||||||
|
group.PUT("/users/me", jwtAuthMiddleware.Add(false), uc.updateCurrentUserHandler)
|
||||||
|
group.DELETE("/users/:id", jwtAuthMiddleware.Add(true), uc.deleteUserHandler)
|
||||||
|
|
||||||
|
group.POST("/users/:id/one-time-access-token", jwtAuthMiddleware.Add(true), uc.createOneTimeAccessTokenHandler)
|
||||||
|
group.POST("/one-time-access-token/:token", rateLimitMiddleware.Add(rate.Every(10*time.Second), 5), uc.exchangeOneTimeAccessTokenHandler)
|
||||||
|
group.POST("/one-time-access-token/setup", uc.getSetupAccessTokenHandler)
|
||||||
|
}
|
||||||
|
|
||||||
|
type UserController struct {
|
||||||
|
UserService *service.UserService
|
||||||
|
}
|
||||||
|
|
||||||
|
func (uc *UserController) listUsersHandler(c *gin.Context) {
|
||||||
|
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||||
|
pageSize, _ := strconv.Atoi(c.DefaultQuery("limit", "10"))
|
||||||
|
searchTerm := c.Query("search")
|
||||||
|
|
||||||
|
users, pagination, err := uc.UserService.ListUsers(searchTerm, page, pageSize)
|
||||||
|
if err != nil {
|
||||||
|
utils.UnknownHandlerError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"data": users,
|
||||||
|
"pagination": pagination,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (uc *UserController) getUserHandler(c *gin.Context) {
|
||||||
|
user, err := uc.UserService.GetUser(c.Param("id"))
|
||||||
|
if err != nil {
|
||||||
|
utils.UnknownHandlerError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, user)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (uc *UserController) getCurrentUserHandler(c *gin.Context) {
|
||||||
|
user, err := uc.UserService.GetUser(c.GetString("userID"))
|
||||||
|
if err != nil {
|
||||||
|
utils.UnknownHandlerError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, user)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (uc *UserController) deleteUserHandler(c *gin.Context) {
|
||||||
|
if err := uc.UserService.DeleteUser(c.Param("id")); err != nil {
|
||||||
|
utils.UnknownHandlerError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Status(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (uc *UserController) createUserHandler(c *gin.Context) {
|
||||||
|
var user model.User
|
||||||
|
if err := c.ShouldBindJSON(&user); err != nil {
|
||||||
|
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := uc.UserService.CreateUser(&user); err != nil {
|
||||||
|
if errors.Is(err, common.ErrEmailTaken) || errors.Is(err, common.ErrUsernameTaken) {
|
||||||
|
utils.HandlerError(c, http.StatusConflict, err.Error())
|
||||||
|
} else {
|
||||||
|
utils.UnknownHandlerError(c, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusCreated, user)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (uc *UserController) updateUserHandler(c *gin.Context) {
|
||||||
|
uc.updateUser(c, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (uc *UserController) updateCurrentUserHandler(c *gin.Context) {
|
||||||
|
uc.updateUser(c, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (uc *UserController) createOneTimeAccessTokenHandler(c *gin.Context) {
|
||||||
|
var input model.OneTimeAccessTokenCreateDto
|
||||||
|
if err := c.ShouldBindJSON(&input); err != nil {
|
||||||
|
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := uc.UserService.CreateOneTimeAccessToken(input.UserID, input.ExpiresAt)
|
||||||
|
if err != nil {
|
||||||
|
utils.UnknownHandlerError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusCreated, gin.H{"token": token})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (uc *UserController) exchangeOneTimeAccessTokenHandler(c *gin.Context) {
|
||||||
|
user, token, err := uc.UserService.ExchangeOneTimeAccessToken(c.Param("token"))
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, common.ErrTokenInvalidOrExpired) {
|
||||||
|
utils.HandlerError(c, http.StatusUnauthorized, err.Error())
|
||||||
|
} else {
|
||||||
|
utils.UnknownHandlerError(c, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.SetCookie("access_token", token, int(time.Hour.Seconds()), "/", "", false, true)
|
||||||
|
c.JSON(http.StatusOK, user)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (uc *UserController) getSetupAccessTokenHandler(c *gin.Context) {
|
||||||
|
user, token, err := uc.UserService.SetupInitialAdmin()
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, common.ErrSetupAlreadyCompleted) {
|
||||||
|
utils.HandlerError(c, http.StatusBadRequest, err.Error())
|
||||||
|
} else {
|
||||||
|
utils.UnknownHandlerError(c, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.SetCookie("access_token", token, int(time.Hour.Seconds()), "/", "", false, true)
|
||||||
|
c.JSON(http.StatusOK, user)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (uc *UserController) updateUser(c *gin.Context, updateOwnUser bool) {
|
||||||
|
var updatedUser model.User
|
||||||
|
if err := c.ShouldBindJSON(&updatedUser); err != nil {
|
||||||
|
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var userID string
|
||||||
|
if updateOwnUser {
|
||||||
|
userID = c.GetString("userID")
|
||||||
|
} else {
|
||||||
|
userID = c.Param("id")
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := uc.UserService.UpdateUser(userID, updatedUser, updateOwnUser)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, common.ErrEmailTaken) || errors.Is(err, common.ErrUsernameTaken) {
|
||||||
|
utils.HandlerError(c, http.StatusConflict, err.Error())
|
||||||
|
} else {
|
||||||
|
utils.UnknownHandlerError(c, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, user)
|
||||||
|
}
|
||||||
160
backend/internal/controller/webauthn_controller.go
Normal file
160
backend/internal/controller/webauthn_controller.go
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"github.com/go-webauthn/webauthn/protocol"
|
||||||
|
"github.com/stonith404/pocket-id/backend/internal/middleware"
|
||||||
|
"github.com/stonith404/pocket-id/backend/internal/model"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stonith404/pocket-id/backend/internal/common"
|
||||||
|
"github.com/stonith404/pocket-id/backend/internal/service"
|
||||||
|
"github.com/stonith404/pocket-id/backend/internal/utils"
|
||||||
|
"golang.org/x/time/rate"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewWebauthnController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.JwtAuthMiddleware, rateLimitMiddleware *middleware.RateLimitMiddleware, webauthnService *service.WebAuthnService, jwtService *service.JwtService) {
|
||||||
|
wc := &WebauthnController{webAuthnService: webauthnService, jwtService: jwtService}
|
||||||
|
group.GET("/webauthn/register/start", jwtAuthMiddleware.Add(false), wc.beginRegistrationHandler)
|
||||||
|
group.POST("/webauthn/register/finish", jwtAuthMiddleware.Add(false), wc.verifyRegistrationHandler)
|
||||||
|
|
||||||
|
group.GET("/webauthn/login/start", wc.beginLoginHandler)
|
||||||
|
group.POST("/webauthn/login/finish", rateLimitMiddleware.Add(rate.Every(10*time.Second), 5), wc.verifyLoginHandler)
|
||||||
|
|
||||||
|
group.POST("/webauthn/logout", jwtAuthMiddleware.Add(false), wc.logoutHandler)
|
||||||
|
|
||||||
|
group.GET("/webauthn/credentials", jwtAuthMiddleware.Add(false), wc.listCredentialsHandler)
|
||||||
|
group.PATCH("/webauthn/credentials/:id", jwtAuthMiddleware.Add(false), wc.updateCredentialHandler)
|
||||||
|
group.DELETE("/webauthn/credentials/:id", jwtAuthMiddleware.Add(false), wc.deleteCredentialHandler)
|
||||||
|
}
|
||||||
|
|
||||||
|
type WebauthnController struct {
|
||||||
|
webAuthnService *service.WebAuthnService
|
||||||
|
jwtService *service.JwtService
|
||||||
|
}
|
||||||
|
|
||||||
|
func (wc *WebauthnController) beginRegistrationHandler(c *gin.Context) {
|
||||||
|
userID := c.GetString("userID")
|
||||||
|
options, err := wc.webAuthnService.BeginRegistration(userID)
|
||||||
|
if err != nil {
|
||||||
|
utils.UnknownHandlerError(c, err)
|
||||||
|
log.Println(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.SetCookie("session_id", options.SessionID, int(options.Timeout.Seconds()), "/", "", false, true)
|
||||||
|
c.JSON(http.StatusOK, options.Response)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (wc *WebauthnController) verifyRegistrationHandler(c *gin.Context) {
|
||||||
|
sessionID, err := c.Cookie("session_id")
|
||||||
|
if err != nil {
|
||||||
|
utils.HandlerError(c, http.StatusBadRequest, "Session ID missing")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userID := c.GetString("userID")
|
||||||
|
credential, err := wc.webAuthnService.VerifyRegistration(sessionID, userID, c.Request)
|
||||||
|
if err != nil {
|
||||||
|
utils.UnknownHandlerError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, credential)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (wc *WebauthnController) beginLoginHandler(c *gin.Context) {
|
||||||
|
options, err := wc.webAuthnService.BeginLogin()
|
||||||
|
if err != nil {
|
||||||
|
utils.UnknownHandlerError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.SetCookie("session_id", options.SessionID, int(options.Timeout.Seconds()), "/", "", false, true)
|
||||||
|
c.JSON(http.StatusOK, options.Response)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (wc *WebauthnController) verifyLoginHandler(c *gin.Context) {
|
||||||
|
sessionID, err := c.Cookie("session_id")
|
||||||
|
if err != nil {
|
||||||
|
utils.HandlerError(c, http.StatusBadRequest, "Session ID missing")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
credentialAssertionData, err := protocol.ParseCredentialRequestResponseBody(c.Request.Body)
|
||||||
|
if err != nil {
|
||||||
|
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userID := c.GetString("userID")
|
||||||
|
user, err := wc.webAuthnService.VerifyLogin(sessionID, userID, credentialAssertionData)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, common.ErrInvalidCredentials) {
|
||||||
|
utils.HandlerError(c, http.StatusUnauthorized, err.Error())
|
||||||
|
} else {
|
||||||
|
utils.UnknownHandlerError(c, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := wc.jwtService.GenerateAccessToken(*user)
|
||||||
|
if err != nil {
|
||||||
|
utils.UnknownHandlerError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.SetCookie("access_token", token, int(time.Hour.Seconds()), "/", "", false, true)
|
||||||
|
c.JSON(http.StatusOK, user)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (wc *WebauthnController) listCredentialsHandler(c *gin.Context) {
|
||||||
|
userID := c.GetString("userID")
|
||||||
|
credentials, err := wc.webAuthnService.ListCredentials(userID)
|
||||||
|
if err != nil {
|
||||||
|
utils.UnknownHandlerError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, credentials)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (wc *WebauthnController) deleteCredentialHandler(c *gin.Context) {
|
||||||
|
userID := c.GetString("userID")
|
||||||
|
credentialID := c.Param("id")
|
||||||
|
|
||||||
|
err := wc.webAuthnService.DeleteCredential(userID, credentialID)
|
||||||
|
if err != nil {
|
||||||
|
utils.UnknownHandlerError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Status(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (wc *WebauthnController) updateCredentialHandler(c *gin.Context) {
|
||||||
|
userID := c.GetString("userID")
|
||||||
|
credentialID := c.Param("id")
|
||||||
|
|
||||||
|
var input model.WebauthnCredentialUpdateDto
|
||||||
|
if err := c.ShouldBindJSON(&input); err != nil {
|
||||||
|
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err := wc.webAuthnService.UpdateCredential(userID, credentialID, input.Name)
|
||||||
|
if err != nil {
|
||||||
|
utils.UnknownHandlerError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Status(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (wc *WebauthnController) logoutHandler(c *gin.Context) {
|
||||||
|
c.SetCookie("access_token", "", 0, "/", "", false, true)
|
||||||
|
c.Status(http.StatusNoContent)
|
||||||
|
}
|
||||||
@@ -1,19 +1,25 @@
|
|||||||
package handler
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"golang-rest-api-template/internal/common"
|
"github.com/stonith404/pocket-id/backend/internal/common"
|
||||||
"golang-rest-api-template/internal/utils"
|
"github.com/stonith404/pocket-id/backend/internal/service"
|
||||||
|
"github.com/stonith404/pocket-id/backend/internal/utils"
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
func RegisterWellKnownRoutes(group *gin.RouterGroup) {
|
func NewWellKnownController(group *gin.RouterGroup, jwtService *service.JwtService) {
|
||||||
group.GET("/.well-known/jwks.json", jwks)
|
wkc := &WellKnownController{jwtService: jwtService}
|
||||||
group.GET("/.well-known/openid-configuration", openIDConfiguration)
|
group.GET("/.well-known/jwks.json", wkc.jwksHandler)
|
||||||
|
group.GET("/.well-known/openid-configuration", wkc.openIDConfigurationHandler)
|
||||||
}
|
}
|
||||||
|
|
||||||
func jwks(c *gin.Context) {
|
type WellKnownController struct {
|
||||||
jwk, err := common.GetJWK()
|
jwtService *service.JwtService
|
||||||
|
}
|
||||||
|
|
||||||
|
func (wkc *WellKnownController) jwksHandler(c *gin.Context) {
|
||||||
|
jwk, err := wkc.jwtService.GetJWK()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.UnknownHandlerError(c, err)
|
utils.UnknownHandlerError(c, err)
|
||||||
return
|
return
|
||||||
@@ -22,7 +28,7 @@ func jwks(c *gin.Context) {
|
|||||||
c.JSON(http.StatusOK, gin.H{"keys": []interface{}{jwk}})
|
c.JSON(http.StatusOK, gin.H{"keys": []interface{}{jwk}})
|
||||||
}
|
}
|
||||||
|
|
||||||
func openIDConfiguration(c *gin.Context) {
|
func (wkc *WellKnownController) openIDConfigurationHandler(c *gin.Context) {
|
||||||
appUrl := common.EnvConfig.AppURL
|
appUrl := common.EnvConfig.AppURL
|
||||||
config := map[string]interface{}{
|
config := map[string]interface{}{
|
||||||
"issuer": appUrl,
|
"issuer": appUrl,
|
||||||
@@ -1,196 +0,0 @@
|
|||||||
package handler
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"golang-rest-api-template/internal/common"
|
|
||||||
"golang-rest-api-template/internal/common/middleware"
|
|
||||||
"golang-rest-api-template/internal/model"
|
|
||||||
"golang-rest-api-template/internal/utils"
|
|
||||||
"gorm.io/gorm"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"reflect"
|
|
||||||
)
|
|
||||||
|
|
||||||
func RegisterConfigurationRoutes(group *gin.RouterGroup) {
|
|
||||||
group.GET("/application-configuration", listApplicationConfigurationHandler)
|
|
||||||
group.GET("/application-configuration/all", middleware.JWTAuth(true), listAllApplicationConfigurationHandler)
|
|
||||||
group.PUT("/application-configuration", updateApplicationConfigurationHandler)
|
|
||||||
|
|
||||||
group.GET("/application-configuration/logo", getLogoHandler)
|
|
||||||
group.GET("/application-configuration/background-image", getBackgroundImageHandler)
|
|
||||||
group.GET("/application-configuration/favicon", getFaviconHandler)
|
|
||||||
group.PUT("/application-configuration/logo", middleware.JWTAuth(true), updateLogoHandler)
|
|
||||||
group.PUT("/application-configuration/favicon", middleware.JWTAuth(true), updateFaviconHandler)
|
|
||||||
group.PUT("/application-configuration/background-image", middleware.JWTAuth(true), updateBackgroundImageHandler)
|
|
||||||
}
|
|
||||||
|
|
||||||
func listApplicationConfigurationHandler(c *gin.Context) {
|
|
||||||
listApplicationConfiguration(c, false)
|
|
||||||
}
|
|
||||||
|
|
||||||
func listAllApplicationConfigurationHandler(c *gin.Context) {
|
|
||||||
listApplicationConfiguration(c, true)
|
|
||||||
}
|
|
||||||
|
|
||||||
func updateApplicationConfigurationHandler(c *gin.Context) {
|
|
||||||
var input model.ApplicationConfigurationUpdateDto
|
|
||||||
if err := c.ShouldBindJSON(&input); err != nil {
|
|
||||||
utils.HandlerError(c, http.StatusBadRequest, "invalid request body")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
savedConfigVariables := make([]model.ApplicationConfigurationVariable, 10)
|
|
||||||
|
|
||||||
tx := common.DB.Begin()
|
|
||||||
rt := reflect.ValueOf(input).Type()
|
|
||||||
rv := reflect.ValueOf(input)
|
|
||||||
|
|
||||||
// Loop over the input struct fields and update the related configuration variables
|
|
||||||
for i := 0; i < rt.NumField(); i++ {
|
|
||||||
field := rt.Field(i)
|
|
||||||
key := field.Tag.Get("json")
|
|
||||||
value := rv.FieldByName(field.Name).String()
|
|
||||||
|
|
||||||
// Get the existing configuration variable from the db
|
|
||||||
var applicationConfigurationVariable model.ApplicationConfigurationVariable
|
|
||||||
if err := tx.First(&applicationConfigurationVariable, "key = ? AND is_internal = false", key).Error; err != nil {
|
|
||||||
tx.Rollback()
|
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
||||||
utils.HandlerError(c, http.StatusNotFound, fmt.Sprintf("Invalid configuration variable '%s'", value))
|
|
||||||
} else {
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update the value of the existing configuration variable and save it
|
|
||||||
applicationConfigurationVariable.Value = value
|
|
||||||
if err := tx.Save(&applicationConfigurationVariable).Error; err != nil {
|
|
||||||
tx.Rollback()
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
savedConfigVariables[i] = applicationConfigurationVariable
|
|
||||||
}
|
|
||||||
|
|
||||||
tx.Commit()
|
|
||||||
|
|
||||||
if err := common.LoadDbConfigFromDb(); err != nil {
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.JSON(http.StatusOK, savedConfigVariables)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func getLogoHandler(c *gin.Context) {
|
|
||||||
imagType := common.DbConfig.LogoImageType.Value
|
|
||||||
getImage(c, "logo", imagType)
|
|
||||||
}
|
|
||||||
|
|
||||||
func getFaviconHandler(c *gin.Context) {
|
|
||||||
getImage(c, "favicon", "ico")
|
|
||||||
}
|
|
||||||
|
|
||||||
func getBackgroundImageHandler(c *gin.Context) {
|
|
||||||
imageType := common.DbConfig.BackgroundImageType.Value
|
|
||||||
getImage(c, "background", imageType)
|
|
||||||
}
|
|
||||||
|
|
||||||
func updateLogoHandler(c *gin.Context) {
|
|
||||||
imageType := common.DbConfig.LogoImageType.Value
|
|
||||||
updateImage(c, "logo", imageType)
|
|
||||||
}
|
|
||||||
|
|
||||||
func updateFaviconHandler(c *gin.Context) {
|
|
||||||
file, err := c.FormFile("file")
|
|
||||||
if err != nil {
|
|
||||||
utils.HandlerError(c, http.StatusBadRequest, "invalid request body")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
fileType := utils.GetFileExtension(file.Filename)
|
|
||||||
if fileType != "ico" {
|
|
||||||
utils.HandlerError(c, http.StatusBadRequest, "File must be of type .ico")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
updateImage(c, "favicon", "ico")
|
|
||||||
}
|
|
||||||
|
|
||||||
func updateBackgroundImageHandler(c *gin.Context) {
|
|
||||||
imagType := common.DbConfig.BackgroundImageType.Value
|
|
||||||
updateImage(c, "background", imagType)
|
|
||||||
}
|
|
||||||
|
|
||||||
func getImage(c *gin.Context, name string, imageType string) {
|
|
||||||
imagePath := fmt.Sprintf("%s/application-images/%s.%s", common.EnvConfig.UploadPath, name, imageType)
|
|
||||||
mimeType := utils.GetImageMimeType(imageType)
|
|
||||||
|
|
||||||
c.Header("Content-Type", mimeType)
|
|
||||||
c.File(imagePath)
|
|
||||||
}
|
|
||||||
|
|
||||||
func updateImage(c *gin.Context, imageName string, oldImageType string) {
|
|
||||||
file, err := c.FormFile("file")
|
|
||||||
if err != nil {
|
|
||||||
utils.HandlerError(c, http.StatusBadRequest, "invalid request body")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
fileType := utils.GetFileExtension(file.Filename)
|
|
||||||
if mimeType := utils.GetImageMimeType(fileType); mimeType == "" {
|
|
||||||
utils.HandlerError(c, http.StatusBadRequest, "File type not supported")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Delete the old image if it has a different file type
|
|
||||||
if fileType != oldImageType {
|
|
||||||
oldImagePath := fmt.Sprintf("%s/application-images/%s.%s", common.EnvConfig.UploadPath, imageName, oldImageType)
|
|
||||||
if err := os.Remove(oldImagePath); err != nil {
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
imagePath := fmt.Sprintf("%s/application-images/%s.%s", common.EnvConfig.UploadPath, imageName, fileType)
|
|
||||||
err = c.SaveUploadedFile(file, imagePath)
|
|
||||||
if err != nil {
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update the file type in the database
|
|
||||||
key := fmt.Sprintf("%sImageType", imageName)
|
|
||||||
err = common.DB.Model(&model.ApplicationConfigurationVariable{}).Where("key = ?", key).Update("value", fileType).Error
|
|
||||||
if err != nil {
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := common.LoadDbConfigFromDb(); err != nil {
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Status(http.StatusNoContent)
|
|
||||||
}
|
|
||||||
|
|
||||||
func listApplicationConfiguration(c *gin.Context, showAll bool) {
|
|
||||||
var configuration []model.ApplicationConfigurationVariable
|
|
||||||
var err error
|
|
||||||
|
|
||||||
if showAll {
|
|
||||||
err = common.DB.Find(&configuration).Error
|
|
||||||
} else {
|
|
||||||
err = common.DB.Find(&configuration, "is_public = true").Error
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.JSON(200, configuration)
|
|
||||||
}
|
|
||||||
@@ -1,415 +0,0 @@
|
|||||||
package handler
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"golang-rest-api-template/internal/common"
|
|
||||||
"golang-rest-api-template/internal/common/middleware"
|
|
||||||
"golang-rest-api-template/internal/model"
|
|
||||||
"golang-rest-api-template/internal/utils"
|
|
||||||
"golang.org/x/crypto/bcrypt"
|
|
||||||
"gorm.io/gorm"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
func RegisterOIDCRoutes(group *gin.RouterGroup) {
|
|
||||||
group.POST("/oidc/authorize", middleware.JWTAuth(false), authorizeHandler)
|
|
||||||
group.POST("/oidc/authorize/new-client", middleware.JWTAuth(false), authorizeNewClientHandler)
|
|
||||||
group.POST("/oidc/token", createIDTokenHandler)
|
|
||||||
|
|
||||||
group.GET("/oidc/clients", middleware.JWTAuth(true), listClientsHandler)
|
|
||||||
group.POST("/oidc/clients", middleware.JWTAuth(true), createClientHandler)
|
|
||||||
group.GET("/oidc/clients/:id", getClientHandler)
|
|
||||||
group.PUT("/oidc/clients/:id", middleware.JWTAuth(true), updateClientHandler)
|
|
||||||
group.DELETE("/oidc/clients/:id", middleware.JWTAuth(true), deleteClientHandler)
|
|
||||||
|
|
||||||
group.POST("/oidc/clients/:id/secret", middleware.JWTAuth(true), createClientSecretHandler)
|
|
||||||
|
|
||||||
group.GET("/oidc/clients/:id/logo", getClientLogoHandler)
|
|
||||||
group.DELETE("/oidc/clients/:id/logo", deleteClientLogoHandler)
|
|
||||||
group.POST("/oidc/clients/:id/logo", middleware.JWTAuth(true), middleware.LimitFileSize(2<<20), updateClientLogoHandler)
|
|
||||||
}
|
|
||||||
|
|
||||||
type AuthorizeRequest struct {
|
|
||||||
ClientID string `json:"clientID" binding:"required"`
|
|
||||||
Scope string `json:"scope" binding:"required"`
|
|
||||||
Nonce string `json:"nonce"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func authorizeHandler(c *gin.Context) {
|
|
||||||
var parsedBody AuthorizeRequest
|
|
||||||
if err := c.ShouldBindJSON(&parsedBody); err != nil {
|
|
||||||
utils.HandlerError(c, http.StatusBadRequest, "invalid request body")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var userAuthorizedOIDCClient model.UserAuthorizedOidcClient
|
|
||||||
common.DB.First(&userAuthorizedOIDCClient, "client_id = ? AND user_id = ?", parsedBody.ClientID, c.GetString("userID"))
|
|
||||||
|
|
||||||
// If the record isn't found or the scope is different return an error
|
|
||||||
// The client will have to call the authorizeNewClientHandler
|
|
||||||
if userAuthorizedOIDCClient.Scope != parsedBody.Scope {
|
|
||||||
utils.HandlerError(c, http.StatusForbidden, "missing authorization")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
authorizationCode, err := createAuthorizationCode(parsedBody.ClientID, c.GetString("userID"), parsedBody.Scope, parsedBody.Nonce)
|
|
||||||
if err != nil {
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{"code": authorizationCode})
|
|
||||||
}
|
|
||||||
|
|
||||||
// authorizeNewClientHandler authorizes a new client for the user
|
|
||||||
// a new client is a new client when the user has not authorized the client before
|
|
||||||
func authorizeNewClientHandler(c *gin.Context) {
|
|
||||||
var parsedBody model.AuthorizeNewClientDto
|
|
||||||
if err := c.ShouldBindJSON(&parsedBody); err != nil {
|
|
||||||
utils.HandlerError(c, http.StatusBadRequest, "invalid request body")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
userAuthorizedClient := model.UserAuthorizedOidcClient{
|
|
||||||
UserID: c.GetString("userID"),
|
|
||||||
ClientID: parsedBody.ClientID,
|
|
||||||
Scope: parsedBody.Scope,
|
|
||||||
}
|
|
||||||
err := common.DB.Create(&userAuthorizedClient).Error
|
|
||||||
|
|
||||||
if err != nil && errors.Is(err, gorm.ErrDuplicatedKey) {
|
|
||||||
err = common.DB.Model(&userAuthorizedClient).Update("scope", parsedBody.Scope).Error
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
authorizationCode, err := createAuthorizationCode(parsedBody.ClientID, c.GetString("userID"), parsedBody.Scope, parsedBody.Nonce)
|
|
||||||
if err != nil {
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{"code": authorizationCode})
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func createIDTokenHandler(c *gin.Context) {
|
|
||||||
var body model.OidcIdTokenDto
|
|
||||||
|
|
||||||
if err := c.ShouldBind(&body); err != nil {
|
|
||||||
utils.HandlerError(c, http.StatusBadRequest, "invalid request body")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Currently only authorization_code grant type is supported
|
|
||||||
if body.GrantType != "authorization_code" {
|
|
||||||
utils.HandlerError(c, http.StatusBadRequest, "grant type not supported")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
clientID := body.ClientID
|
|
||||||
clientSecret := body.ClientSecret
|
|
||||||
|
|
||||||
// Client id and secret can also be passed over the Authorization header
|
|
||||||
if clientID == "" || clientSecret == "" {
|
|
||||||
var ok bool
|
|
||||||
clientID, clientSecret, ok = c.Request.BasicAuth()
|
|
||||||
if !ok {
|
|
||||||
utils.HandlerError(c, http.StatusBadRequest, "Client id and secret not provided")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the client
|
|
||||||
var client model.OidcClient
|
|
||||||
err := common.DB.First(&client, "id = ?", clientID, clientSecret).Error
|
|
||||||
if err != nil {
|
|
||||||
utils.HandlerError(c, http.StatusBadRequest, "OIDC OIDC client not found")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if client secret is correct
|
|
||||||
err = bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret))
|
|
||||||
if err != nil {
|
|
||||||
utils.HandlerError(c, http.StatusBadRequest, "invalid client secret")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var authorizationCodeMetaData model.OidcAuthorizationCode
|
|
||||||
err = common.DB.Preload("User").First(&authorizationCodeMetaData, "code = ?", body.Code).Error
|
|
||||||
if err != nil {
|
|
||||||
utils.HandlerError(c, http.StatusBadRequest, "invalid authorization code")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if the client id matches the client id in the authorization code and if the code has expired
|
|
||||||
if authorizationCodeMetaData.ClientID != clientID && authorizationCodeMetaData.ExpiresAt.Before(time.Now()) {
|
|
||||||
utils.HandlerError(c, http.StatusBadRequest, "invalid authorization code")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
idToken, e := common.GenerateIDToken(authorizationCodeMetaData.User, clientID, authorizationCodeMetaData.Scope, authorizationCodeMetaData.Nonce)
|
|
||||||
if e != nil {
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Delete the authorization code after it has been used
|
|
||||||
common.DB.Delete(&authorizationCodeMetaData)
|
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{"id_token": idToken})
|
|
||||||
}
|
|
||||||
|
|
||||||
func getClientHandler(c *gin.Context) {
|
|
||||||
clientId := c.Param("id")
|
|
||||||
|
|
||||||
var client model.OidcClient
|
|
||||||
err := common.DB.First(&client, "id = ?", clientId).Error
|
|
||||||
if err != nil {
|
|
||||||
utils.HandlerError(c, http.StatusNotFound, "OIDC client not found")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.JSON(http.StatusOK, client)
|
|
||||||
}
|
|
||||||
|
|
||||||
func listClientsHandler(c *gin.Context) {
|
|
||||||
var clients []model.OidcClient
|
|
||||||
searchTerm := c.Query("search")
|
|
||||||
|
|
||||||
query := common.DB.Model(&model.OidcClient{})
|
|
||||||
|
|
||||||
if searchTerm != "" {
|
|
||||||
searchPattern := "%" + searchTerm + "%"
|
|
||||||
query = query.Where("name LIKE ?", searchPattern)
|
|
||||||
}
|
|
||||||
|
|
||||||
pagination, err := utils.Paginate(c, query, &clients)
|
|
||||||
if err != nil {
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
|
||||||
"data": clients,
|
|
||||||
"pagination": pagination,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func createClientHandler(c *gin.Context) {
|
|
||||||
var input model.OidcClientCreateDto
|
|
||||||
if err := c.ShouldBindJSON(&input); err != nil {
|
|
||||||
utils.HandlerError(c, http.StatusBadRequest, "invalid request body")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
client := model.OidcClient{
|
|
||||||
Name: input.Name,
|
|
||||||
CallbackURL: input.CallbackURL,
|
|
||||||
CreatedByID: c.GetString("userID"),
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := common.DB.Create(&client).Error; err != nil {
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.JSON(http.StatusCreated, client)
|
|
||||||
}
|
|
||||||
|
|
||||||
func deleteClientHandler(c *gin.Context) {
|
|
||||||
var client model.OidcClient
|
|
||||||
if err := common.DB.First(&client, "id = ?", c.Param("id")).Error; err != nil {
|
|
||||||
utils.HandlerError(c, http.StatusNotFound, "OIDC OIDC client not found")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := common.DB.Delete(&client).Error; err != nil {
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Status(http.StatusNoContent)
|
|
||||||
}
|
|
||||||
|
|
||||||
func updateClientHandler(c *gin.Context) {
|
|
||||||
var input model.OidcClientCreateDto
|
|
||||||
if err := c.ShouldBindJSON(&input); err != nil {
|
|
||||||
utils.HandlerError(c, http.StatusBadRequest, "invalid request body")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var client model.OidcClient
|
|
||||||
if err := common.DB.First(&client, "id = ?", c.Param("id")).Error; err != nil {
|
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
||||||
utils.HandlerError(c, http.StatusNotFound, "OIDC client not found")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
client.Name = input.Name
|
|
||||||
client.CallbackURL = input.CallbackURL
|
|
||||||
|
|
||||||
if err := common.DB.Save(&client).Error; err != nil {
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.JSON(http.StatusNoContent, client)
|
|
||||||
}
|
|
||||||
|
|
||||||
// createClientSecretHandler creates a new secret for the client and revokes the old one
|
|
||||||
func createClientSecretHandler(c *gin.Context) {
|
|
||||||
var client model.OidcClient
|
|
||||||
if err := common.DB.First(&client, "id = ?", c.Param("id")).Error; err != nil {
|
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
||||||
utils.HandlerError(c, http.StatusNotFound, "OIDC client not found")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
clientSecret, err := utils.GenerateRandomAlphanumericString(32)
|
|
||||||
if err != nil {
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
hashedSecret, err := bcrypt.GenerateFromPassword([]byte(clientSecret), bcrypt.DefaultCost)
|
|
||||||
if err != nil {
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
client.Secret = string(hashedSecret)
|
|
||||||
if err := common.DB.Save(&client).Error; err != nil {
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{"secret": clientSecret})
|
|
||||||
}
|
|
||||||
|
|
||||||
func getClientLogoHandler(c *gin.Context) {
|
|
||||||
var client model.OidcClient
|
|
||||||
if err := common.DB.First(&client, "id = ?", c.Param("id")).Error; err != nil {
|
|
||||||
utils.HandlerError(c, http.StatusNotFound, "OIDC client not found")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if client.ImageType == nil {
|
|
||||||
utils.HandlerError(c, http.StatusNotFound, "image not found")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
imageType := *client.ImageType
|
|
||||||
|
|
||||||
imagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, client.ID, imageType)
|
|
||||||
mimeType := utils.GetImageMimeType(imageType)
|
|
||||||
|
|
||||||
c.Header("Content-Type", mimeType)
|
|
||||||
c.File(imagePath)
|
|
||||||
}
|
|
||||||
|
|
||||||
func updateClientLogoHandler(c *gin.Context) {
|
|
||||||
file, err := c.FormFile("file")
|
|
||||||
if err != nil {
|
|
||||||
utils.HandlerError(c, http.StatusBadRequest, "invalid request body")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
fileType := utils.GetFileExtension(file.Filename)
|
|
||||||
if mimeType := utils.GetImageMimeType(fileType); mimeType == "" {
|
|
||||||
utils.HandlerError(c, http.StatusBadRequest, "file type not supported")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
imagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, c.Param("id"), fileType)
|
|
||||||
err = c.SaveUploadedFile(file, imagePath)
|
|
||||||
if err != nil {
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var client model.OidcClient
|
|
||||||
if err := common.DB.First(&client, "id = ?", c.Param("id")).Error; err != nil {
|
|
||||||
utils.HandlerError(c, http.StatusNotFound, "OIDC client not found")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Delete the old image if it has a different file type
|
|
||||||
if client.ImageType != nil && fileType != *client.ImageType {
|
|
||||||
oldImagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, client.ID, *client.ImageType)
|
|
||||||
if err := os.Remove(oldImagePath); err != nil {
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
client.ImageType = &fileType
|
|
||||||
if err := common.DB.Save(&client).Error; err != nil {
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Status(http.StatusNoContent)
|
|
||||||
}
|
|
||||||
|
|
||||||
func deleteClientLogoHandler(c *gin.Context) {
|
|
||||||
var client model.OidcClient
|
|
||||||
if err := common.DB.First(&client, "id = ?", c.Param("id")).Error; err != nil {
|
|
||||||
utils.HandlerError(c, http.StatusNotFound, "OIDC client not found")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if client.ImageType == nil {
|
|
||||||
utils.HandlerError(c, http.StatusNotFound, "image not found")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
imagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, client.ID, *client.ImageType)
|
|
||||||
if err := os.Remove(imagePath); err != nil {
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
client.ImageType = nil
|
|
||||||
if err := common.DB.Save(&client).Error; err != nil {
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Status(http.StatusNoContent)
|
|
||||||
}
|
|
||||||
|
|
||||||
func createAuthorizationCode(clientID string, userID string, scope string, nonce string) (string, error) {
|
|
||||||
randomString, err := utils.GenerateRandomAlphanumericString(32)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
oidcAuthorizationCode := model.OidcAuthorizationCode{
|
|
||||||
ExpiresAt: time.Now().Add(15 * time.Minute),
|
|
||||||
Code: randomString,
|
|
||||||
ClientID: clientID,
|
|
||||||
UserID: userID,
|
|
||||||
Scope: scope,
|
|
||||||
Nonce: nonce,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := common.DB.Create(&oidcAuthorizationCode).Error; err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
return randomString, nil
|
|
||||||
}
|
|
||||||
@@ -1,276 +0,0 @@
|
|||||||
package handler
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"golang-rest-api-template/internal/common"
|
|
||||||
"golang-rest-api-template/internal/common/middleware"
|
|
||||||
"golang-rest-api-template/internal/model"
|
|
||||||
"golang-rest-api-template/internal/utils"
|
|
||||||
"golang.org/x/time/rate"
|
|
||||||
"gorm.io/gorm"
|
|
||||||
"log"
|
|
||||||
"net/http"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
func RegisterUserRoutes(group *gin.RouterGroup) {
|
|
||||||
group.GET("/users", middleware.JWTAuth(true), listUsersHandler)
|
|
||||||
group.GET("/users/me", middleware.JWTAuth(false), getCurrentUserHandler)
|
|
||||||
group.GET("/users/:id", middleware.JWTAuth(true), getUserHandler)
|
|
||||||
group.POST("/users", middleware.JWTAuth(true), createUserHandler)
|
|
||||||
group.PUT("/users/:id", middleware.JWTAuth(true), updateUserHandler)
|
|
||||||
group.PUT("/users/me", middleware.JWTAuth(false), updateCurrentUserHandler)
|
|
||||||
group.DELETE("/users/:id", middleware.JWTAuth(true), deleteUserHandler)
|
|
||||||
|
|
||||||
group.POST("/users/:id/one-time-access-token", middleware.JWTAuth(true), createOneTimeAccessTokenHandler)
|
|
||||||
group.POST("/one-time-access-token/:token", middleware.RateLimiter(rate.Every(10*time.Second), 5), exchangeOneTimeAccessTokenHandler)
|
|
||||||
group.POST("/one-time-access-token/setup", getSetupAccessTokenHandler)
|
|
||||||
}
|
|
||||||
|
|
||||||
func listUsersHandler(c *gin.Context) {
|
|
||||||
var users []model.User
|
|
||||||
searchTerm := c.Query("search")
|
|
||||||
|
|
||||||
query := common.DB.Model(&model.User{})
|
|
||||||
|
|
||||||
if searchTerm != "" {
|
|
||||||
searchPattern := "%" + searchTerm + "%"
|
|
||||||
query = query.Where("email LIKE ? OR first_name LIKE ? OR username LIKE ?", searchPattern, searchPattern, searchPattern)
|
|
||||||
}
|
|
||||||
|
|
||||||
pagination, err := utils.Paginate(c, query, &users)
|
|
||||||
if err != nil {
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
|
||||||
"data": users,
|
|
||||||
"pagination": pagination,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func getUserHandler(c *gin.Context) {
|
|
||||||
var user model.User
|
|
||||||
if err := common.DB.Where("id = ?", c.Param("id")).First(&user).Error; err != nil {
|
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
||||||
utils.HandlerError(c, http.StatusNotFound, "User not found")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.JSON(http.StatusOK, user)
|
|
||||||
}
|
|
||||||
|
|
||||||
func getCurrentUserHandler(c *gin.Context) {
|
|
||||||
var user model.User
|
|
||||||
if err := common.DB.Where("id = ?", c.GetString("userID")).First(&user).Error; err != nil {
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.JSON(http.StatusOK, user)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func deleteUserHandler(c *gin.Context) {
|
|
||||||
var user model.User
|
|
||||||
if err := common.DB.Where("id = ?", c.Param("id")).First(&user).Error; err != nil {
|
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
||||||
utils.HandlerError(c, http.StatusNotFound, "User not found")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := common.DB.Delete(&user).Error; err != nil {
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Status(http.StatusNoContent)
|
|
||||||
}
|
|
||||||
|
|
||||||
func createUserHandler(c *gin.Context) {
|
|
||||||
var user model.User
|
|
||||||
if err := c.ShouldBindJSON(&user); err != nil {
|
|
||||||
utils.HandlerError(c, http.StatusBadRequest, "invalid request body")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := common.DB.Create(&user).Error; err != nil {
|
|
||||||
if errors.Is(err, gorm.ErrDuplicatedKey) {
|
|
||||||
if err := checkDuplicatedFields(user); err != nil {
|
|
||||||
utils.HandlerError(c, http.StatusBadRequest, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
c.JSON(http.StatusCreated, user)
|
|
||||||
}
|
|
||||||
|
|
||||||
func updateUserHandler(c *gin.Context) {
|
|
||||||
updateUser(c, c.Param("id"), false)
|
|
||||||
}
|
|
||||||
|
|
||||||
func updateCurrentUserHandler(c *gin.Context) {
|
|
||||||
updateUser(c, c.GetString("userID"), true)
|
|
||||||
}
|
|
||||||
|
|
||||||
func createOneTimeAccessTokenHandler(c *gin.Context) {
|
|
||||||
var input model.OneTimeAccessTokenCreateDto
|
|
||||||
if err := c.ShouldBindJSON(&input); err != nil {
|
|
||||||
utils.HandlerError(c, http.StatusBadRequest, "invalid request body")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
randomString, err := utils.GenerateRandomAlphanumericString(16)
|
|
||||||
if err != nil {
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
oneTimeAccessToken := model.OneTimeAccessToken{
|
|
||||||
UserID: input.UserID,
|
|
||||||
ExpiresAt: input.ExpiresAt,
|
|
||||||
Token: randomString,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := common.DB.Create(&oneTimeAccessToken).Error; err != nil {
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.JSON(http.StatusCreated, gin.H{"token": oneTimeAccessToken.Token})
|
|
||||||
}
|
|
||||||
|
|
||||||
func exchangeOneTimeAccessTokenHandler(c *gin.Context) {
|
|
||||||
var oneTimeAccessToken model.OneTimeAccessToken
|
|
||||||
if err := common.DB.Where("token = ? AND expires_at > ?", c.Param("token"), utils.FormatDateForDb(time.Now())).Preload("User").First(&oneTimeAccessToken).Error; err != nil {
|
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
||||||
utils.HandlerError(c, http.StatusForbidden, "Token is invalid or expired")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
token, err := common.GenerateAccessToken(oneTimeAccessToken.User)
|
|
||||||
if err != nil {
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
log.Println(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := common.DB.Delete(&oneTimeAccessToken).Error; err != nil {
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.SetCookie("access_token", token, int(time.Hour.Seconds()), "/", "", false, true)
|
|
||||||
|
|
||||||
c.JSON(http.StatusOK, oneTimeAccessToken.User)
|
|
||||||
}
|
|
||||||
|
|
||||||
// getSetupAccessTokenHandler creates the initial admin user and returns an access token for the user
|
|
||||||
// This handler is only available if there are no users in the database
|
|
||||||
func getSetupAccessTokenHandler(c *gin.Context) {
|
|
||||||
var userCount int64
|
|
||||||
if err := common.DB.Model(&model.User{}).Count(&userCount).Error; err != nil {
|
|
||||||
log.Fatal("failed to count users", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// If there are more than one user, we don't need to create the admin user
|
|
||||||
if userCount > 1 {
|
|
||||||
utils.HandlerError(c, http.StatusForbidden, "Setup already completed")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var user = model.User{
|
|
||||||
FirstName: "Admin",
|
|
||||||
LastName: "Admin",
|
|
||||||
Username: "admin",
|
|
||||||
Email: "admin@admin.com",
|
|
||||||
IsAdmin: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create the initial admin user if it doesn't exist
|
|
||||||
if err := common.DB.Model(&model.User{}).Preload("Credentials").FirstOrCreate(&user).Error; err != nil {
|
|
||||||
log.Fatal("failed to create admin user", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the user already has credentials, the setup is already completed
|
|
||||||
if len(user.Credentials) > 0 {
|
|
||||||
utils.HandlerError(c, http.StatusForbidden, "Setup already completed")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
token, err := common.GenerateAccessToken(user)
|
|
||||||
if err != nil {
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
log.Println(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.SetCookie("access_token", token, int(time.Hour.Seconds()), "/", "", false, true)
|
|
||||||
c.JSON(http.StatusOK, user)
|
|
||||||
}
|
|
||||||
|
|
||||||
func updateUser(c *gin.Context, userID string, updateOwnUser bool) {
|
|
||||||
var user model.User
|
|
||||||
if err := common.DB.Where("id = ?", userID).First(&user).Error; err != nil {
|
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
||||||
utils.HandlerError(c, http.StatusNotFound, "User not found")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var updatedUser model.User
|
|
||||||
if err := c.ShouldBindJSON(&updatedUser); err != nil {
|
|
||||||
utils.HandlerError(c, http.StatusBadRequest, "invalid request body")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
user.FirstName = updatedUser.FirstName
|
|
||||||
user.LastName = updatedUser.LastName
|
|
||||||
user.Email = updatedUser.Email
|
|
||||||
user.Username = updatedUser.Username
|
|
||||||
user.Username = updatedUser.Username
|
|
||||||
if !updateOwnUser {
|
|
||||||
user.IsAdmin = updatedUser.IsAdmin
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := common.DB.Save(user).Error; err != nil {
|
|
||||||
if errors.Is(err, gorm.ErrDuplicatedKey) {
|
|
||||||
if err := checkDuplicatedFields(user); err != nil {
|
|
||||||
utils.HandlerError(c, http.StatusBadRequest, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
c.JSON(http.StatusOK, user)
|
|
||||||
}
|
|
||||||
|
|
||||||
func checkDuplicatedFields(user model.User) error {
|
|
||||||
var existingUser model.User
|
|
||||||
|
|
||||||
if common.DB.Where("id != ? AND email = ?", user.ID, user.Email).First(&existingUser).Error == nil {
|
|
||||||
return errors.New("email is already taken")
|
|
||||||
}
|
|
||||||
|
|
||||||
if common.DB.Where("id != ? AND username = ?", user.ID, user.Username).First(&existingUser).Error == nil {
|
|
||||||
return errors.New("username is already taken")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,257 +0,0 @@
|
|||||||
package handler
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/go-webauthn/webauthn/protocol"
|
|
||||||
"github.com/go-webauthn/webauthn/webauthn"
|
|
||||||
"golang-rest-api-template/internal/common"
|
|
||||||
"golang-rest-api-template/internal/common/middleware"
|
|
||||||
"golang-rest-api-template/internal/model"
|
|
||||||
"golang-rest-api-template/internal/utils"
|
|
||||||
"golang.org/x/time/rate"
|
|
||||||
"gorm.io/gorm"
|
|
||||||
"log"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
func RegisterRoutes(group *gin.RouterGroup) {
|
|
||||||
group.GET("/webauthn/register/start", middleware.JWTAuth(false), beginRegistrationHandler)
|
|
||||||
group.POST("/webauthn/register/finish", middleware.JWTAuth(false), verifyRegistrationHandler)
|
|
||||||
|
|
||||||
group.GET("/webauthn/login/start", beginLoginHandler)
|
|
||||||
group.POST("/webauthn/login/finish", middleware.RateLimiter(rate.Every(10*time.Second), 5), verifyLoginHandler)
|
|
||||||
|
|
||||||
group.POST("/webauthn/logout", middleware.JWTAuth(false), logoutHandler)
|
|
||||||
|
|
||||||
group.GET("/webauthn/credentials", middleware.JWTAuth(false), listCredentialsHandler)
|
|
||||||
group.PATCH("/webauthn/credentials/:id", middleware.JWTAuth(false), updateCredentialHandler)
|
|
||||||
group.DELETE("/webauthn/credentials/:id", middleware.JWTAuth(false), deleteCredentialHandler)
|
|
||||||
}
|
|
||||||
|
|
||||||
func beginRegistrationHandler(c *gin.Context) {
|
|
||||||
var user model.User
|
|
||||||
err := common.DB.Preload("Credentials").Find(&user, "id = ?", c.GetString("userID")).Error
|
|
||||||
if err != nil {
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
log.Println(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
options, session, err := common.WebAuthn.BeginRegistration(&user, webauthn.WithResidentKeyRequirement(protocol.ResidentKeyRequirementRequired), webauthn.WithExclusions(user.WebAuthnCredentialDescriptors()))
|
|
||||||
if err != nil {
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Save the webauthn session so we can retrieve it in the verifyRegistrationHandler
|
|
||||||
sessionToStore := &model.WebauthnSession{
|
|
||||||
ExpiresAt: session.Expires,
|
|
||||||
Challenge: session.Challenge,
|
|
||||||
UserVerification: string(session.UserVerification),
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = common.DB.Create(&sessionToStore).Error; err != nil {
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.SetCookie("session_id", sessionToStore.ID, int(common.WebAuthn.Config.Timeouts.Registration.Timeout.Seconds()), "/", "", false, true)
|
|
||||||
c.JSON(http.StatusOK, options.Response)
|
|
||||||
}
|
|
||||||
|
|
||||||
func verifyRegistrationHandler(c *gin.Context) {
|
|
||||||
sessionID, err := c.Cookie("session_id")
|
|
||||||
if err != nil {
|
|
||||||
utils.HandlerError(c, http.StatusBadRequest, "Session ID missing")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Retrieve the session that was previously created by the beginRegistrationHandler
|
|
||||||
var storedSession model.WebauthnSession
|
|
||||||
err = common.DB.First(&storedSession, "id = ?", sessionID).Error
|
|
||||||
|
|
||||||
session := webauthn.SessionData{
|
|
||||||
Challenge: storedSession.Challenge,
|
|
||||||
Expires: storedSession.ExpiresAt,
|
|
||||||
UserID: []byte(c.GetString("userID")),
|
|
||||||
}
|
|
||||||
|
|
||||||
var user model.User
|
|
||||||
err = common.DB.Find(&user, "id = ?", c.GetString("userID")).Error
|
|
||||||
if err != nil {
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
credential, err := common.WebAuthn.FinishRegistration(&user, session, c.Request)
|
|
||||||
if err != nil {
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
credentialToStore := model.WebauthnCredential{
|
|
||||||
Name: "New Passkey",
|
|
||||||
CredentialID: string(credential.ID),
|
|
||||||
AttestationType: credential.AttestationType,
|
|
||||||
PublicKey: credential.PublicKey,
|
|
||||||
Transport: credential.Transport,
|
|
||||||
UserID: user.ID,
|
|
||||||
BackupEligible: credential.Flags.BackupEligible,
|
|
||||||
BackupState: credential.Flags.BackupState,
|
|
||||||
}
|
|
||||||
if err := common.DB.Create(&credentialToStore).Error; err != nil {
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.JSON(http.StatusOK, credentialToStore)
|
|
||||||
}
|
|
||||||
|
|
||||||
func beginLoginHandler(c *gin.Context) {
|
|
||||||
options, session, err := common.WebAuthn.BeginDiscoverableLogin()
|
|
||||||
if err != nil {
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Save the webauthn session so we can retrieve it in the verifyLoginHandler
|
|
||||||
sessionToStore := &model.WebauthnSession{
|
|
||||||
ExpiresAt: session.Expires,
|
|
||||||
Challenge: session.Challenge,
|
|
||||||
UserVerification: string(session.UserVerification),
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = common.DB.Create(&sessionToStore).Error; err != nil {
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.SetCookie("session_id", sessionToStore.ID, int(common.WebAuthn.Config.Timeouts.Registration.Timeout.Seconds()), "/", "", false, true)
|
|
||||||
c.JSON(http.StatusOK, options.Response)
|
|
||||||
}
|
|
||||||
|
|
||||||
func verifyLoginHandler(c *gin.Context) {
|
|
||||||
sessionID, err := c.Cookie("session_id")
|
|
||||||
if err != nil {
|
|
||||||
utils.HandlerError(c, http.StatusBadRequest, "Session ID missing")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
credentialAssertionData, err := protocol.ParseCredentialRequestResponseBody(c.Request.Body)
|
|
||||||
if err != nil {
|
|
||||||
utils.HandlerError(c, http.StatusBadRequest, "Invalid body")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Retrieve the session that was previously created by the beginLoginHandler
|
|
||||||
var storedSession model.WebauthnSession
|
|
||||||
if err := common.DB.First(&storedSession, "id = ?", sessionID).Error; err != nil {
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
session := webauthn.SessionData{
|
|
||||||
Challenge: storedSession.Challenge,
|
|
||||||
Expires: storedSession.ExpiresAt,
|
|
||||||
}
|
|
||||||
|
|
||||||
var user *model.User
|
|
||||||
_, err = common.WebAuthn.ValidateDiscoverableLogin(func(_, userHandle []byte) (webauthn.User, error) {
|
|
||||||
if err := common.DB.Preload("Credentials").First(&user, "id = ?", string(userHandle)).Error; err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return user, nil
|
|
||||||
}, session, credentialAssertionData)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
if strings.Contains(err.Error(), gorm.ErrRecordNotFound.Error()) {
|
|
||||||
utils.HandlerError(c, http.StatusBadRequest, "no user with this passkey exists")
|
|
||||||
} else {
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err = common.DB.Find(&user, "id = ?", c.GetString("userID")).Error
|
|
||||||
if err != nil {
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
token, err := common.GenerateAccessToken(*user)
|
|
||||||
if err != nil {
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.SetCookie("access_token", token, int(time.Hour.Seconds()), "/", "", false, true)
|
|
||||||
c.JSON(http.StatusOK, user)
|
|
||||||
}
|
|
||||||
|
|
||||||
func listCredentialsHandler(c *gin.Context) {
|
|
||||||
var credentials []model.WebauthnCredential
|
|
||||||
if err := common.DB.Find(&credentials, "user_id = ?", c.GetString("userID")).Error; err != nil {
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.JSON(http.StatusOK, credentials)
|
|
||||||
}
|
|
||||||
|
|
||||||
func deleteCredentialHandler(c *gin.Context) {
|
|
||||||
var passkeyCount int64
|
|
||||||
if err := common.DB.Model(&model.WebauthnCredential{}).Where("user_id = ?", c.GetString("userID")).Count(&passkeyCount).Error; err != nil {
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if passkeyCount == 1 {
|
|
||||||
utils.HandlerError(c, http.StatusBadRequest, "You must have at least one passkey")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var credential model.WebauthnCredential
|
|
||||||
if err := common.DB.First(&credential, "id = ? AND user_id = ?", c.Param("id"), c.GetString("userID")).Error; err != nil {
|
|
||||||
utils.HandlerError(c, http.StatusNotFound, "Credential not found")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := common.DB.Delete(&credential).Error; err != nil {
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Status(http.StatusNoContent)
|
|
||||||
}
|
|
||||||
|
|
||||||
func updateCredentialHandler(c *gin.Context) {
|
|
||||||
var credential model.WebauthnCredential
|
|
||||||
if err := common.DB.Where("id = ? AND user_id = ?", c.Param("id"), c.GetString("userID")).First(&credential).Error; err != nil {
|
|
||||||
utils.HandlerError(c, http.StatusNotFound, "Credential not found")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var input struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
}
|
|
||||||
if err := c.ShouldBindJSON(&input); err != nil {
|
|
||||||
utils.HandlerError(c, http.StatusBadRequest, "invalid request body")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
credential.Name = input.Name
|
|
||||||
|
|
||||||
if err := common.DB.Save(&credential).Error; err != nil {
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Status(http.StatusNoContent)
|
|
||||||
}
|
|
||||||
|
|
||||||
func logoutHandler(c *gin.Context) {
|
|
||||||
c.SetCookie("access_token", "", 0, "/", "", false, true)
|
|
||||||
c.Status(http.StatusNoContent)
|
|
||||||
}
|
|
||||||
@@ -3,28 +3,46 @@ package job
|
|||||||
import (
|
import (
|
||||||
"github.com/go-co-op/gocron/v2"
|
"github.com/go-co-op/gocron/v2"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"golang-rest-api-template/internal/common"
|
"github.com/stonith404/pocket-id/backend/internal/model"
|
||||||
"golang-rest-api-template/internal/model"
|
"github.com/stonith404/pocket-id/backend/internal/utils"
|
||||||
"golang-rest-api-template/internal/utils"
|
"gorm.io/gorm"
|
||||||
"log"
|
"log"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func RegisterJobs() {
|
func RegisterJobs(db *gorm.DB) {
|
||||||
scheduler, err := gocron.NewScheduler()
|
scheduler, err := gocron.NewScheduler()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Failed to create a new scheduler: %s", err)
|
log.Fatalf("Failed to create a new scheduler: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
registerJob(scheduler, "ClearWebauthnSessions", "0 3 * * *", clearWebauthnSessions)
|
jobs := &Jobs{db: db}
|
||||||
registerJob(scheduler, "ClearOneTimeAccessTokens", "0 3 * * *", clearOneTimeAccessTokens)
|
|
||||||
registerJob(scheduler, "ClearOidcAuthorizationCodes", "0 3 * * *", clearOidcAuthorizationCodes)
|
registerJob(scheduler, "ClearWebauthnSessions", "0 3 * * *", jobs.clearWebauthnSessions)
|
||||||
|
registerJob(scheduler, "ClearOneTimeAccessTokens", "0 3 * * *", jobs.clearOneTimeAccessTokens)
|
||||||
|
registerJob(scheduler, "ClearOidcAuthorizationCodes", "0 3 * * *", jobs.clearOidcAuthorizationCodes)
|
||||||
|
|
||||||
scheduler.Start()
|
scheduler.Start()
|
||||||
}
|
}
|
||||||
|
|
||||||
func registerJob(scheduler gocron.Scheduler, name string, interval string, job func() error) {
|
type Jobs struct {
|
||||||
|
db *gorm.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func (j *Jobs) clearWebauthnSessions() error {
|
||||||
|
return j.db.Delete(&model.WebauthnSession{}, "expires_at < ?", utils.FormatDateForDb(time.Now())).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (j *Jobs) clearOneTimeAccessTokens() error {
|
||||||
|
return j.db.Debug().Delete(&model.OneTimeAccessToken{}, "expires_at < ?", utils.FormatDateForDb(time.Now())).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (j *Jobs) clearOidcAuthorizationCodes() error {
|
||||||
|
return j.db.Delete(&model.OidcAuthorizationCode{}, "expires_at < ?", utils.FormatDateForDb(time.Now())).Error
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func registerJob(scheduler gocron.Scheduler, name string, interval string, job func() error) {
|
||||||
_, err := scheduler.NewJob(
|
_, err := scheduler.NewJob(
|
||||||
gocron.CronJob(interval, false),
|
gocron.CronJob(interval, false),
|
||||||
gocron.NewTask(job),
|
gocron.NewTask(job),
|
||||||
@@ -42,16 +60,3 @@ func registerJob(scheduler gocron.Scheduler, name string, interval string, job f
|
|||||||
log.Fatalf("Failed to register job %q: %v", name, err)
|
log.Fatalf("Failed to register job %q: %v", name, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func clearWebauthnSessions() error {
|
|
||||||
return common.DB.Delete(&model.WebauthnSession{}, "expires_at < ?", utils.FormatDateForDb(time.Now())).Error
|
|
||||||
}
|
|
||||||
|
|
||||||
func clearOneTimeAccessTokens() error {
|
|
||||||
return common.DB.Debug().Delete(&model.OneTimeAccessToken{}, "expires_at < ?", utils.FormatDateForDb(time.Now())).Error
|
|
||||||
}
|
|
||||||
|
|
||||||
func clearOidcAuthorizationCodes() error {
|
|
||||||
return common.DB.Delete(&model.OidcAuthorizationCode{}, "expires_at < ?", utils.FormatDateForDb(time.Now())).Error
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,14 +1,20 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"golang-rest-api-template/internal/common"
|
"github.com/stonith404/pocket-id/backend/internal/common"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-contrib/cors"
|
"github.com/gin-contrib/cors"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Cors() gin.HandlerFunc {
|
type CorsMiddleware struct{}
|
||||||
|
|
||||||
|
func NewCorsMiddleware() *CorsMiddleware {
|
||||||
|
return &CorsMiddleware{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *CorsMiddleware) Add() gin.HandlerFunc {
|
||||||
return cors.New(cors.Config{
|
return cors.New(cors.Config{
|
||||||
AllowOrigins: []string{common.EnvConfig.AppURL},
|
AllowOrigins: []string{common.EnvConfig.AppURL},
|
||||||
AllowMethods: []string{"*"},
|
AllowMethods: []string{"*"},
|
||||||
@@ -3,11 +3,17 @@ package middleware
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"golang-rest-api-template/internal/utils"
|
"github.com/stonith404/pocket-id/backend/internal/utils"
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
func LimitFileSize(maxSize int64) gin.HandlerFunc {
|
type FileSizeLimitMiddleware struct{}
|
||||||
|
|
||||||
|
func NewFileSizeLimitMiddleware() *FileSizeLimitMiddleware {
|
||||||
|
return &FileSizeLimitMiddleware{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *FileSizeLimitMiddleware) Add(maxSize int64) gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, maxSize)
|
c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, maxSize)
|
||||||
if err := c.Request.ParseMultipartForm(maxSize); err != nil {
|
if err := c.Request.ParseMultipartForm(maxSize); err != nil {
|
||||||
@@ -2,15 +2,22 @@ package middleware
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"golang-rest-api-template/internal/common"
|
"github.com/stonith404/pocket-id/backend/internal/service"
|
||||||
"golang-rest-api-template/internal/utils"
|
"github.com/stonith404/pocket-id/backend/internal/utils"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
func JWTAuth(adminOnly bool) gin.HandlerFunc {
|
type JwtAuthMiddleware struct {
|
||||||
return func(c *gin.Context) {
|
jwtService *service.JwtService
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewJwtAuthMiddleware(jwtService *service.JwtService) *JwtAuthMiddleware {
|
||||||
|
return &JwtAuthMiddleware{jwtService: jwtService}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *JwtAuthMiddleware) Add(adminOnly bool) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
// Extract the token from the cookie or the Authorization header
|
// Extract the token from the cookie or the Authorization header
|
||||||
token, err := c.Cookie("access_token")
|
token, err := c.Cookie("access_token")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -22,11 +29,9 @@ func JWTAuth(adminOnly bool) gin.HandlerFunc {
|
|||||||
c.Abort()
|
c.Abort()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify the token
|
claims, err := m.jwtService.VerifyAccessToken(token)
|
||||||
claims, err := common.VerifyAccessToken(token)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.HandlerError(c, http.StatusUnauthorized, "You're not signed in")
|
utils.HandlerError(c, http.StatusUnauthorized, "You're not signed in")
|
||||||
c.Abort()
|
c.Abort()
|
||||||
@@ -1,8 +1,8 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"golang-rest-api-template/internal/common"
|
"github.com/stonith404/pocket-id/backend/internal/common"
|
||||||
"golang-rest-api-template/internal/utils"
|
"github.com/stonith404/pocket-id/backend/internal/utils"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@@ -11,8 +11,13 @@ import (
|
|||||||
"golang.org/x/time/rate"
|
"golang.org/x/time/rate"
|
||||||
)
|
)
|
||||||
|
|
||||||
// RateLimiter is a Gin middleware for rate limiting based on client IP
|
type RateLimitMiddleware struct{}
|
||||||
func RateLimiter(limit rate.Limit, burst int) gin.HandlerFunc {
|
|
||||||
|
func NewRateLimitMiddleware() *RateLimitMiddleware {
|
||||||
|
return &RateLimitMiddleware{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *RateLimitMiddleware) Add(limit rate.Limit, burst int) gin.HandlerFunc {
|
||||||
// Start the cleanup routine
|
// Start the cleanup routine
|
||||||
go cleanupClients()
|
go cleanupClients()
|
||||||
|
|
||||||
20
backend/internal/model/app_config.go
Normal file
20
backend/internal/model/app_config.go
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
type AppConfigVariable struct {
|
||||||
|
Key string `gorm:"primaryKey;not null" json:"key"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
IsPublic bool `json:"-"`
|
||||||
|
IsInternal bool `json:"-"`
|
||||||
|
Value string `json:"value"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AppConfig struct {
|
||||||
|
AppName AppConfigVariable
|
||||||
|
BackgroundImageType AppConfigVariable
|
||||||
|
LogoImageType AppConfigVariable
|
||||||
|
SessionDuration AppConfigVariable
|
||||||
|
}
|
||||||
|
|
||||||
|
type AppConfigUpdateDto struct {
|
||||||
|
AppName string `json:"appName" binding:"required"`
|
||||||
|
}
|
||||||
@@ -1,20 +0,0 @@
|
|||||||
package model
|
|
||||||
|
|
||||||
type ApplicationConfigurationVariable struct {
|
|
||||||
Key string `gorm:"primaryKey;not null" json:"key"`
|
|
||||||
Type string `json:"type"`
|
|
||||||
IsPublic bool `json:"-"`
|
|
||||||
IsInternal bool `json:"-"`
|
|
||||||
Value string `json:"value"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ApplicationConfiguration struct {
|
|
||||||
AppName ApplicationConfigurationVariable
|
|
||||||
BackgroundImageType ApplicationConfigurationVariable
|
|
||||||
LogoImageType ApplicationConfigurationVariable
|
|
||||||
SessionDuration ApplicationConfigurationVariable
|
|
||||||
}
|
|
||||||
|
|
||||||
type ApplicationConfigurationUpdateDto struct {
|
|
||||||
AppName string `json:"appName" binding:"required"`
|
|
||||||
}
|
|
||||||
@@ -12,7 +12,7 @@ type Base struct {
|
|||||||
CreatedAt time.Time `json:"createdAt"`
|
CreatedAt time.Time `json:"createdAt"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Base) BeforeCreate(db *gorm.DB) (err error) {
|
func (b *Base) BeforeCreate(_ *gorm.DB) (err error) {
|
||||||
if b.ID == "" {
|
if b.ID == "" {
|
||||||
b.ID = uuid.New().String()
|
b.ID = uuid.New().String()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -63,3 +63,9 @@ type OidcIdTokenDto struct {
|
|||||||
ClientID string `form:"client_id"`
|
ClientID string `form:"client_id"`
|
||||||
ClientSecret string `form:"client_secret"`
|
ClientSecret string `form:"client_secret"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type AuthorizeRequest struct {
|
||||||
|
ClientID string `json:"clientID" binding:"required"`
|
||||||
|
Scope string `json:"scope" binding:"required"`
|
||||||
|
Nonce string `json:"nonce"`
|
||||||
|
}
|
||||||
|
|||||||
@@ -31,6 +31,18 @@ type WebauthnCredential struct {
|
|||||||
UserID string
|
UserID string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type PublicKeyCredentialCreationOptions struct {
|
||||||
|
Response protocol.PublicKeyCredentialCreationOptions `json:"response"`
|
||||||
|
SessionID string `json:"session_id"`
|
||||||
|
Timeout time.Duration `json:"timeout"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type PublicKeyCredentialRequestOptions struct {
|
||||||
|
Response protocol.PublicKeyCredentialRequestOptions `json:"response"`
|
||||||
|
SessionID string `json:"session_id"`
|
||||||
|
Timeout time.Duration `json:"timeout"`
|
||||||
|
}
|
||||||
|
|
||||||
type AuthenticatorTransportList []protocol.AuthenticatorTransport
|
type AuthenticatorTransportList []protocol.AuthenticatorTransport
|
||||||
|
|
||||||
// Scan and Value methods for GORM to handle the custom type
|
// Scan and Value methods for GORM to handle the custom type
|
||||||
@@ -46,3 +58,7 @@ func (atl *AuthenticatorTransportList) Scan(value interface{}) error {
|
|||||||
func (atl AuthenticatorTransportList) Value() (driver.Value, error) {
|
func (atl AuthenticatorTransportList) Value() (driver.Value, error) {
|
||||||
return json.Marshal(atl)
|
return json.Marshal(atl)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type WebauthnCredentialUpdateDto struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
}
|
||||||
|
|||||||
213
backend/internal/service/app_config_service.go
Normal file
213
backend/internal/service/app_config_service.go
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/stonith404/pocket-id/backend/internal/common"
|
||||||
|
"github.com/stonith404/pocket-id/backend/internal/model"
|
||||||
|
"github.com/stonith404/pocket-id/backend/internal/utils"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"log"
|
||||||
|
"mime/multipart"
|
||||||
|
"os"
|
||||||
|
"reflect"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AppConfigService struct {
|
||||||
|
DbConfig *model.AppConfig
|
||||||
|
db *gorm.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewAppConfigService(db *gorm.DB) *AppConfigService {
|
||||||
|
service := &AppConfigService{
|
||||||
|
DbConfig: &defaultDbConfig,
|
||||||
|
db: db,
|
||||||
|
}
|
||||||
|
if err := service.InitDbConfig(); err != nil {
|
||||||
|
log.Fatalf("Failed to initialize app config service: %v", err)
|
||||||
|
}
|
||||||
|
return service
|
||||||
|
}
|
||||||
|
|
||||||
|
var defaultDbConfig = model.AppConfig{
|
||||||
|
AppName: model.AppConfigVariable{
|
||||||
|
Key: "appName",
|
||||||
|
Type: "string",
|
||||||
|
IsPublic: true,
|
||||||
|
Value: "Pocket ID",
|
||||||
|
},
|
||||||
|
SessionDuration: model.AppConfigVariable{
|
||||||
|
Key: "sessionDuration",
|
||||||
|
Type: "number",
|
||||||
|
Value: "60",
|
||||||
|
},
|
||||||
|
BackgroundImageType: model.AppConfigVariable{
|
||||||
|
Key: "backgroundImageType",
|
||||||
|
Type: "string",
|
||||||
|
IsInternal: true,
|
||||||
|
Value: "jpg",
|
||||||
|
},
|
||||||
|
LogoImageType: model.AppConfigVariable{
|
||||||
|
Key: "logoImageType",
|
||||||
|
Type: "string",
|
||||||
|
IsInternal: true,
|
||||||
|
Value: "svg",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AppConfigService) UpdateApplicationConfiguration(input model.AppConfigUpdateDto) ([]model.AppConfigVariable, error) {
|
||||||
|
savedConfigVariables := make([]model.AppConfigVariable, 10)
|
||||||
|
|
||||||
|
tx := s.db.Begin()
|
||||||
|
rt := reflect.ValueOf(input).Type()
|
||||||
|
rv := reflect.ValueOf(input)
|
||||||
|
|
||||||
|
for i := 0; i < rt.NumField(); i++ {
|
||||||
|
field := rt.Field(i)
|
||||||
|
key := field.Tag.Get("json")
|
||||||
|
value := rv.FieldByName(field.Name).String()
|
||||||
|
|
||||||
|
var applicationConfigurationVariable model.AppConfigVariable
|
||||||
|
if err := tx.First(&applicationConfigurationVariable, "key = ? AND is_internal = false", key).Error; err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
applicationConfigurationVariable.Value = value
|
||||||
|
if err := tx.Save(&applicationConfigurationVariable).Error; err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
savedConfigVariables[i] = applicationConfigurationVariable
|
||||||
|
}
|
||||||
|
|
||||||
|
tx.Commit()
|
||||||
|
|
||||||
|
if err := s.loadDbConfigFromDb(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return savedConfigVariables, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AppConfigService) UpdateImageType(imageName string, fileType string) error {
|
||||||
|
key := fmt.Sprintf("%sImageType", imageName)
|
||||||
|
err := s.db.Model(&model.AppConfigVariable{}).Where("key = ?", key).Update("value", fileType).Error
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.loadDbConfigFromDb()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AppConfigService) ListApplicationConfiguration(showAll bool) ([]model.AppConfigVariable, error) {
|
||||||
|
var configuration []model.AppConfigVariable
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if showAll {
|
||||||
|
err = s.db.Find(&configuration).Error
|
||||||
|
} else {
|
||||||
|
err = s.db.Find(&configuration, "is_public = true").Error
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return configuration, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AppConfigService) UpdateImage(uploadedFile *multipart.FileHeader, imageName string, oldImageType string) error {
|
||||||
|
fileType := utils.GetFileExtension(uploadedFile.Filename)
|
||||||
|
mimeType := utils.GetImageMimeType(fileType)
|
||||||
|
if mimeType == "" {
|
||||||
|
return common.ErrFileTypeNotSupported
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete the old image if it has a different file type
|
||||||
|
if fileType != oldImageType {
|
||||||
|
oldImagePath := fmt.Sprintf("%s/application-images/%s.%s", common.EnvConfig.UploadPath, imageName, oldImageType)
|
||||||
|
if err := os.Remove(oldImagePath); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
imagePath := fmt.Sprintf("%s/application-images/%s.%s", common.EnvConfig.UploadPath, imageName, fileType)
|
||||||
|
if err := utils.SaveFile(uploadedFile, imagePath); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the file type in the database
|
||||||
|
if err := s.UpdateImageType(imageName, fileType); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// InitDbConfig creates the default configuration values in the database if they do not exist,
|
||||||
|
// updates existing configurations if they differ from the default, and deletes any configurations
|
||||||
|
// that are not in the default configuration.
|
||||||
|
func (s *AppConfigService) InitDbConfig() error {
|
||||||
|
// Reflect to get the underlying value of DbConfig and its default configuration
|
||||||
|
defaultConfigReflectValue := reflect.ValueOf(defaultDbConfig)
|
||||||
|
defaultKeys := make(map[string]struct{})
|
||||||
|
|
||||||
|
// Iterate over the fields of DbConfig
|
||||||
|
for i := 0; i < defaultConfigReflectValue.NumField(); i++ {
|
||||||
|
defaultConfigVar := defaultConfigReflectValue.Field(i).Interface().(model.AppConfigVariable)
|
||||||
|
|
||||||
|
defaultKeys[defaultConfigVar.Key] = struct{}{}
|
||||||
|
|
||||||
|
var storedConfigVar model.AppConfigVariable
|
||||||
|
if err := s.db.First(&storedConfigVar, "key = ?", defaultConfigVar.Key).Error; err != nil {
|
||||||
|
// If the configuration does not exist, create it
|
||||||
|
if err := s.db.Create(&defaultConfigVar).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update existing configuration if it differs from the default
|
||||||
|
if storedConfigVar.Type != defaultConfigVar.Type || storedConfigVar.IsPublic != defaultConfigVar.IsPublic || storedConfigVar.IsInternal != defaultConfigVar.IsInternal {
|
||||||
|
storedConfigVar.Type = defaultConfigVar.Type
|
||||||
|
storedConfigVar.IsPublic = defaultConfigVar.IsPublic
|
||||||
|
storedConfigVar.IsInternal = defaultConfigVar.IsInternal
|
||||||
|
if err := s.db.Save(&storedConfigVar).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete any configurations not in the default keys
|
||||||
|
var allConfigVars []model.AppConfigVariable
|
||||||
|
if err := s.db.Find(&allConfigVars).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, config := range allConfigVars {
|
||||||
|
if _, exists := defaultKeys[config.Key]; !exists {
|
||||||
|
if err := s.db.Delete(&config).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return s.loadDbConfigFromDb()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AppConfigService) loadDbConfigFromDb() error {
|
||||||
|
dbConfigReflectValue := reflect.ValueOf(s.DbConfig).Elem()
|
||||||
|
|
||||||
|
for i := 0; i < dbConfigReflectValue.NumField(); i++ {
|
||||||
|
dbConfigField := dbConfigReflectValue.Field(i)
|
||||||
|
currentConfigVar := dbConfigField.Interface().(model.AppConfigVariable)
|
||||||
|
var storedConfigVar model.AppConfigVariable
|
||||||
|
if err := s.db.First(&storedConfigVar, "key = ?", currentConfigVar.Key).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
dbConfigField.Set(reflect.ValueOf(storedConfigVar))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
248
backend/internal/service/jwt_service.go
Normal file
248
backend/internal/service/jwt_service.go
Normal file
@@ -0,0 +1,248 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/pem"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
"github.com/stonith404/pocket-id/backend/internal/common"
|
||||||
|
"github.com/stonith404/pocket-id/backend/internal/model"
|
||||||
|
"github.com/stonith404/pocket-id/backend/internal/utils"
|
||||||
|
"log"
|
||||||
|
"math/big"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"slices"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
privateKeyPath = "data/keys/jwt_private_key.pem"
|
||||||
|
publicKeyPath = "data/keys/jwt_public_key.pem"
|
||||||
|
)
|
||||||
|
|
||||||
|
type JwtService struct {
|
||||||
|
publicKey *rsa.PublicKey
|
||||||
|
privateKey *rsa.PrivateKey
|
||||||
|
appConfigService *AppConfigService
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewJwtService(appConfigService *AppConfigService) *JwtService {
|
||||||
|
service := &JwtService{
|
||||||
|
appConfigService: appConfigService,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure keys are generated or loaded
|
||||||
|
if err := service.loadOrGenerateKeys(); err != nil {
|
||||||
|
log.Fatalf("Failed to initialize jwt service: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return service
|
||||||
|
}
|
||||||
|
|
||||||
|
type AccessTokenJWTClaims struct {
|
||||||
|
jwt.RegisteredClaims
|
||||||
|
IsAdmin bool `json:"isAdmin,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type JWK struct {
|
||||||
|
Kty string `json:"kty"`
|
||||||
|
Use string `json:"use"`
|
||||||
|
Kid string `json:"kid"`
|
||||||
|
Alg string `json:"alg"`
|
||||||
|
N string `json:"n"`
|
||||||
|
E string `json:"e"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadOrGenerateKeys loads RSA keys from the given paths or generates them if they do not exist.
|
||||||
|
func (s *JwtService) loadOrGenerateKeys() error {
|
||||||
|
if _, err := os.Stat(privateKeyPath); os.IsNotExist(err) {
|
||||||
|
if err := s.generateKeys(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
privateKeyBytes, err := os.ReadFile(privateKeyPath)
|
||||||
|
if err != nil {
|
||||||
|
return errors.New("can't read jwt private key: " + err.Error())
|
||||||
|
}
|
||||||
|
s.privateKey, err = jwt.ParseRSAPrivateKeyFromPEM(privateKeyBytes)
|
||||||
|
if err != nil {
|
||||||
|
return errors.New("can't parse jwt private key: " + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
publicKeyBytes, err := os.ReadFile(publicKeyPath)
|
||||||
|
if err != nil {
|
||||||
|
return errors.New("can't read jwt public key: " + err.Error())
|
||||||
|
}
|
||||||
|
s.publicKey, err = jwt.ParseRSAPublicKeyFromPEM(publicKeyBytes)
|
||||||
|
if err != nil {
|
||||||
|
return errors.New("can't parse jwt public key: " + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *JwtService) GenerateIDToken(user model.User, clientID string, scope string, nonce string) (string, error) {
|
||||||
|
profileClaims := map[string]interface{}{
|
||||||
|
"given_name": user.FirstName,
|
||||||
|
"family_name": user.LastName,
|
||||||
|
"email": user.Email,
|
||||||
|
"preferred_username": user.Username,
|
||||||
|
}
|
||||||
|
|
||||||
|
claims := jwt.MapClaims{
|
||||||
|
"sub": user.ID,
|
||||||
|
"aud": clientID,
|
||||||
|
"exp": jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
|
||||||
|
"iat": jwt.NewNumericDate(time.Now()),
|
||||||
|
}
|
||||||
|
|
||||||
|
if nonce != "" {
|
||||||
|
claims["nonce"] = nonce
|
||||||
|
}
|
||||||
|
if strings.Contains(scope, "profile") {
|
||||||
|
for k, v := range profileClaims {
|
||||||
|
claims[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if strings.Contains(scope, "email") {
|
||||||
|
claims["email"] = user.Email
|
||||||
|
}
|
||||||
|
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||||||
|
return token.SignedString(s.privateKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *JwtService) GenerateAccessToken(user model.User) (string, error) {
|
||||||
|
sessionDurationInMinutes, _ := strconv.Atoi(s.appConfigService.DbConfig.SessionDuration.Value)
|
||||||
|
claim := AccessTokenJWTClaims{
|
||||||
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
Subject: user.ID,
|
||||||
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Duration(sessionDurationInMinutes) * time.Minute)),
|
||||||
|
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||||
|
Audience: jwt.ClaimStrings{utils.GetHostFromURL(common.EnvConfig.AppURL)},
|
||||||
|
},
|
||||||
|
IsAdmin: user.IsAdmin,
|
||||||
|
}
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claim)
|
||||||
|
return token.SignedString(s.privateKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *JwtService) VerifyAccessToken(tokenString string) (*AccessTokenJWTClaims, error) {
|
||||||
|
token, err := jwt.ParseWithClaims(tokenString, &AccessTokenJWTClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||||
|
return s.publicKey, nil
|
||||||
|
})
|
||||||
|
if err != nil || !token.Valid {
|
||||||
|
return nil, errors.New("couldn't handle this token")
|
||||||
|
}
|
||||||
|
|
||||||
|
claims, isValid := token.Claims.(*AccessTokenJWTClaims)
|
||||||
|
if !isValid {
|
||||||
|
return nil, errors.New("can't parse claims")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !slices.Contains(claims.Audience, utils.GetHostFromURL(common.EnvConfig.AppURL)) {
|
||||||
|
return nil, errors.New("audience doesn't match")
|
||||||
|
}
|
||||||
|
return claims, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetJWK returns the JSON Web Key (JWK) for the public key.
|
||||||
|
func (s *JwtService) GetJWK() (JWK, error) {
|
||||||
|
if s.publicKey == nil {
|
||||||
|
return JWK{}, errors.New("public key is not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
jwk := JWK{
|
||||||
|
Kty: "RSA",
|
||||||
|
Use: "sig",
|
||||||
|
Kid: "1",
|
||||||
|
Alg: "RS256",
|
||||||
|
N: base64.RawURLEncoding.EncodeToString(s.publicKey.N.Bytes()),
|
||||||
|
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(s.publicKey.E)).Bytes()),
|
||||||
|
}
|
||||||
|
|
||||||
|
return jwk, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateKeys generates a new RSA key pair and saves them to the specified paths.
|
||||||
|
func (s *JwtService) generateKeys() error {
|
||||||
|
if err := os.MkdirAll(filepath.Dir(privateKeyPath), 0700); err != nil {
|
||||||
|
return errors.New("failed to create directories for keys: " + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
if err != nil {
|
||||||
|
return errors.New("failed to generate private key: " + err.Error())
|
||||||
|
}
|
||||||
|
s.privateKey = privateKey
|
||||||
|
|
||||||
|
if err := s.savePEMKey(privateKeyPath, x509.MarshalPKCS1PrivateKey(privateKey), "RSA PRIVATE KEY"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
publicKey := &privateKey.PublicKey
|
||||||
|
s.publicKey = publicKey
|
||||||
|
|
||||||
|
if err := s.savePEMKey(publicKeyPath, x509.MarshalPKCS1PublicKey(publicKey), "RSA PUBLIC KEY"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// savePEMKey saves a PEM encoded key to a file.
|
||||||
|
func (s *JwtService) savePEMKey(path string, keyBytes []byte, keyType string) error {
|
||||||
|
keyFile, err := os.Create(path)
|
||||||
|
if err != nil {
|
||||||
|
return errors.New("failed to create key file: " + err.Error())
|
||||||
|
}
|
||||||
|
defer keyFile.Close()
|
||||||
|
|
||||||
|
keyPEM := pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: keyType,
|
||||||
|
Bytes: keyBytes,
|
||||||
|
})
|
||||||
|
|
||||||
|
if _, err := keyFile.Write(keyPEM); err != nil {
|
||||||
|
return errors.New("failed to write key file: " + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadKeys loads RSA keys from the given paths.
|
||||||
|
func (s *JwtService) loadKeys() error {
|
||||||
|
if _, err := os.Stat(privateKeyPath); os.IsNotExist(err) {
|
||||||
|
if err := s.generateKeys(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
privateKeyBytes, err := os.ReadFile(privateKeyPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("can't read jwt private key: %w", err)
|
||||||
|
}
|
||||||
|
s.privateKey, err = jwt.ParseRSAPrivateKeyFromPEM(privateKeyBytes)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("can't parse jwt private key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
publicKeyBytes, err := os.ReadFile(publicKeyPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("can't read jwt public key: %w", err)
|
||||||
|
}
|
||||||
|
s.publicKey, err = jwt.ParseRSAPublicKeyFromPEM(publicKeyBytes)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("can't parse jwt public key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
282
backend/internal/service/oidc_service.go
Normal file
282
backend/internal/service/oidc_service.go
Normal file
@@ -0,0 +1,282 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"github.com/stonith404/pocket-id/backend/internal/common"
|
||||||
|
"github.com/stonith404/pocket-id/backend/internal/model"
|
||||||
|
"github.com/stonith404/pocket-id/backend/internal/utils"
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"mime/multipart"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type OidcService struct {
|
||||||
|
db *gorm.DB
|
||||||
|
jwtService *JwtService
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOidcService(db *gorm.DB, jwtService *JwtService) *OidcService {
|
||||||
|
return &OidcService{
|
||||||
|
db: db,
|
||||||
|
jwtService: jwtService,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OidcService) Authorize(req model.AuthorizeRequest, userID string) (string, error) {
|
||||||
|
var userAuthorizedOIDCClient model.UserAuthorizedOidcClient
|
||||||
|
s.db.First(&userAuthorizedOIDCClient, "client_id = ? AND user_id = ?", req.ClientID, userID)
|
||||||
|
|
||||||
|
if userAuthorizedOIDCClient.Scope != req.Scope {
|
||||||
|
return "", common.ErrOidcMissingAuthorization
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.createAuthorizationCode(req.ClientID, userID, req.Scope, req.Nonce)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OidcService) AuthorizeNewClient(req model.AuthorizeNewClientDto, userID string) (string, error) {
|
||||||
|
userAuthorizedClient := model.UserAuthorizedOidcClient{
|
||||||
|
UserID: userID,
|
||||||
|
ClientID: req.ClientID,
|
||||||
|
Scope: req.Scope,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.db.Create(&userAuthorizedClient).Error; err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrDuplicatedKey) {
|
||||||
|
err = s.db.Model(&userAuthorizedClient).Update("scope", req.Scope).Error
|
||||||
|
} else {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.createAuthorizationCode(req.ClientID, userID, req.Scope, req.Nonce)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OidcService) CreateIDToken(req model.OidcIdTokenDto) (string, error) {
|
||||||
|
if req.GrantType != "authorization_code" {
|
||||||
|
return "", common.ErrOidcGrantTypeNotSupported
|
||||||
|
}
|
||||||
|
|
||||||
|
clientID := req.ClientID
|
||||||
|
clientSecret := req.ClientSecret
|
||||||
|
|
||||||
|
if clientID == "" || clientSecret == "" {
|
||||||
|
return "", common.ErrOidcMissingClientCredentials
|
||||||
|
}
|
||||||
|
|
||||||
|
var client model.OidcClient
|
||||||
|
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret))
|
||||||
|
if err != nil {
|
||||||
|
return "", common.ErrOidcClientSecretInvalid
|
||||||
|
}
|
||||||
|
|
||||||
|
var authorizationCodeMetaData model.OidcAuthorizationCode
|
||||||
|
err = s.db.Preload("User").First(&authorizationCodeMetaData, "code = ?", req.Code).Error
|
||||||
|
if err != nil {
|
||||||
|
return "", common.ErrOidcInvalidAuthorizationCode
|
||||||
|
}
|
||||||
|
|
||||||
|
if authorizationCodeMetaData.ClientID != clientID && authorizationCodeMetaData.ExpiresAt.Before(time.Now()) {
|
||||||
|
return "", common.ErrOidcInvalidAuthorizationCode
|
||||||
|
}
|
||||||
|
|
||||||
|
idToken, err := s.jwtService.GenerateIDToken(authorizationCodeMetaData.User, clientID, authorizationCodeMetaData.Scope, authorizationCodeMetaData.Nonce)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
s.db.Delete(&authorizationCodeMetaData)
|
||||||
|
|
||||||
|
return idToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OidcService) GetClient(clientID string) (*model.OidcClient, error) {
|
||||||
|
var client model.OidcClient
|
||||||
|
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &client, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OidcService) ListClients(searchTerm string, page int, pageSize int) ([]model.OidcClient, utils.PaginationResponse, error) {
|
||||||
|
var clients []model.OidcClient
|
||||||
|
|
||||||
|
query := s.db.Model(&model.OidcClient{})
|
||||||
|
if searchTerm != "" {
|
||||||
|
searchPattern := "%" + searchTerm + "%"
|
||||||
|
query = query.Where("name LIKE ?", searchPattern)
|
||||||
|
}
|
||||||
|
|
||||||
|
pagination, err := utils.Paginate(page, pageSize, query, &clients)
|
||||||
|
if err != nil {
|
||||||
|
return nil, utils.PaginationResponse{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return clients, pagination, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OidcService) CreateClient(input model.OidcClientCreateDto, userID string) (*model.OidcClient, error) {
|
||||||
|
client := model.OidcClient{
|
||||||
|
Name: input.Name,
|
||||||
|
CallbackURL: input.CallbackURL,
|
||||||
|
CreatedByID: userID,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.db.Create(&client).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &client, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OidcService) UpdateClient(clientID string, input model.OidcClientCreateDto) (*model.OidcClient, error) {
|
||||||
|
var client model.OidcClient
|
||||||
|
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
client.Name = input.Name
|
||||||
|
client.CallbackURL = input.CallbackURL
|
||||||
|
|
||||||
|
if err := s.db.Save(&client).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &client, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OidcService) DeleteClient(clientID string) error {
|
||||||
|
var client model.OidcClient
|
||||||
|
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.db.Delete(&client).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OidcService) CreateClientSecret(clientID string) (string, error) {
|
||||||
|
var client model.OidcClient
|
||||||
|
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
clientSecret, err := utils.GenerateRandomAlphanumericString(32)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
hashedSecret, err := bcrypt.GenerateFromPassword([]byte(clientSecret), bcrypt.DefaultCost)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
client.Secret = string(hashedSecret)
|
||||||
|
if err := s.db.Save(&client).Error; err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return clientSecret, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OidcService) GetClientLogo(clientID string) (string, string, error) {
|
||||||
|
var client model.OidcClient
|
||||||
|
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if client.ImageType == nil {
|
||||||
|
return "", "", errors.New("image not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
imageType := *client.ImageType
|
||||||
|
imagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, client.ID, imageType)
|
||||||
|
mimeType := utils.GetImageMimeType(imageType)
|
||||||
|
|
||||||
|
return imagePath, mimeType, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OidcService) UpdateClientLogo(clientID string, file *multipart.FileHeader) error {
|
||||||
|
fileType := utils.GetFileExtension(file.Filename)
|
||||||
|
if mimeType := utils.GetImageMimeType(fileType); mimeType == "" {
|
||||||
|
return common.ErrFileTypeNotSupported
|
||||||
|
}
|
||||||
|
|
||||||
|
imagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, clientID, fileType)
|
||||||
|
if err := utils.SaveFile(file, imagePath); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var client model.OidcClient
|
||||||
|
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if client.ImageType != nil && fileType != *client.ImageType {
|
||||||
|
oldImagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, client.ID, *client.ImageType)
|
||||||
|
if err := os.Remove(oldImagePath); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
client.ImageType = &fileType
|
||||||
|
if err := s.db.Save(&client).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OidcService) DeleteClientLogo(clientID string) error {
|
||||||
|
var client model.OidcClient
|
||||||
|
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if client.ImageType == nil {
|
||||||
|
return errors.New("image not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
imagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, client.ID, *client.ImageType)
|
||||||
|
if err := os.Remove(imagePath); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
client.ImageType = nil
|
||||||
|
if err := s.db.Save(&client).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OidcService) createAuthorizationCode(clientID string, userID string, scope string, nonce string) (string, error) {
|
||||||
|
randomString, err := utils.GenerateRandomAlphanumericString(32)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
oidcAuthorizationCode := model.OidcAuthorizationCode{
|
||||||
|
ExpiresAt: time.Now().Add(15 * time.Minute),
|
||||||
|
Code: randomString,
|
||||||
|
ClientID: clientID,
|
||||||
|
UserID: userID,
|
||||||
|
Scope: scope,
|
||||||
|
Nonce: nonce,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.db.Create(&oidcAuthorizationCode).Error; err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return randomString, nil
|
||||||
|
}
|
||||||
@@ -1,48 +1,33 @@
|
|||||||
package handler
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/ecdsa"
|
"crypto/ecdsa"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"github.com/fxamacker/cbor/v2"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/fxamacker/cbor/v2"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/go-webauthn/webauthn/protocol"
|
"github.com/go-webauthn/webauthn/protocol"
|
||||||
"golang-rest-api-template/internal/common"
|
"github.com/stonith404/pocket-id/backend/internal/common"
|
||||||
"golang-rest-api-template/internal/model"
|
"github.com/stonith404/pocket-id/backend/internal/model"
|
||||||
"golang-rest-api-template/internal/utils"
|
"github.com/stonith404/pocket-id/backend/internal/utils"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
func RegisterTestRoutes(group *gin.RouterGroup) {
|
type TestService struct {
|
||||||
group.POST("/test/reset", resetAndSeedHandler)
|
db *gorm.DB
|
||||||
|
appConfigService *AppConfigService
|
||||||
}
|
}
|
||||||
|
|
||||||
func resetAndSeedHandler(c *gin.Context) {
|
func NewTestService(db *gorm.DB, appConfigService *AppConfigService) *TestService {
|
||||||
if err := resetDatabase(); err != nil {
|
return &TestService{db: db, appConfigService: appConfigService}
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := resetApplicationImages(); err != nil {
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := seedDatabase(); err != nil {
|
|
||||||
utils.UnknownHandlerError(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.JSON(200, gin.H{"message": "Database reset and seeded"})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// seedDatabase seeds the database with initial data and uses a transaction to ensure atomicity.
|
func (s *TestService) SeedDatabase() error {
|
||||||
func seedDatabase() error {
|
return s.db.Transaction(func(tx *gorm.DB) error {
|
||||||
return common.DB.Transaction(func(tx *gorm.DB) error {
|
|
||||||
users := []model.User{
|
users := []model.User{
|
||||||
{
|
{
|
||||||
Base: model.Base{
|
Base: model.Base{
|
||||||
@@ -128,11 +113,16 @@ func seedDatabase() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
publicKey1, err := getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEwcOo5KV169KR67QEHrcYkeXE3CCxv2BgwnSq4VYTQxyLtdmKxegexa8JdwFKhKXa2BMI9xaN15BoL6wSCRFJhg==")
|
||||||
|
publicKey2, err := getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAESq/wR8QbBu3dKnpaw/v0mDxFFDwnJ/L5XHSg2tAmq5x1BpSMmIr3+DxCbybVvGRmWGh8kKhy7SMnK91M6rFHTA==")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
webauthnCredentials := []model.WebauthnCredential{
|
webauthnCredentials := []model.WebauthnCredential{
|
||||||
{
|
{
|
||||||
Name: "Passkey 1",
|
Name: "Passkey 1",
|
||||||
CredentialID: "test-credential-1",
|
CredentialID: "test-credential-1",
|
||||||
PublicKey: getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEwcOo5KV169KR67QEHrcYkeXE3CCxv2BgwnSq4VYTQxyLtdmKxegexa8JdwFKhKXa2BMI9xaN15BoL6wSCRFJhg=="),
|
PublicKey: publicKey1,
|
||||||
AttestationType: "none",
|
AttestationType: "none",
|
||||||
Transport: model.AuthenticatorTransportList{protocol.Internal},
|
Transport: model.AuthenticatorTransportList{protocol.Internal},
|
||||||
UserID: users[0].ID,
|
UserID: users[0].ID,
|
||||||
@@ -140,7 +130,7 @@ func seedDatabase() error {
|
|||||||
{
|
{
|
||||||
Name: "Passkey 2",
|
Name: "Passkey 2",
|
||||||
CredentialID: "test-credential-2",
|
CredentialID: "test-credential-2",
|
||||||
PublicKey: getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAESq/wR8QbBu3dKnpaw/v0mDxFFDwnJ/L5XHSg2tAmq5x1BpSMmIr3+DxCbybVvGRmWGh8kKhy7SMnK91M6rFHTA=="),
|
PublicKey: publicKey2,
|
||||||
AttestationType: "none",
|
AttestationType: "none",
|
||||||
Transport: model.AuthenticatorTransportList{protocol.Internal},
|
Transport: model.AuthenticatorTransportList{protocol.Internal},
|
||||||
UserID: users[0].ID,
|
UserID: users[0].ID,
|
||||||
@@ -165,9 +155,8 @@ func seedDatabase() error {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// resetDatabase resets the database by deleting all rows from each table.
|
func (s *TestService) ResetDatabase() error {
|
||||||
func resetDatabase() error {
|
err := s.db.Transaction(func(tx *gorm.DB) error {
|
||||||
err := common.DB.Transaction(func(tx *gorm.DB) error {
|
|
||||||
var tables []string
|
var tables []string
|
||||||
if err := tx.Raw("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%' AND name != 'schema_migrations';").Scan(&tables).Error; err != nil {
|
if err := tx.Raw("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%' AND name != 'schema_migrations';").Scan(&tables).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -183,13 +172,11 @@ func resetDatabase() error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
common.InitDbConfig()
|
err = s.appConfigService.InitDbConfig()
|
||||||
return nil
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// resetApplicationImages resets the application images by removing existing images and replacing them with the default ones
|
func (s *TestService) ResetApplicationImages() error {
|
||||||
func resetApplicationImages() error {
|
|
||||||
|
|
||||||
if err := os.RemoveAll(common.EnvConfig.UploadPath); err != nil {
|
if err := os.RemoveAll(common.EnvConfig.UploadPath); err != nil {
|
||||||
log.Printf("Error removing directory: %v", err)
|
log.Printf("Error removing directory: %v", err)
|
||||||
return err
|
return err
|
||||||
@@ -204,20 +191,19 @@ func resetApplicationImages() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// getCborPublicKey decodes a Base64 encoded public key and returns the CBOR encoded COSE key
|
// getCborPublicKey decodes a Base64 encoded public key and returns the CBOR encoded COSE key
|
||||||
func getCborPublicKey(base64PublicKey string) []byte {
|
func getCborPublicKey(base64PublicKey string) ([]byte, error) {
|
||||||
decodedKey, err := base64.StdEncoding.DecodeString(base64PublicKey)
|
decodedKey, err := base64.StdEncoding.DecodeString(base64PublicKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Failed to decode base64 key: %v", err)
|
return nil, fmt.Errorf("failed to decode base64 key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
pubKey, err := x509.ParsePKIXPublicKey(decodedKey)
|
pubKey, err := x509.ParsePKIXPublicKey(decodedKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Failed to parse public key: %v", err)
|
return nil, fmt.Errorf("failed to parse public key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ecdsaPubKey, ok := pubKey.(*ecdsa.PublicKey)
|
ecdsaPubKey, ok := pubKey.(*ecdsa.PublicKey)
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Fatalf("Not an ECDSA public key")
|
return nil, fmt.Errorf("not an ECDSA public key")
|
||||||
}
|
}
|
||||||
|
|
||||||
coseKey := map[int]interface{}{
|
coseKey := map[int]interface{}{
|
||||||
@@ -230,8 +216,8 @@ func getCborPublicKey(base64PublicKey string) []byte {
|
|||||||
|
|
||||||
cborPublicKey, err := cbor.Marshal(coseKey)
|
cborPublicKey, err := cbor.Marshal(coseKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Failed to encode CBOR: %v", err)
|
return nil, fmt.Errorf("failed to marshal COSE key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return cborPublicKey
|
return cborPublicKey, nil
|
||||||
}
|
}
|
||||||
165
backend/internal/service/user_sevice.go
Normal file
165
backend/internal/service/user_sevice.go
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"github.com/stonith404/pocket-id/backend/internal/common"
|
||||||
|
"github.com/stonith404/pocket-id/backend/internal/model"
|
||||||
|
"github.com/stonith404/pocket-id/backend/internal/utils"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type UserService struct {
|
||||||
|
db *gorm.DB
|
||||||
|
jwtService *JwtService
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewUserService(db *gorm.DB, jwtService *JwtService) *UserService {
|
||||||
|
return &UserService{db: db, jwtService: jwtService}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserService) ListUsers(searchTerm string, page int, pageSize int) ([]model.User, utils.PaginationResponse, error) {
|
||||||
|
var users []model.User
|
||||||
|
query := s.db.Model(&model.User{})
|
||||||
|
|
||||||
|
if searchTerm != "" {
|
||||||
|
searchPattern := "%" + searchTerm + "%"
|
||||||
|
query = query.Where("email LIKE ? OR first_name LIKE ? OR username LIKE ?", searchPattern, searchPattern, searchPattern)
|
||||||
|
}
|
||||||
|
|
||||||
|
pagination, err := utils.Paginate(page, pageSize, query, &users)
|
||||||
|
return users, pagination, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserService) GetUser(userID string) (model.User, error) {
|
||||||
|
var user model.User
|
||||||
|
err := s.db.Where("id = ?", userID).First(&user).Error
|
||||||
|
return user, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserService) DeleteUser(userID string) error {
|
||||||
|
var user model.User
|
||||||
|
if err := s.db.Where("id = ?", userID).First(&user).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.db.Delete(&user).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserService) CreateUser(user *model.User) error {
|
||||||
|
if err := s.db.Create(user).Error; err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrDuplicatedKey) {
|
||||||
|
return s.checkDuplicatedFields(*user)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserService) UpdateUser(userID string, updatedUser model.User, updateOwnUser bool) (model.User, error) {
|
||||||
|
var user model.User
|
||||||
|
if err := s.db.Where("id = ?", userID).First(&user).Error; err != nil {
|
||||||
|
return model.User{}, err
|
||||||
|
}
|
||||||
|
user.FirstName = updatedUser.FirstName
|
||||||
|
user.LastName = updatedUser.LastName
|
||||||
|
user.Email = updatedUser.Email
|
||||||
|
user.Username = updatedUser.Username
|
||||||
|
if !updateOwnUser {
|
||||||
|
user.IsAdmin = updatedUser.IsAdmin
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.db.Save(&user).Error; err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrDuplicatedKey) {
|
||||||
|
return user, s.checkDuplicatedFields(user)
|
||||||
|
}
|
||||||
|
return user, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserService) CreateOneTimeAccessToken(userID string, expiresAt time.Time) (string, error) {
|
||||||
|
randomString, err := utils.GenerateRandomAlphanumericString(16)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
oneTimeAccessToken := model.OneTimeAccessToken{
|
||||||
|
UserID: userID,
|
||||||
|
ExpiresAt: expiresAt,
|
||||||
|
Token: randomString,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.db.Create(&oneTimeAccessToken).Error; err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return oneTimeAccessToken.Token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserService) ExchangeOneTimeAccessToken(token string) (model.User, string, error) {
|
||||||
|
var oneTimeAccessToken model.OneTimeAccessToken
|
||||||
|
if err := s.db.Where("token = ? AND expires_at > ?", token, utils.FormatDateForDb(time.Now())).Preload("User").First(&oneTimeAccessToken).Error; err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return model.User{}, "", common.ErrTokenInvalidOrExpired
|
||||||
|
}
|
||||||
|
return model.User{}, "", err
|
||||||
|
}
|
||||||
|
accessToken, err := s.jwtService.GenerateAccessToken(oneTimeAccessToken.User)
|
||||||
|
if err != nil {
|
||||||
|
return model.User{}, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.db.Delete(&oneTimeAccessToken).Error; err != nil {
|
||||||
|
return model.User{}, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return oneTimeAccessToken.User, accessToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserService) SetupInitialAdmin() (model.User, string, error) {
|
||||||
|
var userCount int64
|
||||||
|
if err := s.db.Model(&model.User{}).Count(&userCount).Error; err != nil {
|
||||||
|
return model.User{}, "", err
|
||||||
|
}
|
||||||
|
if userCount > 1 {
|
||||||
|
return model.User{}, "", common.ErrSetupAlreadyCompleted
|
||||||
|
}
|
||||||
|
|
||||||
|
user := model.User{
|
||||||
|
FirstName: "Admin",
|
||||||
|
LastName: "Admin",
|
||||||
|
Username: "admin",
|
||||||
|
Email: "admin@admin.com",
|
||||||
|
IsAdmin: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.db.Model(&model.User{}).Preload("Credentials").FirstOrCreate(&user).Error; err != nil {
|
||||||
|
return model.User{}, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(user.Credentials) > 0 {
|
||||||
|
return model.User{}, "", common.ErrSetupAlreadyCompleted
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := s.jwtService.GenerateAccessToken(user)
|
||||||
|
if err != nil {
|
||||||
|
return model.User{}, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return user, token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserService) checkDuplicatedFields(user model.User) error {
|
||||||
|
var existingUser model.User
|
||||||
|
if s.db.Where("id != ? AND email = ?", user.ID, user.Email).First(&existingUser).Error == nil {
|
||||||
|
return common.ErrEmailTaken
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.db.Where("id != ? AND username = ?", user.ID, user.Username).First(&existingUser).Error == nil {
|
||||||
|
return common.ErrUsernameTaken
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
196
backend/internal/service/webauthn_service.go
Normal file
196
backend/internal/service/webauthn_service.go
Normal file
@@ -0,0 +1,196 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/go-webauthn/webauthn/protocol"
|
||||||
|
"github.com/go-webauthn/webauthn/webauthn"
|
||||||
|
"github.com/stonith404/pocket-id/backend/internal/common"
|
||||||
|
"github.com/stonith404/pocket-id/backend/internal/model"
|
||||||
|
"github.com/stonith404/pocket-id/backend/internal/utils"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type WebAuthnService struct {
|
||||||
|
db *gorm.DB
|
||||||
|
webAuthn *webauthn.WebAuthn
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewWebAuthnService(db *gorm.DB, appConfigService *AppConfigService) *WebAuthnService {
|
||||||
|
webauthnConfig := &webauthn.Config{
|
||||||
|
RPDisplayName: appConfigService.DbConfig.AppName.Value,
|
||||||
|
RPID: utils.GetHostFromURL(common.EnvConfig.AppURL),
|
||||||
|
RPOrigins: []string{common.EnvConfig.AppURL},
|
||||||
|
Timeouts: webauthn.TimeoutsConfig{
|
||||||
|
Login: webauthn.TimeoutConfig{
|
||||||
|
Enforce: true,
|
||||||
|
Timeout: time.Second * 60,
|
||||||
|
TimeoutUVD: time.Second * 60,
|
||||||
|
},
|
||||||
|
Registration: webauthn.TimeoutConfig{
|
||||||
|
Enforce: true,
|
||||||
|
Timeout: time.Second * 60,
|
||||||
|
TimeoutUVD: time.Second * 60,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
wa, _ := webauthn.New(webauthnConfig)
|
||||||
|
return &WebAuthnService{db: db, webAuthn: wa}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *WebAuthnService) BeginRegistration(userID string) (*model.PublicKeyCredentialCreationOptions, error) {
|
||||||
|
var user model.User
|
||||||
|
if err := s.db.Preload("Credentials").Find(&user, "id = ?", userID).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
options, session, err := s.webAuthn.BeginRegistration(&user, webauthn.WithResidentKeyRequirement(protocol.ResidentKeyRequirementRequired), webauthn.WithExclusions(user.WebAuthnCredentialDescriptors()))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionToStore := &model.WebauthnSession{
|
||||||
|
ExpiresAt: session.Expires,
|
||||||
|
Challenge: session.Challenge,
|
||||||
|
UserVerification: string(session.UserVerification),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.db.Create(&sessionToStore).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &model.PublicKeyCredentialCreationOptions{
|
||||||
|
Response: options.Response,
|
||||||
|
SessionID: sessionToStore.ID,
|
||||||
|
Timeout: s.webAuthn.Config.Timeouts.Registration.Timeout,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *WebAuthnService) VerifyRegistration(sessionID, userID string, r *http.Request) (*model.WebauthnCredential, error) {
|
||||||
|
var storedSession model.WebauthnSession
|
||||||
|
if err := s.db.First(&storedSession, "id = ?", sessionID).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
session := webauthn.SessionData{
|
||||||
|
Challenge: storedSession.Challenge,
|
||||||
|
Expires: storedSession.ExpiresAt,
|
||||||
|
UserID: []byte(userID),
|
||||||
|
}
|
||||||
|
|
||||||
|
var user model.User
|
||||||
|
if err := s.db.Find(&user, "id = ?", userID).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
credential, err := s.webAuthn.FinishRegistration(&user, session, r)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
credentialToStore := model.WebauthnCredential{
|
||||||
|
Name: "New Passkey",
|
||||||
|
CredentialID: string(credential.ID),
|
||||||
|
AttestationType: credential.AttestationType,
|
||||||
|
PublicKey: credential.PublicKey,
|
||||||
|
Transport: credential.Transport,
|
||||||
|
UserID: user.ID,
|
||||||
|
BackupEligible: credential.Flags.BackupEligible,
|
||||||
|
BackupState: credential.Flags.BackupState,
|
||||||
|
}
|
||||||
|
if err := s.db.Create(&credentialToStore).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &credentialToStore, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *WebAuthnService) BeginLogin() (*model.PublicKeyCredentialRequestOptions, error) {
|
||||||
|
options, session, err := s.webAuthn.BeginDiscoverableLogin()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionToStore := &model.WebauthnSession{
|
||||||
|
ExpiresAt: session.Expires,
|
||||||
|
Challenge: session.Challenge,
|
||||||
|
UserVerification: string(session.UserVerification),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.db.Create(&sessionToStore).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &model.PublicKeyCredentialRequestOptions{
|
||||||
|
Response: options.Response,
|
||||||
|
SessionID: sessionToStore.ID,
|
||||||
|
Timeout: s.webAuthn.Config.Timeouts.Registration.Timeout,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *WebAuthnService) VerifyLogin(sessionID, userID string, credentialAssertionData *protocol.ParsedCredentialAssertionData) (*model.User, error) {
|
||||||
|
var storedSession model.WebauthnSession
|
||||||
|
if err := s.db.First(&storedSession, "id = ?", sessionID).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
session := webauthn.SessionData{
|
||||||
|
Challenge: storedSession.Challenge,
|
||||||
|
Expires: storedSession.ExpiresAt,
|
||||||
|
}
|
||||||
|
|
||||||
|
var user *model.User
|
||||||
|
_, err := s.webAuthn.ValidateDiscoverableLogin(func(_, userHandle []byte) (webauthn.User, error) {
|
||||||
|
if err := s.db.Preload("Credentials").First(&user, "id = ?", string(userHandle)).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return user, nil
|
||||||
|
}, session, credentialAssertionData)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, common.ErrInvalidCredentials
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.db.Find(&user, "id = ?", userID).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *WebAuthnService) ListCredentials(userID string) ([]model.WebauthnCredential, error) {
|
||||||
|
var credentials []model.WebauthnCredential
|
||||||
|
if err := s.db.Find(&credentials, "user_id = ?", userID).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return credentials, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *WebAuthnService) DeleteCredential(userID, credentialID string) error {
|
||||||
|
var credential model.WebauthnCredential
|
||||||
|
if err := s.db.First(&credential, "id = ? AND user_id = ?", credentialID, userID).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.db.Delete(&credential).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *WebAuthnService) UpdateCredential(userID, credentialID, name string) error {
|
||||||
|
var credential model.WebauthnCredential
|
||||||
|
if err := s.db.Where("id = ? AND user_id = ?", credentialID, userID).First(&credential).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
credential.Name = name
|
||||||
|
|
||||||
|
if err := s.db.Save(&credential).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -2,6 +2,7 @@ package utils
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"io"
|
"io"
|
||||||
|
"mime/multipart"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -71,3 +72,24 @@ func copyFile(srcFilePath, destFilePath string) error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func SaveFile(file *multipart.FileHeader, dst string) error {
|
||||||
|
src, err := file.Open()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer src.Close()
|
||||||
|
|
||||||
|
if err = os.MkdirAll(filepath.Dir(dst), 0o750); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := os.Create(dst)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer out.Close()
|
||||||
|
|
||||||
|
_, err = io.Copy(out, src)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,15 +1,23 @@
|
|||||||
package utils
|
package utils
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"gorm.io/gorm"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
func UnknownHandlerError(c *gin.Context, err error) {
|
func UnknownHandlerError(c *gin.Context, err error) {
|
||||||
log.Println(err)
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Something went wrong"})
|
HandlerError(c, http.StatusNotFound, "Record not found")
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
log.Println(err)
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Something went wrong"})
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func HandlerError(c *gin.Context, statusCode int, message string) {
|
func HandlerError(c *gin.Context, statusCode int, message string) {
|
||||||
|
|||||||
@@ -1,9 +1,7 @@
|
|||||||
package utils
|
package utils
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"strconv"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type PaginationResponse struct {
|
type PaginationResponse struct {
|
||||||
@@ -12,10 +10,7 @@ type PaginationResponse struct {
|
|||||||
CurrentPage int `json:"currentPage"`
|
CurrentPage int `json:"currentPage"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func Paginate(c *gin.Context, db *gorm.DB, result interface{}) (PaginationResponse, error) {
|
func Paginate(page int, pageSize int, db *gorm.DB, result interface{}) (PaginationResponse, error) {
|
||||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
|
||||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("limit", "10"))
|
|
||||||
|
|
||||||
if page < 1 {
|
if page < 1 {
|
||||||
page = 1
|
page = 1
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,2 @@
|
|||||||
|
ALTER TABLE app_config_variables
|
||||||
|
RENAME TO application_configuration_variables;
|
||||||
@@ -0,0 +1,2 @@
|
|||||||
|
ALTER TABLE application_configuration_variables
|
||||||
|
RENAME TO app_config_variables;
|
||||||
@@ -69,15 +69,3 @@ test('Delete passkey from account', async ({ page }) => {
|
|||||||
|
|
||||||
await expect(page.getByRole('status')).toHaveText('Passkey deleted successfully');
|
await expect(page.getByRole('status')).toHaveText('Passkey deleted successfully');
|
||||||
});
|
});
|
||||||
|
|
||||||
test('Delete last passkey from account fails', async ({ page }) => {
|
|
||||||
await page.goto('/settings/account');
|
|
||||||
|
|
||||||
await page.getByLabel('Delete').first().click();
|
|
||||||
await page.getByText('Delete', { exact: true }).click();
|
|
||||||
|
|
||||||
await page.getByLabel('Delete').first().click();
|
|
||||||
await page.getByText('Delete', { exact: true }).click();
|
|
||||||
|
|
||||||
await expect(page.getByRole('status').first()).toHaveText('You must have at least one passkey');
|
|
||||||
});
|
|
||||||
|
|||||||
@@ -35,6 +35,21 @@ test('Create user fails with already taken email', async ({ page }) => {
|
|||||||
await expect(page.getByRole('status')).toHaveText('Email is already taken');
|
await expect(page.getByRole('status')).toHaveText('Email is already taken');
|
||||||
});
|
});
|
||||||
|
|
||||||
|
test('Create user fails with already taken username', async ({ page }) => {
|
||||||
|
const user = users.steve;
|
||||||
|
|
||||||
|
await page.goto('/settings/admin/users');
|
||||||
|
|
||||||
|
await page.getByRole('button', { name: 'Add User' }).click();
|
||||||
|
await page.getByLabel('Firstname').fill(user.firstname);
|
||||||
|
await page.getByLabel('Lastname').fill(user.lastname);
|
||||||
|
await page.getByLabel('Email').fill(user.email);
|
||||||
|
await page.getByLabel('Username').fill(users.tim.username);
|
||||||
|
await page.getByRole('button', { name: 'Save' }).click();
|
||||||
|
|
||||||
|
await expect(page.getByRole('status')).toHaveText('Username is already taken');
|
||||||
|
});
|
||||||
|
|
||||||
test('Create one time access token', async ({ page }) => {
|
test('Create one time access token', async ({ page }) => {
|
||||||
await page.goto('/settings/admin/users');
|
await page.goto('/settings/admin/users');
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user