refactor: use dependency injection in backend

This commit is contained in:
Elias Schneider
2024-08-17 21:57:14 +02:00
parent 0595d73ea5
commit 601f6c488a
44 changed files with 2220 additions and 1751 deletions

View File

@@ -1,7 +1,7 @@
package main
import (
"golang-rest-api-template/internal/bootstrap"
"github.com/stonith404/pocket-id/backend/internal/bootstrap"
)
func main() {

View File

@@ -1,4 +1,4 @@
module golang-rest-api-template
module github.com/stonith404/pocket-id/backend
go 1.22

View File

@@ -0,0 +1,28 @@
package bootstrap
import (
"github.com/stonith404/pocket-id/backend/internal/common"
"github.com/stonith404/pocket-id/backend/internal/utils"
"log"
"os"
)
func initApplicationImages() {
dirPath := common.EnvConfig.UploadPath + "/application-images"
files, err := os.ReadDir(dirPath)
if err != nil && !os.IsNotExist(err) {
log.Fatalf("Error reading directory: %v", err)
}
// Skip if files already exist
if len(files) > 1 {
return
}
// Copy files from source to destination
err = utils.CopyDirectory("./images", dirPath)
if err != nil {
log.Fatalf("Error copying directory: %v", err)
}
}

View File

@@ -1,78 +1,16 @@
package bootstrap
import (
"github.com/gin-gonic/gin"
_ "github.com/golang-migrate/migrate/v4/source/file"
"golang-rest-api-template/internal/common"
"golang-rest-api-template/internal/common/middleware"
"golang-rest-api-template/internal/handler"
"golang-rest-api-template/internal/job"
"golang-rest-api-template/internal/utils"
"golang.org/x/time/rate"
"log"
"os"
"time"
"github.com/stonith404/pocket-id/backend/internal/job"
"github.com/stonith404/pocket-id/backend/internal/service"
)
func Bootstrap() {
common.InitDatabase()
common.InitDbConfig()
db := newDatabase()
appConfigService := service.NewAppConfigService(db)
initApplicationImages()
job.RegisterJobs()
initRouter()
}
func initRouter() {
switch common.EnvConfig.AppEnv {
case "production":
gin.SetMode(gin.ReleaseMode)
case "development":
gin.SetMode(gin.DebugMode)
case "test":
gin.SetMode(gin.TestMode)
}
r := gin.Default()
r.Use(gin.Logger())
r.Use(middleware.Cors())
r.Use(middleware.RateLimiter(rate.Every(time.Second), 60))
apiGroup := r.Group("/api")
handler.RegisterRoutes(apiGroup)
handler.RegisterOIDCRoutes(apiGroup)
handler.RegisterUserRoutes(apiGroup)
handler.RegisterConfigurationRoutes(apiGroup)
if common.EnvConfig.AppEnv != "production" {
handler.RegisterTestRoutes(apiGroup)
}
baseGroup := r.Group("/")
handler.RegisterWellKnownRoutes(baseGroup)
if err := r.Run(common.EnvConfig.Host + ":" + common.EnvConfig.Port); err != nil {
log.Fatal(err)
}
}
func initApplicationImages() {
dirPath := common.EnvConfig.UploadPath + "/application-images"
files, err := os.ReadDir(dirPath)
if err != nil && !os.IsNotExist(err) {
log.Fatalf("Error reading directory: %v", err)
}
// Skip if files already exist
if len(files) > 1 {
return
}
// Copy files from source to destination
err = utils.CopyDirectory("./images", dirPath)
if err != nil {
log.Fatalf("Error copying directory: %v", err)
}
job.RegisterJobs(db)
initRouter(db, appConfigService)
}

View File

@@ -1,51 +1,54 @@
package common
package bootstrap
import (
"errors"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database/sqlite3"
"github.com/stonith404/pocket-id/backend/internal/common"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"log"
"os"
"time"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
var DB *gorm.DB
func InitDatabase() {
connectDatabase()
sqlDb, err := DB.DB()
func newDatabase() (db *gorm.DB) {
db, err := connectDatabase()
if err != nil {
log.Fatal("failed to get sql db", err)
log.Fatalf("failed to connect to database: %v", err)
}
sqlDb, err := db.DB()
if err != nil {
log.Fatalf("failed to get sql.DB: %v", err)
}
driver, err := sqlite3.WithInstance(sqlDb, &sqlite3.Config{})
m, err := migrate.NewWithDatabaseInstance(
"file://migrations",
"postgres", driver)
if err != nil {
log.Fatal("failed to create migration instance", err)
log.Fatalf("failed to create migration instance: %v", err)
}
err = m.Up()
if err != nil && !errors.Is(err, migrate.ErrNoChange) {
log.Fatal("failed to run migrations", err)
log.Fatalf("failed to apply migrations: %v", err)
}
return db
}
func connectDatabase() {
var database *gorm.DB
var err error
func connectDatabase() (db *gorm.DB, err error) {
dbPath := common.EnvConfig.DBPath
dbPath := EnvConfig.DBPath
if EnvConfig.AppEnv == "test" {
// Use in-memory database for testing
if common.EnvConfig.AppEnv == "test" {
dbPath = "file::memory:?cache=shared"
}
for i := 1; i <= 3; i++ {
database, err = gorm.Open(sqlite.Open(dbPath), &gorm.Config{
db, err = gorm.Open(sqlite.Open(dbPath), &gorm.Config{
TranslateError: true,
Logger: getLogger(),
})
@@ -57,11 +60,11 @@ func connectDatabase() {
}
}
DB = database
return db, err
}
func getLogger() logger.Interface {
isProduction := EnvConfig.AppEnv == "production"
isProduction := common.EnvConfig.AppEnv == "production"
var logLevel logger.LogLevel
if isProduction {
@@ -70,7 +73,6 @@ func getLogger() logger.Interface {
logLevel = logger.Info
}
// Create the GORM logger
return logger.New(
log.New(os.Stdout, "\r\n", log.LstdFlags),
logger.Config{

View File

@@ -0,0 +1,67 @@
package bootstrap
import (
"log"
"time"
"github.com/gin-gonic/gin"
"github.com/stonith404/pocket-id/backend/internal/common"
"github.com/stonith404/pocket-id/backend/internal/controller"
"github.com/stonith404/pocket-id/backend/internal/middleware"
"github.com/stonith404/pocket-id/backend/internal/service"
"golang.org/x/time/rate"
"gorm.io/gorm"
)
func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) {
// Set the appropriate Gin mode based on the environment
switch common.EnvConfig.AppEnv {
case "production":
gin.SetMode(gin.ReleaseMode)
case "development":
gin.SetMode(gin.DebugMode)
case "test":
gin.SetMode(gin.TestMode)
}
r := gin.Default()
r.Use(gin.Logger())
// Add middleware
r.Use(
middleware.NewCorsMiddleware().Add(),
middleware.NewRateLimitMiddleware().Add(rate.Every(time.Second), 60),
)
// Initialize services
webauthnService := service.NewWebAuthnService(db, appConfigService)
jwtService := service.NewJwtService(appConfigService)
userService := service.NewUserService(db, jwtService)
oidcService := service.NewOidcService(db, jwtService)
testService := service.NewTestService(db, appConfigService)
// Initialize middleware
jwtAuthMiddleware := middleware.NewJwtAuthMiddleware(jwtService)
fileSizeLimitMiddleware := middleware.NewFileSizeLimitMiddleware()
// Set up API routes
apiGroup := r.Group("/api")
controller.NewWebauthnController(apiGroup, jwtAuthMiddleware, middleware.NewRateLimitMiddleware(), webauthnService, jwtService)
controller.NewOidcController(apiGroup, jwtAuthMiddleware, fileSizeLimitMiddleware, oidcService)
controller.NewUserController(apiGroup, jwtAuthMiddleware, middleware.NewRateLimitMiddleware(), userService)
controller.NewApplicationConfigurationController(apiGroup, jwtAuthMiddleware, appConfigService)
// Add test controller in non-production environments
if common.EnvConfig.AppEnv != "production" {
controller.NewTestController(apiGroup, testService)
}
// Set up base routes
baseGroup := r.Group("/")
controller.NewWellKnownController(baseGroup, jwtService)
// Run the server
if err := r.Run(common.EnvConfig.Host + ":" + common.EnvConfig.Port); err != nil {
log.Fatal(err)
}
}

View File

@@ -1,138 +0,0 @@
package common
import (
"github.com/caarlos0/env/v11"
_ "github.com/joho/godotenv/autoload"
"golang-rest-api-template/internal/model"
"log"
"reflect"
)
type EnvConfigSchema struct {
AppEnv string `env:"APP_ENV"`
AppURL string `env:"PUBLIC_APP_URL"`
DBPath string `env:"DB_PATH"`
UploadPath string `env:"UPLOAD_PATH"`
Port string `env:"BACKEND_PORT"`
Host string `env:"HOST"`
}
var EnvConfig = &EnvConfigSchema{
AppEnv: "production",
DBPath: "data/pocket-id.db",
UploadPath: "data/uploads",
AppURL: "http://localhost",
Port: "8080",
Host: "localhost",
}
var DbConfig = NewDefaultDbConfig()
func NewDefaultDbConfig() model.ApplicationConfiguration {
return model.ApplicationConfiguration{
AppName: model.ApplicationConfigurationVariable{
Key: "appName",
Type: "string",
IsPublic: true,
Value: "Pocket ID",
},
SessionDuration: model.ApplicationConfigurationVariable{
Key: "sessionDuration",
Type: "number",
Value: "60",
},
BackgroundImageType: model.ApplicationConfigurationVariable{
Key: "backgroundImageType",
Type: "string",
IsInternal: true,
Value: "jpg",
},
LogoImageType: model.ApplicationConfigurationVariable{
Key: "logoImageType",
Type: "string",
IsInternal: true,
Value: "svg",
},
}
}
// LoadDbConfigFromDb refreshes the database configuration by loading the current values
// from the database and updating the DbConfig struct.
func LoadDbConfigFromDb() error {
dbConfigReflectValue := reflect.ValueOf(&DbConfig).Elem()
for i := 0; i < dbConfigReflectValue.NumField(); i++ {
dbConfigField := dbConfigReflectValue.Field(i)
currentConfigVar := dbConfigField.Interface().(model.ApplicationConfigurationVariable)
var storedConfigVar model.ApplicationConfigurationVariable
if err := DB.First(&storedConfigVar, "key = ?", currentConfigVar.Key).Error; err != nil {
return err
}
dbConfigField.Set(reflect.ValueOf(storedConfigVar))
}
return nil
}
// InitDbConfig creates the default configuration values in the database if they do not exist,
// updates existing configurations if they differ from the default, and deletes any configurations
// that are not in the default configuration.
func InitDbConfig() {
// Reflect to get the underlying value of DbConfig and its default configuration
dbConfigReflectValue := reflect.ValueOf(&DbConfig).Elem()
defaultDbConfig := NewDefaultDbConfig()
defaultConfigReflectValue := reflect.ValueOf(&defaultDbConfig).Elem()
defaultKeys := make(map[string]struct{})
// Iterate over the fields of DbConfig
for i := 0; i < dbConfigReflectValue.NumField(); i++ {
dbConfigField := dbConfigReflectValue.Field(i)
currentConfigVar := dbConfigField.Interface().(model.ApplicationConfigurationVariable)
defaultConfigVar := defaultConfigReflectValue.Field(i).Interface().(model.ApplicationConfigurationVariable)
defaultKeys[currentConfigVar.Key] = struct{}{}
var storedConfigVar model.ApplicationConfigurationVariable
if err := DB.First(&storedConfigVar, "key = ?", currentConfigVar.Key).Error; err != nil {
// If the configuration does not exist, create it
if err := DB.Create(&defaultConfigVar).Error; err != nil {
log.Fatalf("Failed to create default configuration: %v", err)
}
dbConfigField.Set(reflect.ValueOf(defaultConfigVar))
continue
}
// Update existing configuration if it differs from the default
if storedConfigVar.Type != defaultConfigVar.Type || storedConfigVar.IsPublic != defaultConfigVar.IsPublic || storedConfigVar.IsInternal != defaultConfigVar.IsInternal {
storedConfigVar.Type = defaultConfigVar.Type
storedConfigVar.IsPublic = defaultConfigVar.IsPublic
storedConfigVar.IsInternal = defaultConfigVar.IsInternal
if err := DB.Save(&storedConfigVar).Error; err != nil {
log.Fatalf("Failed to update configuration: %v", err)
}
}
// Set the value in DbConfig
dbConfigField.Set(reflect.ValueOf(storedConfigVar))
}
// Delete any configurations not in the default keys
var allConfigVars []model.ApplicationConfigurationVariable
if err := DB.Find(&allConfigVars).Error; err != nil {
log.Fatalf("Failed to retrieve existing configurations: %v", err)
}
for _, config := range allConfigVars {
if _, exists := defaultKeys[config.Key]; !exists {
if err := DB.Delete(&config).Error; err != nil {
log.Fatalf("Failed to delete outdated configuration: %v", err)
}
}
}
}
func init() {
if err := env.ParseWithOptions(EnvConfig, env.Options{}); err != nil {
log.Fatal(err)
}
}

View File

@@ -0,0 +1,31 @@
package common
import (
"github.com/caarlos0/env/v11"
_ "github.com/joho/godotenv/autoload"
"log"
)
type EnvConfigSchema struct {
AppEnv string `env:"APP_ENV"`
AppURL string `env:"PUBLIC_APP_URL"`
DBPath string `env:"DB_PATH"`
UploadPath string `env:"UPLOAD_PATH"`
Port string `env:"BACKEND_PORT"`
Host string `env:"HOST"`
}
var EnvConfig = &EnvConfigSchema{
AppEnv: "production",
DBPath: "data/pocket-id.db",
UploadPath: "data/uploads",
AppURL: "http://localhost",
Port: "8080",
Host: "localhost",
}
func init() {
if err := env.ParseWithOptions(EnvConfig, env.Options{}); err != nil {
log.Fatal(err)
}
}

View File

@@ -0,0 +1,18 @@
package common
import "errors"
var (
ErrUsernameTaken = errors.New("username is already taken")
ErrEmailTaken = errors.New("email is already taken")
ErrSetupAlreadyCompleted = errors.New("setup already completed")
ErrInvalidBody = errors.New("invalid request body")
ErrTokenInvalidOrExpired = errors.New("token is invalid or expired")
ErrOidcMissingAuthorization = errors.New("missing authorization")
ErrOidcGrantTypeNotSupported = errors.New("grant type not supported")
ErrOidcMissingClientCredentials = errors.New("client id or secret not provided")
ErrOidcClientSecretInvalid = errors.New("invalid client secret")
ErrOidcInvalidAuthorizationCode = errors.New("invalid authorization code")
ErrFileTypeNotSupported = errors.New("file type not supported")
ErrInvalidCredentials = errors.New("no user found with provided credentials")
)

View File

@@ -1,209 +0,0 @@
package common
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"errors"
"github.com/golang-jwt/jwt/v5"
"golang-rest-api-template/internal/model"
"golang-rest-api-template/internal/utils"
"log"
"math/big"
"os"
"path/filepath"
"slices"
"strconv"
"strings"
"time"
)
var (
PrivateKey *rsa.PrivateKey
PublicKey *rsa.PublicKey
)
const (
privateKeyPath = "data/keys/jwt_private_key.pem"
publicKeyPath = "data/keys/jwt_public_key.pem"
)
type accessTokenJWTClaims struct {
jwt.RegisteredClaims
IsAdmin bool `json:"isAdmin,omitempty"`
}
// GenerateIDToken generates an ID token for the given user, clientID, scope and nonce.
func GenerateIDToken(user model.User, clientID string, scope string, nonce string) (tokenString string, err error) {
profileClaims := map[string]interface{}{
"given_name": user.FirstName,
"family_name": user.LastName,
"email": user.Email,
"preferred_username": user.Username,
}
claims := jwt.MapClaims{
"sub": user.ID,
"aud": clientID,
"exp": jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
"iat": jwt.NewNumericDate(time.Now()),
}
if nonce != "" {
claims["nonce"] = nonce
}
if strings.Contains(scope, "profile") {
for k, v := range profileClaims {
claims[k] = v
}
}
if strings.Contains(scope, "email") {
claims["email"] = user.Email
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
signedToken, err := token.SignedString(PrivateKey)
if err != nil {
return "", err
}
return signedToken, nil
}
// GenerateAccessToken generates an access token for the given user.
func GenerateAccessToken(user model.User) (tokenString string, err error) {
sessionDurationInMinutes, _ := strconv.Atoi(DbConfig.SessionDuration.Value)
claim := accessTokenJWTClaims{
RegisteredClaims: jwt.RegisteredClaims{
Subject: user.ID,
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Duration(sessionDurationInMinutes) * time.Minute)),
IssuedAt: jwt.NewNumericDate(time.Now()),
Audience: jwt.ClaimStrings{utils.GetHostFromURL(EnvConfig.AppURL)},
},
IsAdmin: user.IsAdmin,
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claim)
tokenString, err = token.SignedString(PrivateKey)
return tokenString, err
}
// VerifyAccessToken verifies the given access token and returns the claims if the token is valid.
func VerifyAccessToken(tokenString string) (*accessTokenJWTClaims, error) {
token, err := jwt.ParseWithClaims(tokenString, &accessTokenJWTClaims{}, func(token *jwt.Token) (interface{}, error) {
return PublicKey, nil
})
if err != nil || !token.Valid {
return nil, errors.New("couldn't handle this token")
}
claims, isValid := token.Claims.(*accessTokenJWTClaims)
if !isValid {
return nil, errors.New("can't parse claims")
}
if !slices.Contains(claims.Audience, utils.GetHostFromURL(EnvConfig.AppURL)) {
return nil, errors.New("audience doesn't match")
}
return claims, nil
}
type JWK struct {
Kty string `json:"kty"`
Use string `json:"use"`
Kid string `json:"kid"`
Alg string `json:"alg"`
N string `json:"n"`
E string `json:"e"`
}
// GetJWK returns the JSON Web Key (JWK) for the public key.
func GetJWK() (JWK, error) {
if PublicKey == nil {
return JWK{}, errors.New("public key is not initialized")
}
// Create JWK from RSA public key
jwk := JWK{
Kty: "RSA",
Use: "sig",
Kid: "1", // Key ID can be set to any identifier. Here it's statically set to "1"
Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(PublicKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(PublicKey.E)).Bytes()),
}
return jwk, nil
}
// generateKeys generates a new RSA key pair and saves the private and public keys to the data folder.
func generateKeys() {
if err := os.MkdirAll(filepath.Dir(privateKeyPath), 0700); err != nil {
log.Fatal("Failed to create directories for keys", err)
}
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
log.Fatal("Failed to generate private key", err)
}
privateKeyFile, err := os.Create(privateKeyPath)
if err != nil {
log.Fatal("Failed to create private key file", err)
}
defer privateKeyFile.Close()
privateKeyPEM := pem.EncodeToMemory(
&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
},
)
_, err = privateKeyFile.Write(privateKeyPEM)
if err != nil {
log.Fatal("Failed to write private key file", err)
}
publicKey := &privateKey.PublicKey
publicKeyFile, err := os.Create(publicKeyPath)
if err != nil {
log.Fatal("Failed to create public key file", err)
}
defer publicKeyFile.Close()
publicKeyPEM := pem.EncodeToMemory(
&pem.Block{
Type: "RSA PUBLIC KEY",
Bytes: x509.MarshalPKCS1PublicKey(publicKey),
},
)
_, err = publicKeyFile.Write(publicKeyPEM)
if err != nil {
log.Fatal("Failed to write public key file", err)
}
}
func init() {
if _, err := os.Stat(privateKeyPath); os.IsNotExist(err) {
generateKeys()
}
privateKeyBytes, err := os.ReadFile(privateKeyPath)
if err != nil {
log.Fatal("Can't read jwt private key", err)
}
PrivateKey, err = jwt.ParseRSAPrivateKeyFromPEM(privateKeyBytes)
if err != nil {
log.Fatal("Can't parse jwt private key", err)
}
publicKeyBytes, err := os.ReadFile(publicKeyPath)
if err != nil {
log.Fatal("Can't read jwt public key", err)
}
PublicKey, err = jwt.ParseRSAPublicKeyFromPEM(publicKeyBytes)
if err != nil {
log.Fatal("Can't parse jwt public key", err)
}
}

View File

@@ -1,37 +0,0 @@
package common
import (
"github.com/go-webauthn/webauthn/webauthn"
"golang-rest-api-template/internal/utils"
"log"
"time"
)
var (
WebAuthn *webauthn.WebAuthn
err error
)
func init() {
config := &webauthn.Config{
RPDisplayName: DbConfig.AppName.Value,
RPID: utils.GetHostFromURL(EnvConfig.AppURL),
RPOrigins: []string{EnvConfig.AppURL},
Timeouts: webauthn.TimeoutsConfig{
Login: webauthn.TimeoutConfig{
Enforce: true,
Timeout: time.Second * 60,
TimeoutUVD: time.Second * 60,
},
Registration: webauthn.TimeoutConfig{
Enforce: true,
Timeout: time.Second * 60,
TimeoutUVD: time.Second * 60,
},
},
}
if WebAuthn, err = webauthn.New(config); err != nil {
log.Fatal(err)
}
}

View File

@@ -0,0 +1,140 @@
package controller
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/stonith404/pocket-id/backend/internal/common"
"github.com/stonith404/pocket-id/backend/internal/middleware"
"github.com/stonith404/pocket-id/backend/internal/model"
"github.com/stonith404/pocket-id/backend/internal/service"
"github.com/stonith404/pocket-id/backend/internal/utils"
"net/http"
)
func NewApplicationConfigurationController(
group *gin.RouterGroup,
jwtAuthMiddleware *middleware.JwtAuthMiddleware,
appConfigService *service.AppConfigService) {
acc := &ApplicationConfigurationController{
appConfigService: appConfigService,
}
group.GET("/application-configuration", acc.listApplicationConfigurationHandler)
group.GET("/application-configuration/all", jwtAuthMiddleware.Add(true), acc.listAllApplicationConfigurationHandler)
group.PUT("/application-configuration", acc.updateApplicationConfigurationHandler)
group.GET("/application-configuration/logo", acc.getLogoHandler)
group.GET("/application-configuration/background-image", acc.getBackgroundImageHandler)
group.GET("/application-configuration/favicon", acc.getFaviconHandler)
group.PUT("/application-configuration/logo", jwtAuthMiddleware.Add(true), acc.updateLogoHandler)
group.PUT("/application-configuration/favicon", jwtAuthMiddleware.Add(true), acc.updateFaviconHandler)
group.PUT("/application-configuration/background-image", jwtAuthMiddleware.Add(true), acc.updateBackgroundImageHandler)
}
type ApplicationConfigurationController struct {
appConfigService *service.AppConfigService
}
func (acc *ApplicationConfigurationController) listApplicationConfigurationHandler(c *gin.Context) {
configuration, err := acc.appConfigService.ListApplicationConfiguration(false)
if err != nil {
utils.UnknownHandlerError(c, err)
return
}
c.JSON(200, configuration)
}
func (acc *ApplicationConfigurationController) listAllApplicationConfigurationHandler(c *gin.Context) {
configuration, err := acc.appConfigService.ListApplicationConfiguration(true)
if err != nil {
utils.UnknownHandlerError(c, err)
return
}
c.JSON(200, configuration)
}
func (acc *ApplicationConfigurationController) updateApplicationConfigurationHandler(c *gin.Context) {
var input model.AppConfigUpdateDto
if err := c.ShouldBindJSON(&input); err != nil {
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
return
}
savedConfigVariables, err := acc.appConfigService.UpdateApplicationConfiguration(input)
if err != nil {
utils.UnknownHandlerError(c, err)
return
}
c.JSON(http.StatusOK, savedConfigVariables)
}
func (acc *ApplicationConfigurationController) getLogoHandler(c *gin.Context) {
imageType := acc.appConfigService.DbConfig.LogoImageType.Value
acc.getImage(c, "logo", imageType)
}
func (acc *ApplicationConfigurationController) getFaviconHandler(c *gin.Context) {
acc.getImage(c, "favicon", "ico")
}
func (acc *ApplicationConfigurationController) getBackgroundImageHandler(c *gin.Context) {
imageType := acc.appConfigService.DbConfig.BackgroundImageType.Value
acc.getImage(c, "background", imageType)
}
func (acc *ApplicationConfigurationController) updateLogoHandler(c *gin.Context) {
imageType := acc.appConfigService.DbConfig.LogoImageType.Value
acc.updateImage(c, "logo", imageType)
}
func (acc *ApplicationConfigurationController) updateFaviconHandler(c *gin.Context) {
file, err := c.FormFile("file")
if err != nil {
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
return
}
fileType := utils.GetFileExtension(file.Filename)
if fileType != "ico" {
utils.HandlerError(c, http.StatusBadRequest, "File must be of type .ico")
return
}
acc.updateImage(c, "favicon", "ico")
}
func (acc *ApplicationConfigurationController) updateBackgroundImageHandler(c *gin.Context) {
imageType := acc.appConfigService.DbConfig.BackgroundImageType.Value
acc.updateImage(c, "background", imageType)
}
func (acc *ApplicationConfigurationController) getImage(c *gin.Context, name string, imageType string) {
imagePath := fmt.Sprintf("%s/application-images/%s.%s", common.EnvConfig.UploadPath, name, imageType)
mimeType := utils.GetImageMimeType(imageType)
c.Header("Content-Type", mimeType)
c.File(imagePath)
}
func (acc *ApplicationConfigurationController) updateImage(c *gin.Context, imageName string, oldImageType string) {
file, err := c.FormFile("file")
if err != nil {
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
return
}
err = acc.appConfigService.UpdateImage(file, imageName, oldImageType)
if err != nil {
if errors.Is(err, common.ErrFileTypeNotSupported) {
utils.HandlerError(c, http.StatusBadRequest, err.Error())
} else {
utils.UnknownHandlerError(c, err)
}
return
}
c.Status(http.StatusNoContent)
}

View File

@@ -0,0 +1,218 @@
package controller
import (
"errors"
"github.com/gin-gonic/gin"
"github.com/stonith404/pocket-id/backend/internal/common"
"github.com/stonith404/pocket-id/backend/internal/middleware"
"github.com/stonith404/pocket-id/backend/internal/model"
"github.com/stonith404/pocket-id/backend/internal/service"
"github.com/stonith404/pocket-id/backend/internal/utils"
"net/http"
"strconv"
)
func NewOidcController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.JwtAuthMiddleware, fileSizeLimitMiddleware *middleware.FileSizeLimitMiddleware, oidcService *service.OidcService) {
oc := &OidcController{OidcService: oidcService}
group.POST("/oidc/authorize", jwtAuthMiddleware.Add(false), oc.authorizeHandler)
group.POST("/oidc/authorize/new-client", jwtAuthMiddleware.Add(false), oc.authorizeNewClientHandler)
group.POST("/oidc/token", oc.createIDTokenHandler)
group.GET("/oidc/clients", jwtAuthMiddleware.Add(true), oc.listClientsHandler)
group.POST("/oidc/clients", jwtAuthMiddleware.Add(true), oc.createClientHandler)
group.GET("/oidc/clients/:id", oc.getClientHandler)
group.PUT("/oidc/clients/:id", jwtAuthMiddleware.Add(true), oc.updateClientHandler)
group.DELETE("/oidc/clients/:id", jwtAuthMiddleware.Add(true), oc.deleteClientHandler)
group.POST("/oidc/clients/:id/secret", jwtAuthMiddleware.Add(true), oc.createClientSecretHandler)
group.GET("/oidc/clients/:id/logo", oc.getClientLogoHandler)
group.DELETE("/oidc/clients/:id/logo", oc.deleteClientLogoHandler)
group.POST("/oidc/clients/:id/logo", jwtAuthMiddleware.Add(true), fileSizeLimitMiddleware.Add(2<<20), oc.updateClientLogoHandler)
}
type OidcController struct {
OidcService *service.OidcService
}
func (oc *OidcController) authorizeHandler(c *gin.Context) {
var parsedBody model.AuthorizeRequest
if err := c.ShouldBindJSON(&parsedBody); err != nil {
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
return
}
code, err := oc.OidcService.Authorize(parsedBody, c.GetString("userID"))
if err != nil {
if errors.Is(err, common.ErrOidcMissingAuthorization) {
utils.HandlerError(c, http.StatusForbidden, err.Error())
} else {
utils.UnknownHandlerError(c, err)
}
return
}
c.JSON(http.StatusOK, gin.H{"code": code})
}
func (oc *OidcController) authorizeNewClientHandler(c *gin.Context) {
var parsedBody model.AuthorizeNewClientDto
if err := c.ShouldBindJSON(&parsedBody); err != nil {
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
return
}
code, err := oc.OidcService.AuthorizeNewClient(parsedBody, c.GetString("userID"))
if err != nil {
utils.UnknownHandlerError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{"code": code})
}
func (oc *OidcController) createIDTokenHandler(c *gin.Context) {
var body model.OidcIdTokenDto
if err := c.ShouldBind(&body); err != nil {
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
return
}
idToken, err := oc.OidcService.CreateIDToken(body)
if err != nil {
if errors.Is(err, common.ErrOidcGrantTypeNotSupported) ||
errors.Is(err, common.ErrOidcMissingClientCredentials) ||
errors.Is(err, common.ErrOidcClientSecretInvalid) ||
errors.Is(err, common.ErrOidcInvalidAuthorizationCode) {
utils.HandlerError(c, http.StatusBadRequest, err.Error())
} else {
utils.UnknownHandlerError(c, err)
}
return
}
c.JSON(http.StatusOK, gin.H{"id_token": idToken})
}
func (oc *OidcController) getClientHandler(c *gin.Context) {
clientId := c.Param("id")
client, err := oc.OidcService.GetClient(clientId)
if err != nil {
utils.UnknownHandlerError(c, err)
return
}
c.JSON(http.StatusOK, client)
}
func (oc *OidcController) listClientsHandler(c *gin.Context) {
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("limit", "10"))
searchTerm := c.Query("search")
clients, pagination, err := oc.OidcService.ListClients(searchTerm, page, pageSize)
if err != nil {
utils.UnknownHandlerError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"data": clients,
"pagination": pagination,
})
}
func (oc *OidcController) createClientHandler(c *gin.Context) {
var input model.OidcClientCreateDto
if err := c.ShouldBindJSON(&input); err != nil {
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
return
}
client, err := oc.OidcService.CreateClient(input, c.GetString("userID"))
if err != nil {
utils.UnknownHandlerError(c, err)
return
}
c.JSON(http.StatusCreated, client)
}
func (oc *OidcController) deleteClientHandler(c *gin.Context) {
err := oc.OidcService.DeleteClient(c.Param("id"))
if err != nil {
utils.HandlerError(c, http.StatusNotFound, "OIDC client not found")
return
}
c.Status(http.StatusNoContent)
}
func (oc *OidcController) updateClientHandler(c *gin.Context) {
var input model.OidcClientCreateDto
if err := c.ShouldBindJSON(&input); err != nil {
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
return
}
client, err := oc.OidcService.UpdateClient(c.Param("id"), input)
if err != nil {
utils.UnknownHandlerError(c, err)
return
}
c.JSON(http.StatusNoContent, client)
}
func (oc *OidcController) createClientSecretHandler(c *gin.Context) {
secret, err := oc.OidcService.CreateClientSecret(c.Param("id"))
if err != nil {
utils.UnknownHandlerError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{"secret": secret})
}
func (oc *OidcController) getClientLogoHandler(c *gin.Context) {
imagePath, mimeType, err := oc.OidcService.GetClientLogo(c.Param("id"))
if err != nil {
utils.UnknownHandlerError(c, err)
return
}
c.Header("Content-Type", mimeType)
c.File(imagePath)
}
func (oc *OidcController) updateClientLogoHandler(c *gin.Context) {
file, err := c.FormFile("file")
if err != nil {
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
return
}
err = oc.OidcService.UpdateClientLogo(c.Param("id"), file)
if err != nil {
if errors.Is(err, common.ErrFileTypeNotSupported) {
utils.HandlerError(c, http.StatusBadRequest, err.Error())
} else {
utils.UnknownHandlerError(c, err)
}
return
}
c.Status(http.StatusNoContent)
}
func (oc *OidcController) deleteClientLogoHandler(c *gin.Context) {
err := oc.OidcService.DeleteClientLogo(c.Param("id"))
if err != nil {
utils.UnknownHandlerError(c, err)
return
}
c.Status(http.StatusNoContent)
}

View File

@@ -0,0 +1,36 @@
package controller
import (
"github.com/gin-gonic/gin"
"github.com/stonith404/pocket-id/backend/internal/service"
"github.com/stonith404/pocket-id/backend/internal/utils"
)
func NewTestController(group *gin.RouterGroup, testService *service.TestService) {
testController := &TestController{TestService: testService}
group.POST("/test/reset", testController.resetAndSeedHandler)
}
type TestController struct {
TestService *service.TestService
}
func (tc *TestController) resetAndSeedHandler(c *gin.Context) {
if err := tc.TestService.ResetDatabase(); err != nil {
utils.UnknownHandlerError(c, err)
return
}
if err := tc.TestService.ResetApplicationImages(); err != nil {
utils.UnknownHandlerError(c, err)
return
}
if err := tc.TestService.SeedDatabase(); err != nil {
utils.UnknownHandlerError(c, err)
return
}
c.JSON(200, gin.H{"message": "Database reset and seeded"})
}

View File

@@ -0,0 +1,182 @@
package controller
import (
"errors"
"github.com/gin-gonic/gin"
"github.com/stonith404/pocket-id/backend/internal/common"
"github.com/stonith404/pocket-id/backend/internal/middleware"
"github.com/stonith404/pocket-id/backend/internal/model"
"github.com/stonith404/pocket-id/backend/internal/service"
"github.com/stonith404/pocket-id/backend/internal/utils"
"golang.org/x/time/rate"
"net/http"
"strconv"
"time"
)
func NewUserController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.JwtAuthMiddleware, rateLimitMiddleware *middleware.RateLimitMiddleware, userService *service.UserService) {
uc := UserController{
UserService: userService,
}
group.GET("/users", jwtAuthMiddleware.Add(true), uc.listUsersHandler)
group.GET("/users/me", jwtAuthMiddleware.Add(false), uc.getCurrentUserHandler)
group.GET("/users/:id", jwtAuthMiddleware.Add(true), uc.getUserHandler)
group.POST("/users", jwtAuthMiddleware.Add(true), uc.createUserHandler)
group.PUT("/users/:id", jwtAuthMiddleware.Add(true), uc.updateUserHandler)
group.PUT("/users/me", jwtAuthMiddleware.Add(false), uc.updateCurrentUserHandler)
group.DELETE("/users/:id", jwtAuthMiddleware.Add(true), uc.deleteUserHandler)
group.POST("/users/:id/one-time-access-token", jwtAuthMiddleware.Add(true), uc.createOneTimeAccessTokenHandler)
group.POST("/one-time-access-token/:token", rateLimitMiddleware.Add(rate.Every(10*time.Second), 5), uc.exchangeOneTimeAccessTokenHandler)
group.POST("/one-time-access-token/setup", uc.getSetupAccessTokenHandler)
}
type UserController struct {
UserService *service.UserService
}
func (uc *UserController) listUsersHandler(c *gin.Context) {
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("limit", "10"))
searchTerm := c.Query("search")
users, pagination, err := uc.UserService.ListUsers(searchTerm, page, pageSize)
if err != nil {
utils.UnknownHandlerError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"data": users,
"pagination": pagination,
})
}
func (uc *UserController) getUserHandler(c *gin.Context) {
user, err := uc.UserService.GetUser(c.Param("id"))
if err != nil {
utils.UnknownHandlerError(c, err)
return
}
c.JSON(http.StatusOK, user)
}
func (uc *UserController) getCurrentUserHandler(c *gin.Context) {
user, err := uc.UserService.GetUser(c.GetString("userID"))
if err != nil {
utils.UnknownHandlerError(c, err)
return
}
c.JSON(http.StatusOK, user)
}
func (uc *UserController) deleteUserHandler(c *gin.Context) {
if err := uc.UserService.DeleteUser(c.Param("id")); err != nil {
utils.UnknownHandlerError(c, err)
return
}
c.Status(http.StatusNoContent)
}
func (uc *UserController) createUserHandler(c *gin.Context) {
var user model.User
if err := c.ShouldBindJSON(&user); err != nil {
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
return
}
if err := uc.UserService.CreateUser(&user); err != nil {
if errors.Is(err, common.ErrEmailTaken) || errors.Is(err, common.ErrUsernameTaken) {
utils.HandlerError(c, http.StatusConflict, err.Error())
} else {
utils.UnknownHandlerError(c, err)
}
return
}
c.JSON(http.StatusCreated, user)
}
func (uc *UserController) updateUserHandler(c *gin.Context) {
uc.updateUser(c, false)
}
func (uc *UserController) updateCurrentUserHandler(c *gin.Context) {
uc.updateUser(c, true)
}
func (uc *UserController) createOneTimeAccessTokenHandler(c *gin.Context) {
var input model.OneTimeAccessTokenCreateDto
if err := c.ShouldBindJSON(&input); err != nil {
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
return
}
token, err := uc.UserService.CreateOneTimeAccessToken(input.UserID, input.ExpiresAt)
if err != nil {
utils.UnknownHandlerError(c, err)
return
}
c.JSON(http.StatusCreated, gin.H{"token": token})
}
func (uc *UserController) exchangeOneTimeAccessTokenHandler(c *gin.Context) {
user, token, err := uc.UserService.ExchangeOneTimeAccessToken(c.Param("token"))
if err != nil {
if errors.Is(err, common.ErrTokenInvalidOrExpired) {
utils.HandlerError(c, http.StatusUnauthorized, err.Error())
} else {
utils.UnknownHandlerError(c, err)
}
return
}
c.SetCookie("access_token", token, int(time.Hour.Seconds()), "/", "", false, true)
c.JSON(http.StatusOK, user)
}
func (uc *UserController) getSetupAccessTokenHandler(c *gin.Context) {
user, token, err := uc.UserService.SetupInitialAdmin()
if err != nil {
if errors.Is(err, common.ErrSetupAlreadyCompleted) {
utils.HandlerError(c, http.StatusBadRequest, err.Error())
} else {
utils.UnknownHandlerError(c, err)
}
return
}
c.SetCookie("access_token", token, int(time.Hour.Seconds()), "/", "", false, true)
c.JSON(http.StatusOK, user)
}
func (uc *UserController) updateUser(c *gin.Context, updateOwnUser bool) {
var updatedUser model.User
if err := c.ShouldBindJSON(&updatedUser); err != nil {
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
return
}
var userID string
if updateOwnUser {
userID = c.GetString("userID")
} else {
userID = c.Param("id")
}
user, err := uc.UserService.UpdateUser(userID, updatedUser, updateOwnUser)
if err != nil {
if errors.Is(err, common.ErrEmailTaken) || errors.Is(err, common.ErrUsernameTaken) {
utils.HandlerError(c, http.StatusConflict, err.Error())
} else {
utils.UnknownHandlerError(c, err)
}
return
}
c.JSON(http.StatusOK, user)
}

View File

@@ -0,0 +1,160 @@
package controller
import (
"errors"
"github.com/go-webauthn/webauthn/protocol"
"github.com/stonith404/pocket-id/backend/internal/middleware"
"github.com/stonith404/pocket-id/backend/internal/model"
"log"
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/stonith404/pocket-id/backend/internal/common"
"github.com/stonith404/pocket-id/backend/internal/service"
"github.com/stonith404/pocket-id/backend/internal/utils"
"golang.org/x/time/rate"
)
func NewWebauthnController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.JwtAuthMiddleware, rateLimitMiddleware *middleware.RateLimitMiddleware, webauthnService *service.WebAuthnService, jwtService *service.JwtService) {
wc := &WebauthnController{webAuthnService: webauthnService, jwtService: jwtService}
group.GET("/webauthn/register/start", jwtAuthMiddleware.Add(false), wc.beginRegistrationHandler)
group.POST("/webauthn/register/finish", jwtAuthMiddleware.Add(false), wc.verifyRegistrationHandler)
group.GET("/webauthn/login/start", wc.beginLoginHandler)
group.POST("/webauthn/login/finish", rateLimitMiddleware.Add(rate.Every(10*time.Second), 5), wc.verifyLoginHandler)
group.POST("/webauthn/logout", jwtAuthMiddleware.Add(false), wc.logoutHandler)
group.GET("/webauthn/credentials", jwtAuthMiddleware.Add(false), wc.listCredentialsHandler)
group.PATCH("/webauthn/credentials/:id", jwtAuthMiddleware.Add(false), wc.updateCredentialHandler)
group.DELETE("/webauthn/credentials/:id", jwtAuthMiddleware.Add(false), wc.deleteCredentialHandler)
}
type WebauthnController struct {
webAuthnService *service.WebAuthnService
jwtService *service.JwtService
}
func (wc *WebauthnController) beginRegistrationHandler(c *gin.Context) {
userID := c.GetString("userID")
options, err := wc.webAuthnService.BeginRegistration(userID)
if err != nil {
utils.UnknownHandlerError(c, err)
log.Println(err)
return
}
c.SetCookie("session_id", options.SessionID, int(options.Timeout.Seconds()), "/", "", false, true)
c.JSON(http.StatusOK, options.Response)
}
func (wc *WebauthnController) verifyRegistrationHandler(c *gin.Context) {
sessionID, err := c.Cookie("session_id")
if err != nil {
utils.HandlerError(c, http.StatusBadRequest, "Session ID missing")
return
}
userID := c.GetString("userID")
credential, err := wc.webAuthnService.VerifyRegistration(sessionID, userID, c.Request)
if err != nil {
utils.UnknownHandlerError(c, err)
return
}
c.JSON(http.StatusOK, credential)
}
func (wc *WebauthnController) beginLoginHandler(c *gin.Context) {
options, err := wc.webAuthnService.BeginLogin()
if err != nil {
utils.UnknownHandlerError(c, err)
return
}
c.SetCookie("session_id", options.SessionID, int(options.Timeout.Seconds()), "/", "", false, true)
c.JSON(http.StatusOK, options.Response)
}
func (wc *WebauthnController) verifyLoginHandler(c *gin.Context) {
sessionID, err := c.Cookie("session_id")
if err != nil {
utils.HandlerError(c, http.StatusBadRequest, "Session ID missing")
return
}
credentialAssertionData, err := protocol.ParseCredentialRequestResponseBody(c.Request.Body)
if err != nil {
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
return
}
userID := c.GetString("userID")
user, err := wc.webAuthnService.VerifyLogin(sessionID, userID, credentialAssertionData)
if err != nil {
if errors.Is(err, common.ErrInvalidCredentials) {
utils.HandlerError(c, http.StatusUnauthorized, err.Error())
} else {
utils.UnknownHandlerError(c, err)
}
return
}
token, err := wc.jwtService.GenerateAccessToken(*user)
if err != nil {
utils.UnknownHandlerError(c, err)
return
}
c.SetCookie("access_token", token, int(time.Hour.Seconds()), "/", "", false, true)
c.JSON(http.StatusOK, user)
}
func (wc *WebauthnController) listCredentialsHandler(c *gin.Context) {
userID := c.GetString("userID")
credentials, err := wc.webAuthnService.ListCredentials(userID)
if err != nil {
utils.UnknownHandlerError(c, err)
return
}
c.JSON(http.StatusOK, credentials)
}
func (wc *WebauthnController) deleteCredentialHandler(c *gin.Context) {
userID := c.GetString("userID")
credentialID := c.Param("id")
err := wc.webAuthnService.DeleteCredential(userID, credentialID)
if err != nil {
utils.UnknownHandlerError(c, err)
return
}
c.Status(http.StatusNoContent)
}
func (wc *WebauthnController) updateCredentialHandler(c *gin.Context) {
userID := c.GetString("userID")
credentialID := c.Param("id")
var input model.WebauthnCredentialUpdateDto
if err := c.ShouldBindJSON(&input); err != nil {
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
return
}
err := wc.webAuthnService.UpdateCredential(userID, credentialID, input.Name)
if err != nil {
utils.UnknownHandlerError(c, err)
return
}
c.Status(http.StatusNoContent)
}
func (wc *WebauthnController) logoutHandler(c *gin.Context) {
c.SetCookie("access_token", "", 0, "/", "", false, true)
c.Status(http.StatusNoContent)
}

View File

@@ -1,19 +1,25 @@
package handler
package controller
import (
"github.com/gin-gonic/gin"
"golang-rest-api-template/internal/common"
"golang-rest-api-template/internal/utils"
"github.com/stonith404/pocket-id/backend/internal/common"
"github.com/stonith404/pocket-id/backend/internal/service"
"github.com/stonith404/pocket-id/backend/internal/utils"
"net/http"
)
func RegisterWellKnownRoutes(group *gin.RouterGroup) {
group.GET("/.well-known/jwks.json", jwks)
group.GET("/.well-known/openid-configuration", openIDConfiguration)
func NewWellKnownController(group *gin.RouterGroup, jwtService *service.JwtService) {
wkc := &WellKnownController{jwtService: jwtService}
group.GET("/.well-known/jwks.json", wkc.jwksHandler)
group.GET("/.well-known/openid-configuration", wkc.openIDConfigurationHandler)
}
func jwks(c *gin.Context) {
jwk, err := common.GetJWK()
type WellKnownController struct {
jwtService *service.JwtService
}
func (wkc *WellKnownController) jwksHandler(c *gin.Context) {
jwk, err := wkc.jwtService.GetJWK()
if err != nil {
utils.UnknownHandlerError(c, err)
return
@@ -22,7 +28,7 @@ func jwks(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"keys": []interface{}{jwk}})
}
func openIDConfiguration(c *gin.Context) {
func (wkc *WellKnownController) openIDConfigurationHandler(c *gin.Context) {
appUrl := common.EnvConfig.AppURL
config := map[string]interface{}{
"issuer": appUrl,

View File

@@ -1,196 +0,0 @@
package handler
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"golang-rest-api-template/internal/common"
"golang-rest-api-template/internal/common/middleware"
"golang-rest-api-template/internal/model"
"golang-rest-api-template/internal/utils"
"gorm.io/gorm"
"net/http"
"os"
"reflect"
)
func RegisterConfigurationRoutes(group *gin.RouterGroup) {
group.GET("/application-configuration", listApplicationConfigurationHandler)
group.GET("/application-configuration/all", middleware.JWTAuth(true), listAllApplicationConfigurationHandler)
group.PUT("/application-configuration", updateApplicationConfigurationHandler)
group.GET("/application-configuration/logo", getLogoHandler)
group.GET("/application-configuration/background-image", getBackgroundImageHandler)
group.GET("/application-configuration/favicon", getFaviconHandler)
group.PUT("/application-configuration/logo", middleware.JWTAuth(true), updateLogoHandler)
group.PUT("/application-configuration/favicon", middleware.JWTAuth(true), updateFaviconHandler)
group.PUT("/application-configuration/background-image", middleware.JWTAuth(true), updateBackgroundImageHandler)
}
func listApplicationConfigurationHandler(c *gin.Context) {
listApplicationConfiguration(c, false)
}
func listAllApplicationConfigurationHandler(c *gin.Context) {
listApplicationConfiguration(c, true)
}
func updateApplicationConfigurationHandler(c *gin.Context) {
var input model.ApplicationConfigurationUpdateDto
if err := c.ShouldBindJSON(&input); err != nil {
utils.HandlerError(c, http.StatusBadRequest, "invalid request body")
return
}
savedConfigVariables := make([]model.ApplicationConfigurationVariable, 10)
tx := common.DB.Begin()
rt := reflect.ValueOf(input).Type()
rv := reflect.ValueOf(input)
// Loop over the input struct fields and update the related configuration variables
for i := 0; i < rt.NumField(); i++ {
field := rt.Field(i)
key := field.Tag.Get("json")
value := rv.FieldByName(field.Name).String()
// Get the existing configuration variable from the db
var applicationConfigurationVariable model.ApplicationConfigurationVariable
if err := tx.First(&applicationConfigurationVariable, "key = ? AND is_internal = false", key).Error; err != nil {
tx.Rollback()
if errors.Is(err, gorm.ErrRecordNotFound) {
utils.HandlerError(c, http.StatusNotFound, fmt.Sprintf("Invalid configuration variable '%s'", value))
} else {
utils.UnknownHandlerError(c, err)
}
return
}
// Update the value of the existing configuration variable and save it
applicationConfigurationVariable.Value = value
if err := tx.Save(&applicationConfigurationVariable).Error; err != nil {
tx.Rollback()
utils.UnknownHandlerError(c, err)
return
}
savedConfigVariables[i] = applicationConfigurationVariable
}
tx.Commit()
if err := common.LoadDbConfigFromDb(); err != nil {
utils.UnknownHandlerError(c, err)
}
c.JSON(http.StatusOK, savedConfigVariables)
}
func getLogoHandler(c *gin.Context) {
imagType := common.DbConfig.LogoImageType.Value
getImage(c, "logo", imagType)
}
func getFaviconHandler(c *gin.Context) {
getImage(c, "favicon", "ico")
}
func getBackgroundImageHandler(c *gin.Context) {
imageType := common.DbConfig.BackgroundImageType.Value
getImage(c, "background", imageType)
}
func updateLogoHandler(c *gin.Context) {
imageType := common.DbConfig.LogoImageType.Value
updateImage(c, "logo", imageType)
}
func updateFaviconHandler(c *gin.Context) {
file, err := c.FormFile("file")
if err != nil {
utils.HandlerError(c, http.StatusBadRequest, "invalid request body")
return
}
fileType := utils.GetFileExtension(file.Filename)
if fileType != "ico" {
utils.HandlerError(c, http.StatusBadRequest, "File must be of type .ico")
return
}
updateImage(c, "favicon", "ico")
}
func updateBackgroundImageHandler(c *gin.Context) {
imagType := common.DbConfig.BackgroundImageType.Value
updateImage(c, "background", imagType)
}
func getImage(c *gin.Context, name string, imageType string) {
imagePath := fmt.Sprintf("%s/application-images/%s.%s", common.EnvConfig.UploadPath, name, imageType)
mimeType := utils.GetImageMimeType(imageType)
c.Header("Content-Type", mimeType)
c.File(imagePath)
}
func updateImage(c *gin.Context, imageName string, oldImageType string) {
file, err := c.FormFile("file")
if err != nil {
utils.HandlerError(c, http.StatusBadRequest, "invalid request body")
return
}
fileType := utils.GetFileExtension(file.Filename)
if mimeType := utils.GetImageMimeType(fileType); mimeType == "" {
utils.HandlerError(c, http.StatusBadRequest, "File type not supported")
return
}
// Delete the old image if it has a different file type
if fileType != oldImageType {
oldImagePath := fmt.Sprintf("%s/application-images/%s.%s", common.EnvConfig.UploadPath, imageName, oldImageType)
if err := os.Remove(oldImagePath); err != nil {
utils.UnknownHandlerError(c, err)
return
}
}
imagePath := fmt.Sprintf("%s/application-images/%s.%s", common.EnvConfig.UploadPath, imageName, fileType)
err = c.SaveUploadedFile(file, imagePath)
if err != nil {
utils.UnknownHandlerError(c, err)
return
}
// Update the file type in the database
key := fmt.Sprintf("%sImageType", imageName)
err = common.DB.Model(&model.ApplicationConfigurationVariable{}).Where("key = ?", key).Update("value", fileType).Error
if err != nil {
utils.UnknownHandlerError(c, err)
return
}
if err := common.LoadDbConfigFromDb(); err != nil {
utils.UnknownHandlerError(c, err)
}
c.Status(http.StatusNoContent)
}
func listApplicationConfiguration(c *gin.Context, showAll bool) {
var configuration []model.ApplicationConfigurationVariable
var err error
if showAll {
err = common.DB.Find(&configuration).Error
} else {
err = common.DB.Find(&configuration, "is_public = true").Error
}
if err != nil {
utils.UnknownHandlerError(c, err)
return
}
c.JSON(200, configuration)
}

View File

@@ -1,415 +0,0 @@
package handler
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"golang-rest-api-template/internal/common"
"golang-rest-api-template/internal/common/middleware"
"golang-rest-api-template/internal/model"
"golang-rest-api-template/internal/utils"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
"net/http"
"os"
"time"
)
func RegisterOIDCRoutes(group *gin.RouterGroup) {
group.POST("/oidc/authorize", middleware.JWTAuth(false), authorizeHandler)
group.POST("/oidc/authorize/new-client", middleware.JWTAuth(false), authorizeNewClientHandler)
group.POST("/oidc/token", createIDTokenHandler)
group.GET("/oidc/clients", middleware.JWTAuth(true), listClientsHandler)
group.POST("/oidc/clients", middleware.JWTAuth(true), createClientHandler)
group.GET("/oidc/clients/:id", getClientHandler)
group.PUT("/oidc/clients/:id", middleware.JWTAuth(true), updateClientHandler)
group.DELETE("/oidc/clients/:id", middleware.JWTAuth(true), deleteClientHandler)
group.POST("/oidc/clients/:id/secret", middleware.JWTAuth(true), createClientSecretHandler)
group.GET("/oidc/clients/:id/logo", getClientLogoHandler)
group.DELETE("/oidc/clients/:id/logo", deleteClientLogoHandler)
group.POST("/oidc/clients/:id/logo", middleware.JWTAuth(true), middleware.LimitFileSize(2<<20), updateClientLogoHandler)
}
type AuthorizeRequest struct {
ClientID string `json:"clientID" binding:"required"`
Scope string `json:"scope" binding:"required"`
Nonce string `json:"nonce"`
}
func authorizeHandler(c *gin.Context) {
var parsedBody AuthorizeRequest
if err := c.ShouldBindJSON(&parsedBody); err != nil {
utils.HandlerError(c, http.StatusBadRequest, "invalid request body")
return
}
var userAuthorizedOIDCClient model.UserAuthorizedOidcClient
common.DB.First(&userAuthorizedOIDCClient, "client_id = ? AND user_id = ?", parsedBody.ClientID, c.GetString("userID"))
// If the record isn't found or the scope is different return an error
// The client will have to call the authorizeNewClientHandler
if userAuthorizedOIDCClient.Scope != parsedBody.Scope {
utils.HandlerError(c, http.StatusForbidden, "missing authorization")
return
}
authorizationCode, err := createAuthorizationCode(parsedBody.ClientID, c.GetString("userID"), parsedBody.Scope, parsedBody.Nonce)
if err != nil {
utils.UnknownHandlerError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{"code": authorizationCode})
}
// authorizeNewClientHandler authorizes a new client for the user
// a new client is a new client when the user has not authorized the client before
func authorizeNewClientHandler(c *gin.Context) {
var parsedBody model.AuthorizeNewClientDto
if err := c.ShouldBindJSON(&parsedBody); err != nil {
utils.HandlerError(c, http.StatusBadRequest, "invalid request body")
return
}
userAuthorizedClient := model.UserAuthorizedOidcClient{
UserID: c.GetString("userID"),
ClientID: parsedBody.ClientID,
Scope: parsedBody.Scope,
}
err := common.DB.Create(&userAuthorizedClient).Error
if err != nil && errors.Is(err, gorm.ErrDuplicatedKey) {
err = common.DB.Model(&userAuthorizedClient).Update("scope", parsedBody.Scope).Error
}
if err != nil {
utils.UnknownHandlerError(c, err)
return
}
authorizationCode, err := createAuthorizationCode(parsedBody.ClientID, c.GetString("userID"), parsedBody.Scope, parsedBody.Nonce)
if err != nil {
utils.UnknownHandlerError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{"code": authorizationCode})
}
func createIDTokenHandler(c *gin.Context) {
var body model.OidcIdTokenDto
if err := c.ShouldBind(&body); err != nil {
utils.HandlerError(c, http.StatusBadRequest, "invalid request body")
return
}
// Currently only authorization_code grant type is supported
if body.GrantType != "authorization_code" {
utils.HandlerError(c, http.StatusBadRequest, "grant type not supported")
return
}
clientID := body.ClientID
clientSecret := body.ClientSecret
// Client id and secret can also be passed over the Authorization header
if clientID == "" || clientSecret == "" {
var ok bool
clientID, clientSecret, ok = c.Request.BasicAuth()
if !ok {
utils.HandlerError(c, http.StatusBadRequest, "Client id and secret not provided")
return
}
}
// Get the client
var client model.OidcClient
err := common.DB.First(&client, "id = ?", clientID, clientSecret).Error
if err != nil {
utils.HandlerError(c, http.StatusBadRequest, "OIDC OIDC client not found")
return
}
// Check if client secret is correct
err = bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret))
if err != nil {
utils.HandlerError(c, http.StatusBadRequest, "invalid client secret")
return
}
var authorizationCodeMetaData model.OidcAuthorizationCode
err = common.DB.Preload("User").First(&authorizationCodeMetaData, "code = ?", body.Code).Error
if err != nil {
utils.HandlerError(c, http.StatusBadRequest, "invalid authorization code")
return
}
// Check if the client id matches the client id in the authorization code and if the code has expired
if authorizationCodeMetaData.ClientID != clientID && authorizationCodeMetaData.ExpiresAt.Before(time.Now()) {
utils.HandlerError(c, http.StatusBadRequest, "invalid authorization code")
return
}
idToken, e := common.GenerateIDToken(authorizationCodeMetaData.User, clientID, authorizationCodeMetaData.Scope, authorizationCodeMetaData.Nonce)
if e != nil {
utils.UnknownHandlerError(c, err)
return
}
// Delete the authorization code after it has been used
common.DB.Delete(&authorizationCodeMetaData)
c.JSON(http.StatusOK, gin.H{"id_token": idToken})
}
func getClientHandler(c *gin.Context) {
clientId := c.Param("id")
var client model.OidcClient
err := common.DB.First(&client, "id = ?", clientId).Error
if err != nil {
utils.HandlerError(c, http.StatusNotFound, "OIDC client not found")
return
}
c.JSON(http.StatusOK, client)
}
func listClientsHandler(c *gin.Context) {
var clients []model.OidcClient
searchTerm := c.Query("search")
query := common.DB.Model(&model.OidcClient{})
if searchTerm != "" {
searchPattern := "%" + searchTerm + "%"
query = query.Where("name LIKE ?", searchPattern)
}
pagination, err := utils.Paginate(c, query, &clients)
if err != nil {
utils.UnknownHandlerError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"data": clients,
"pagination": pagination,
})
}
func createClientHandler(c *gin.Context) {
var input model.OidcClientCreateDto
if err := c.ShouldBindJSON(&input); err != nil {
utils.HandlerError(c, http.StatusBadRequest, "invalid request body")
return
}
client := model.OidcClient{
Name: input.Name,
CallbackURL: input.CallbackURL,
CreatedByID: c.GetString("userID"),
}
if err := common.DB.Create(&client).Error; err != nil {
utils.UnknownHandlerError(c, err)
return
}
c.JSON(http.StatusCreated, client)
}
func deleteClientHandler(c *gin.Context) {
var client model.OidcClient
if err := common.DB.First(&client, "id = ?", c.Param("id")).Error; err != nil {
utils.HandlerError(c, http.StatusNotFound, "OIDC OIDC client not found")
return
}
if err := common.DB.Delete(&client).Error; err != nil {
utils.UnknownHandlerError(c, err)
return
}
c.Status(http.StatusNoContent)
}
func updateClientHandler(c *gin.Context) {
var input model.OidcClientCreateDto
if err := c.ShouldBindJSON(&input); err != nil {
utils.HandlerError(c, http.StatusBadRequest, "invalid request body")
return
}
var client model.OidcClient
if err := common.DB.First(&client, "id = ?", c.Param("id")).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
utils.HandlerError(c, http.StatusNotFound, "OIDC client not found")
return
}
utils.UnknownHandlerError(c, err)
return
}
client.Name = input.Name
client.CallbackURL = input.CallbackURL
if err := common.DB.Save(&client).Error; err != nil {
utils.UnknownHandlerError(c, err)
return
}
c.JSON(http.StatusNoContent, client)
}
// createClientSecretHandler creates a new secret for the client and revokes the old one
func createClientSecretHandler(c *gin.Context) {
var client model.OidcClient
if err := common.DB.First(&client, "id = ?", c.Param("id")).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
utils.HandlerError(c, http.StatusNotFound, "OIDC client not found")
return
}
utils.UnknownHandlerError(c, err)
return
}
clientSecret, err := utils.GenerateRandomAlphanumericString(32)
if err != nil {
utils.UnknownHandlerError(c, err)
return
}
hashedSecret, err := bcrypt.GenerateFromPassword([]byte(clientSecret), bcrypt.DefaultCost)
if err != nil {
utils.UnknownHandlerError(c, err)
return
}
client.Secret = string(hashedSecret)
if err := common.DB.Save(&client).Error; err != nil {
utils.UnknownHandlerError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{"secret": clientSecret})
}
func getClientLogoHandler(c *gin.Context) {
var client model.OidcClient
if err := common.DB.First(&client, "id = ?", c.Param("id")).Error; err != nil {
utils.HandlerError(c, http.StatusNotFound, "OIDC client not found")
return
}
if client.ImageType == nil {
utils.HandlerError(c, http.StatusNotFound, "image not found")
return
}
imageType := *client.ImageType
imagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, client.ID, imageType)
mimeType := utils.GetImageMimeType(imageType)
c.Header("Content-Type", mimeType)
c.File(imagePath)
}
func updateClientLogoHandler(c *gin.Context) {
file, err := c.FormFile("file")
if err != nil {
utils.HandlerError(c, http.StatusBadRequest, "invalid request body")
return
}
fileType := utils.GetFileExtension(file.Filename)
if mimeType := utils.GetImageMimeType(fileType); mimeType == "" {
utils.HandlerError(c, http.StatusBadRequest, "file type not supported")
return
}
imagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, c.Param("id"), fileType)
err = c.SaveUploadedFile(file, imagePath)
if err != nil {
utils.UnknownHandlerError(c, err)
return
}
var client model.OidcClient
if err := common.DB.First(&client, "id = ?", c.Param("id")).Error; err != nil {
utils.HandlerError(c, http.StatusNotFound, "OIDC client not found")
return
}
// Delete the old image if it has a different file type
if client.ImageType != nil && fileType != *client.ImageType {
oldImagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, client.ID, *client.ImageType)
if err := os.Remove(oldImagePath); err != nil {
utils.UnknownHandlerError(c, err)
return
}
}
client.ImageType = &fileType
if err := common.DB.Save(&client).Error; err != nil {
utils.UnknownHandlerError(c, err)
return
}
c.Status(http.StatusNoContent)
}
func deleteClientLogoHandler(c *gin.Context) {
var client model.OidcClient
if err := common.DB.First(&client, "id = ?", c.Param("id")).Error; err != nil {
utils.HandlerError(c, http.StatusNotFound, "OIDC client not found")
return
}
if client.ImageType == nil {
utils.HandlerError(c, http.StatusNotFound, "image not found")
return
}
imagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, client.ID, *client.ImageType)
if err := os.Remove(imagePath); err != nil {
utils.UnknownHandlerError(c, err)
return
}
client.ImageType = nil
if err := common.DB.Save(&client).Error; err != nil {
utils.UnknownHandlerError(c, err)
return
}
c.Status(http.StatusNoContent)
}
func createAuthorizationCode(clientID string, userID string, scope string, nonce string) (string, error) {
randomString, err := utils.GenerateRandomAlphanumericString(32)
if err != nil {
return "", err
}
oidcAuthorizationCode := model.OidcAuthorizationCode{
ExpiresAt: time.Now().Add(15 * time.Minute),
Code: randomString,
ClientID: clientID,
UserID: userID,
Scope: scope,
Nonce: nonce,
}
if err := common.DB.Create(&oidcAuthorizationCode).Error; err != nil {
return "", err
}
return randomString, nil
}

View File

@@ -1,276 +0,0 @@
package handler
import (
"errors"
"github.com/gin-gonic/gin"
"golang-rest-api-template/internal/common"
"golang-rest-api-template/internal/common/middleware"
"golang-rest-api-template/internal/model"
"golang-rest-api-template/internal/utils"
"golang.org/x/time/rate"
"gorm.io/gorm"
"log"
"net/http"
"time"
)
func RegisterUserRoutes(group *gin.RouterGroup) {
group.GET("/users", middleware.JWTAuth(true), listUsersHandler)
group.GET("/users/me", middleware.JWTAuth(false), getCurrentUserHandler)
group.GET("/users/:id", middleware.JWTAuth(true), getUserHandler)
group.POST("/users", middleware.JWTAuth(true), createUserHandler)
group.PUT("/users/:id", middleware.JWTAuth(true), updateUserHandler)
group.PUT("/users/me", middleware.JWTAuth(false), updateCurrentUserHandler)
group.DELETE("/users/:id", middleware.JWTAuth(true), deleteUserHandler)
group.POST("/users/:id/one-time-access-token", middleware.JWTAuth(true), createOneTimeAccessTokenHandler)
group.POST("/one-time-access-token/:token", middleware.RateLimiter(rate.Every(10*time.Second), 5), exchangeOneTimeAccessTokenHandler)
group.POST("/one-time-access-token/setup", getSetupAccessTokenHandler)
}
func listUsersHandler(c *gin.Context) {
var users []model.User
searchTerm := c.Query("search")
query := common.DB.Model(&model.User{})
if searchTerm != "" {
searchPattern := "%" + searchTerm + "%"
query = query.Where("email LIKE ? OR first_name LIKE ? OR username LIKE ?", searchPattern, searchPattern, searchPattern)
}
pagination, err := utils.Paginate(c, query, &users)
if err != nil {
utils.UnknownHandlerError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"data": users,
"pagination": pagination,
})
}
func getUserHandler(c *gin.Context) {
var user model.User
if err := common.DB.Where("id = ?", c.Param("id")).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
utils.HandlerError(c, http.StatusNotFound, "User not found")
return
}
utils.UnknownHandlerError(c, err)
return
}
c.JSON(http.StatusOK, user)
}
func getCurrentUserHandler(c *gin.Context) {
var user model.User
if err := common.DB.Where("id = ?", c.GetString("userID")).First(&user).Error; err != nil {
utils.UnknownHandlerError(c, err)
return
}
c.JSON(http.StatusOK, user)
}
func deleteUserHandler(c *gin.Context) {
var user model.User
if err := common.DB.Where("id = ?", c.Param("id")).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
utils.HandlerError(c, http.StatusNotFound, "User not found")
return
}
utils.UnknownHandlerError(c, err)
return
}
if err := common.DB.Delete(&user).Error; err != nil {
utils.UnknownHandlerError(c, err)
return
}
c.Status(http.StatusNoContent)
}
func createUserHandler(c *gin.Context) {
var user model.User
if err := c.ShouldBindJSON(&user); err != nil {
utils.HandlerError(c, http.StatusBadRequest, "invalid request body")
return
}
if err := common.DB.Create(&user).Error; err != nil {
if errors.Is(err, gorm.ErrDuplicatedKey) {
if err := checkDuplicatedFields(user); err != nil {
utils.HandlerError(c, http.StatusBadRequest, err.Error())
return
}
} else {
utils.UnknownHandlerError(c, err)
return
}
}
c.JSON(http.StatusCreated, user)
}
func updateUserHandler(c *gin.Context) {
updateUser(c, c.Param("id"), false)
}
func updateCurrentUserHandler(c *gin.Context) {
updateUser(c, c.GetString("userID"), true)
}
func createOneTimeAccessTokenHandler(c *gin.Context) {
var input model.OneTimeAccessTokenCreateDto
if err := c.ShouldBindJSON(&input); err != nil {
utils.HandlerError(c, http.StatusBadRequest, "invalid request body")
return
}
randomString, err := utils.GenerateRandomAlphanumericString(16)
if err != nil {
utils.UnknownHandlerError(c, err)
return
}
oneTimeAccessToken := model.OneTimeAccessToken{
UserID: input.UserID,
ExpiresAt: input.ExpiresAt,
Token: randomString,
}
if err := common.DB.Create(&oneTimeAccessToken).Error; err != nil {
utils.UnknownHandlerError(c, err)
return
}
c.JSON(http.StatusCreated, gin.H{"token": oneTimeAccessToken.Token})
}
func exchangeOneTimeAccessTokenHandler(c *gin.Context) {
var oneTimeAccessToken model.OneTimeAccessToken
if err := common.DB.Where("token = ? AND expires_at > ?", c.Param("token"), utils.FormatDateForDb(time.Now())).Preload("User").First(&oneTimeAccessToken).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
utils.HandlerError(c, http.StatusForbidden, "Token is invalid or expired")
return
}
utils.UnknownHandlerError(c, err)
return
}
token, err := common.GenerateAccessToken(oneTimeAccessToken.User)
if err != nil {
utils.UnknownHandlerError(c, err)
log.Println(err)
return
}
if err := common.DB.Delete(&oneTimeAccessToken).Error; err != nil {
utils.UnknownHandlerError(c, err)
return
}
c.SetCookie("access_token", token, int(time.Hour.Seconds()), "/", "", false, true)
c.JSON(http.StatusOK, oneTimeAccessToken.User)
}
// getSetupAccessTokenHandler creates the initial admin user and returns an access token for the user
// This handler is only available if there are no users in the database
func getSetupAccessTokenHandler(c *gin.Context) {
var userCount int64
if err := common.DB.Model(&model.User{}).Count(&userCount).Error; err != nil {
log.Fatal("failed to count users", err)
}
// If there are more than one user, we don't need to create the admin user
if userCount > 1 {
utils.HandlerError(c, http.StatusForbidden, "Setup already completed")
return
}
var user = model.User{
FirstName: "Admin",
LastName: "Admin",
Username: "admin",
Email: "admin@admin.com",
IsAdmin: true,
}
// Create the initial admin user if it doesn't exist
if err := common.DB.Model(&model.User{}).Preload("Credentials").FirstOrCreate(&user).Error; err != nil {
log.Fatal("failed to create admin user", err)
}
// If the user already has credentials, the setup is already completed
if len(user.Credentials) > 0 {
utils.HandlerError(c, http.StatusForbidden, "Setup already completed")
return
}
token, err := common.GenerateAccessToken(user)
if err != nil {
utils.UnknownHandlerError(c, err)
log.Println(err)
return
}
c.SetCookie("access_token", token, int(time.Hour.Seconds()), "/", "", false, true)
c.JSON(http.StatusOK, user)
}
func updateUser(c *gin.Context, userID string, updateOwnUser bool) {
var user model.User
if err := common.DB.Where("id = ?", userID).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
utils.HandlerError(c, http.StatusNotFound, "User not found")
return
}
utils.UnknownHandlerError(c, err)
return
}
var updatedUser model.User
if err := c.ShouldBindJSON(&updatedUser); err != nil {
utils.HandlerError(c, http.StatusBadRequest, "invalid request body")
return
}
user.FirstName = updatedUser.FirstName
user.LastName = updatedUser.LastName
user.Email = updatedUser.Email
user.Username = updatedUser.Username
user.Username = updatedUser.Username
if !updateOwnUser {
user.IsAdmin = updatedUser.IsAdmin
}
if err := common.DB.Save(user).Error; err != nil {
if errors.Is(err, gorm.ErrDuplicatedKey) {
if err := checkDuplicatedFields(user); err != nil {
utils.HandlerError(c, http.StatusBadRequest, err.Error())
return
}
} else {
utils.UnknownHandlerError(c, err)
return
}
}
c.JSON(http.StatusOK, user)
}
func checkDuplicatedFields(user model.User) error {
var existingUser model.User
if common.DB.Where("id != ? AND email = ?", user.ID, user.Email).First(&existingUser).Error == nil {
return errors.New("email is already taken")
}
if common.DB.Where("id != ? AND username = ?", user.ID, user.Username).First(&existingUser).Error == nil {
return errors.New("username is already taken")
}
return nil
}

View File

@@ -1,257 +0,0 @@
package handler
import (
"github.com/gin-gonic/gin"
"github.com/go-webauthn/webauthn/protocol"
"github.com/go-webauthn/webauthn/webauthn"
"golang-rest-api-template/internal/common"
"golang-rest-api-template/internal/common/middleware"
"golang-rest-api-template/internal/model"
"golang-rest-api-template/internal/utils"
"golang.org/x/time/rate"
"gorm.io/gorm"
"log"
"net/http"
"strings"
"time"
)
func RegisterRoutes(group *gin.RouterGroup) {
group.GET("/webauthn/register/start", middleware.JWTAuth(false), beginRegistrationHandler)
group.POST("/webauthn/register/finish", middleware.JWTAuth(false), verifyRegistrationHandler)
group.GET("/webauthn/login/start", beginLoginHandler)
group.POST("/webauthn/login/finish", middleware.RateLimiter(rate.Every(10*time.Second), 5), verifyLoginHandler)
group.POST("/webauthn/logout", middleware.JWTAuth(false), logoutHandler)
group.GET("/webauthn/credentials", middleware.JWTAuth(false), listCredentialsHandler)
group.PATCH("/webauthn/credentials/:id", middleware.JWTAuth(false), updateCredentialHandler)
group.DELETE("/webauthn/credentials/:id", middleware.JWTAuth(false), deleteCredentialHandler)
}
func beginRegistrationHandler(c *gin.Context) {
var user model.User
err := common.DB.Preload("Credentials").Find(&user, "id = ?", c.GetString("userID")).Error
if err != nil {
utils.UnknownHandlerError(c, err)
log.Println(err)
return
}
options, session, err := common.WebAuthn.BeginRegistration(&user, webauthn.WithResidentKeyRequirement(protocol.ResidentKeyRequirementRequired), webauthn.WithExclusions(user.WebAuthnCredentialDescriptors()))
if err != nil {
utils.UnknownHandlerError(c, err)
return
}
// Save the webauthn session so we can retrieve it in the verifyRegistrationHandler
sessionToStore := &model.WebauthnSession{
ExpiresAt: session.Expires,
Challenge: session.Challenge,
UserVerification: string(session.UserVerification),
}
if err = common.DB.Create(&sessionToStore).Error; err != nil {
utils.UnknownHandlerError(c, err)
return
}
c.SetCookie("session_id", sessionToStore.ID, int(common.WebAuthn.Config.Timeouts.Registration.Timeout.Seconds()), "/", "", false, true)
c.JSON(http.StatusOK, options.Response)
}
func verifyRegistrationHandler(c *gin.Context) {
sessionID, err := c.Cookie("session_id")
if err != nil {
utils.HandlerError(c, http.StatusBadRequest, "Session ID missing")
return
}
// Retrieve the session that was previously created by the beginRegistrationHandler
var storedSession model.WebauthnSession
err = common.DB.First(&storedSession, "id = ?", sessionID).Error
session := webauthn.SessionData{
Challenge: storedSession.Challenge,
Expires: storedSession.ExpiresAt,
UserID: []byte(c.GetString("userID")),
}
var user model.User
err = common.DB.Find(&user, "id = ?", c.GetString("userID")).Error
if err != nil {
utils.UnknownHandlerError(c, err)
return
}
credential, err := common.WebAuthn.FinishRegistration(&user, session, c.Request)
if err != nil {
utils.UnknownHandlerError(c, err)
return
}
credentialToStore := model.WebauthnCredential{
Name: "New Passkey",
CredentialID: string(credential.ID),
AttestationType: credential.AttestationType,
PublicKey: credential.PublicKey,
Transport: credential.Transport,
UserID: user.ID,
BackupEligible: credential.Flags.BackupEligible,
BackupState: credential.Flags.BackupState,
}
if err := common.DB.Create(&credentialToStore).Error; err != nil {
utils.UnknownHandlerError(c, err)
return
}
c.JSON(http.StatusOK, credentialToStore)
}
func beginLoginHandler(c *gin.Context) {
options, session, err := common.WebAuthn.BeginDiscoverableLogin()
if err != nil {
utils.UnknownHandlerError(c, err)
return
}
// Save the webauthn session so we can retrieve it in the verifyLoginHandler
sessionToStore := &model.WebauthnSession{
ExpiresAt: session.Expires,
Challenge: session.Challenge,
UserVerification: string(session.UserVerification),
}
if err = common.DB.Create(&sessionToStore).Error; err != nil {
utils.UnknownHandlerError(c, err)
return
}
c.SetCookie("session_id", sessionToStore.ID, int(common.WebAuthn.Config.Timeouts.Registration.Timeout.Seconds()), "/", "", false, true)
c.JSON(http.StatusOK, options.Response)
}
func verifyLoginHandler(c *gin.Context) {
sessionID, err := c.Cookie("session_id")
if err != nil {
utils.HandlerError(c, http.StatusBadRequest, "Session ID missing")
return
}
credentialAssertionData, err := protocol.ParseCredentialRequestResponseBody(c.Request.Body)
if err != nil {
utils.HandlerError(c, http.StatusBadRequest, "Invalid body")
return
}
// Retrieve the session that was previously created by the beginLoginHandler
var storedSession model.WebauthnSession
if err := common.DB.First(&storedSession, "id = ?", sessionID).Error; err != nil {
utils.UnknownHandlerError(c, err)
return
}
session := webauthn.SessionData{
Challenge: storedSession.Challenge,
Expires: storedSession.ExpiresAt,
}
var user *model.User
_, err = common.WebAuthn.ValidateDiscoverableLogin(func(_, userHandle []byte) (webauthn.User, error) {
if err := common.DB.Preload("Credentials").First(&user, "id = ?", string(userHandle)).Error; err != nil {
return nil, err
}
return user, nil
}, session, credentialAssertionData)
if err != nil {
if strings.Contains(err.Error(), gorm.ErrRecordNotFound.Error()) {
utils.HandlerError(c, http.StatusBadRequest, "no user with this passkey exists")
} else {
utils.UnknownHandlerError(c, err)
}
return
}
err = common.DB.Find(&user, "id = ?", c.GetString("userID")).Error
if err != nil {
utils.UnknownHandlerError(c, err)
return
}
token, err := common.GenerateAccessToken(*user)
if err != nil {
utils.UnknownHandlerError(c, err)
return
}
c.SetCookie("access_token", token, int(time.Hour.Seconds()), "/", "", false, true)
c.JSON(http.StatusOK, user)
}
func listCredentialsHandler(c *gin.Context) {
var credentials []model.WebauthnCredential
if err := common.DB.Find(&credentials, "user_id = ?", c.GetString("userID")).Error; err != nil {
utils.UnknownHandlerError(c, err)
return
}
c.JSON(http.StatusOK, credentials)
}
func deleteCredentialHandler(c *gin.Context) {
var passkeyCount int64
if err := common.DB.Model(&model.WebauthnCredential{}).Where("user_id = ?", c.GetString("userID")).Count(&passkeyCount).Error; err != nil {
utils.UnknownHandlerError(c, err)
return
}
if passkeyCount == 1 {
utils.HandlerError(c, http.StatusBadRequest, "You must have at least one passkey")
return
}
var credential model.WebauthnCredential
if err := common.DB.First(&credential, "id = ? AND user_id = ?", c.Param("id"), c.GetString("userID")).Error; err != nil {
utils.HandlerError(c, http.StatusNotFound, "Credential not found")
return
}
if err := common.DB.Delete(&credential).Error; err != nil {
utils.UnknownHandlerError(c, err)
return
}
c.Status(http.StatusNoContent)
}
func updateCredentialHandler(c *gin.Context) {
var credential model.WebauthnCredential
if err := common.DB.Where("id = ? AND user_id = ?", c.Param("id"), c.GetString("userID")).First(&credential).Error; err != nil {
utils.HandlerError(c, http.StatusNotFound, "Credential not found")
return
}
var input struct {
Name string `json:"name"`
}
if err := c.ShouldBindJSON(&input); err != nil {
utils.HandlerError(c, http.StatusBadRequest, "invalid request body")
return
}
credential.Name = input.Name
if err := common.DB.Save(&credential).Error; err != nil {
utils.UnknownHandlerError(c, err)
return
}
c.Status(http.StatusNoContent)
}
func logoutHandler(c *gin.Context) {
c.SetCookie("access_token", "", 0, "/", "", false, true)
c.Status(http.StatusNoContent)
}

View File

@@ -3,28 +3,46 @@ package job
import (
"github.com/go-co-op/gocron/v2"
"github.com/google/uuid"
"golang-rest-api-template/internal/common"
"golang-rest-api-template/internal/model"
"golang-rest-api-template/internal/utils"
"github.com/stonith404/pocket-id/backend/internal/model"
"github.com/stonith404/pocket-id/backend/internal/utils"
"gorm.io/gorm"
"log"
"time"
)
func RegisterJobs() {
func RegisterJobs(db *gorm.DB) {
scheduler, err := gocron.NewScheduler()
if err != nil {
log.Fatalf("Failed to create a new scheduler: %s", err)
}
registerJob(scheduler, "ClearWebauthnSessions", "0 3 * * *", clearWebauthnSessions)
registerJob(scheduler, "ClearOneTimeAccessTokens", "0 3 * * *", clearOneTimeAccessTokens)
registerJob(scheduler, "ClearOidcAuthorizationCodes", "0 3 * * *", clearOidcAuthorizationCodes)
jobs := &Jobs{db: db}
registerJob(scheduler, "ClearWebauthnSessions", "0 3 * * *", jobs.clearWebauthnSessions)
registerJob(scheduler, "ClearOneTimeAccessTokens", "0 3 * * *", jobs.clearOneTimeAccessTokens)
registerJob(scheduler, "ClearOidcAuthorizationCodes", "0 3 * * *", jobs.clearOidcAuthorizationCodes)
scheduler.Start()
}
func registerJob(scheduler gocron.Scheduler, name string, interval string, job func() error) {
type Jobs struct {
db *gorm.DB
}
func (j *Jobs) clearWebauthnSessions() error {
return j.db.Delete(&model.WebauthnSession{}, "expires_at < ?", utils.FormatDateForDb(time.Now())).Error
}
func (j *Jobs) clearOneTimeAccessTokens() error {
return j.db.Debug().Delete(&model.OneTimeAccessToken{}, "expires_at < ?", utils.FormatDateForDb(time.Now())).Error
}
func (j *Jobs) clearOidcAuthorizationCodes() error {
return j.db.Delete(&model.OidcAuthorizationCode{}, "expires_at < ?", utils.FormatDateForDb(time.Now())).Error
}
func registerJob(scheduler gocron.Scheduler, name string, interval string, job func() error) {
_, err := scheduler.NewJob(
gocron.CronJob(interval, false),
gocron.NewTask(job),
@@ -42,16 +60,3 @@ func registerJob(scheduler gocron.Scheduler, name string, interval string, job f
log.Fatalf("Failed to register job %q: %v", name, err)
}
}
func clearWebauthnSessions() error {
return common.DB.Delete(&model.WebauthnSession{}, "expires_at < ?", utils.FormatDateForDb(time.Now())).Error
}
func clearOneTimeAccessTokens() error {
return common.DB.Debug().Delete(&model.OneTimeAccessToken{}, "expires_at < ?", utils.FormatDateForDb(time.Now())).Error
}
func clearOidcAuthorizationCodes() error {
return common.DB.Delete(&model.OidcAuthorizationCode{}, "expires_at < ?", utils.FormatDateForDb(time.Now())).Error
}

View File

@@ -1,14 +1,20 @@
package middleware
import (
"golang-rest-api-template/internal/common"
"github.com/stonith404/pocket-id/backend/internal/common"
"time"
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
)
func Cors() gin.HandlerFunc {
type CorsMiddleware struct{}
func NewCorsMiddleware() *CorsMiddleware {
return &CorsMiddleware{}
}
func (m *CorsMiddleware) Add() gin.HandlerFunc {
return cors.New(cors.Config{
AllowOrigins: []string{common.EnvConfig.AppURL},
AllowMethods: []string{"*"},

View File

@@ -3,11 +3,17 @@ package middleware
import (
"fmt"
"github.com/gin-gonic/gin"
"golang-rest-api-template/internal/utils"
"github.com/stonith404/pocket-id/backend/internal/utils"
"net/http"
)
func LimitFileSize(maxSize int64) gin.HandlerFunc {
type FileSizeLimitMiddleware struct{}
func NewFileSizeLimitMiddleware() *FileSizeLimitMiddleware {
return &FileSizeLimitMiddleware{}
}
func (m *FileSizeLimitMiddleware) Add(maxSize int64) gin.HandlerFunc {
return func(c *gin.Context) {
c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, maxSize)
if err := c.Request.ParseMultipartForm(maxSize); err != nil {

View File

@@ -2,15 +2,22 @@ package middleware
import (
"github.com/gin-gonic/gin"
"golang-rest-api-template/internal/common"
"golang-rest-api-template/internal/utils"
"github.com/stonith404/pocket-id/backend/internal/service"
"github.com/stonith404/pocket-id/backend/internal/utils"
"net/http"
"strings"
)
func JWTAuth(adminOnly bool) gin.HandlerFunc {
return func(c *gin.Context) {
type JwtAuthMiddleware struct {
jwtService *service.JwtService
}
func NewJwtAuthMiddleware(jwtService *service.JwtService) *JwtAuthMiddleware {
return &JwtAuthMiddleware{jwtService: jwtService}
}
func (m *JwtAuthMiddleware) Add(adminOnly bool) gin.HandlerFunc {
return func(c *gin.Context) {
// Extract the token from the cookie or the Authorization header
token, err := c.Cookie("access_token")
if err != nil {
@@ -22,11 +29,9 @@ func JWTAuth(adminOnly bool) gin.HandlerFunc {
c.Abort()
return
}
}
// Verify the token
claims, err := common.VerifyAccessToken(token)
claims, err := m.jwtService.VerifyAccessToken(token)
if err != nil {
utils.HandlerError(c, http.StatusUnauthorized, "You're not signed in")
c.Abort()

View File

@@ -1,8 +1,8 @@
package middleware
import (
"golang-rest-api-template/internal/common"
"golang-rest-api-template/internal/utils"
"github.com/stonith404/pocket-id/backend/internal/common"
"github.com/stonith404/pocket-id/backend/internal/utils"
"net/http"
"sync"
"time"
@@ -11,8 +11,13 @@ import (
"golang.org/x/time/rate"
)
// RateLimiter is a Gin middleware for rate limiting based on client IP
func RateLimiter(limit rate.Limit, burst int) gin.HandlerFunc {
type RateLimitMiddleware struct{}
func NewRateLimitMiddleware() *RateLimitMiddleware {
return &RateLimitMiddleware{}
}
func (m *RateLimitMiddleware) Add(limit rate.Limit, burst int) gin.HandlerFunc {
// Start the cleanup routine
go cleanupClients()

View File

@@ -0,0 +1,20 @@
package model
type AppConfigVariable struct {
Key string `gorm:"primaryKey;not null" json:"key"`
Type string `json:"type"`
IsPublic bool `json:"-"`
IsInternal bool `json:"-"`
Value string `json:"value"`
}
type AppConfig struct {
AppName AppConfigVariable
BackgroundImageType AppConfigVariable
LogoImageType AppConfigVariable
SessionDuration AppConfigVariable
}
type AppConfigUpdateDto struct {
AppName string `json:"appName" binding:"required"`
}

View File

@@ -1,20 +0,0 @@
package model
type ApplicationConfigurationVariable struct {
Key string `gorm:"primaryKey;not null" json:"key"`
Type string `json:"type"`
IsPublic bool `json:"-"`
IsInternal bool `json:"-"`
Value string `json:"value"`
}
type ApplicationConfiguration struct {
AppName ApplicationConfigurationVariable
BackgroundImageType ApplicationConfigurationVariable
LogoImageType ApplicationConfigurationVariable
SessionDuration ApplicationConfigurationVariable
}
type ApplicationConfigurationUpdateDto struct {
AppName string `json:"appName" binding:"required"`
}

View File

@@ -12,7 +12,7 @@ type Base struct {
CreatedAt time.Time `json:"createdAt"`
}
func (b *Base) BeforeCreate(db *gorm.DB) (err error) {
func (b *Base) BeforeCreate(_ *gorm.DB) (err error) {
if b.ID == "" {
b.ID = uuid.New().String()
}

View File

@@ -63,3 +63,9 @@ type OidcIdTokenDto struct {
ClientID string `form:"client_id"`
ClientSecret string `form:"client_secret"`
}
type AuthorizeRequest struct {
ClientID string `json:"clientID" binding:"required"`
Scope string `json:"scope" binding:"required"`
Nonce string `json:"nonce"`
}

View File

@@ -31,6 +31,18 @@ type WebauthnCredential struct {
UserID string
}
type PublicKeyCredentialCreationOptions struct {
Response protocol.PublicKeyCredentialCreationOptions `json:"response"`
SessionID string `json:"session_id"`
Timeout time.Duration `json:"timeout"`
}
type PublicKeyCredentialRequestOptions struct {
Response protocol.PublicKeyCredentialRequestOptions `json:"response"`
SessionID string `json:"session_id"`
Timeout time.Duration `json:"timeout"`
}
type AuthenticatorTransportList []protocol.AuthenticatorTransport
// Scan and Value methods for GORM to handle the custom type
@@ -46,3 +58,7 @@ func (atl *AuthenticatorTransportList) Scan(value interface{}) error {
func (atl AuthenticatorTransportList) Value() (driver.Value, error) {
return json.Marshal(atl)
}
type WebauthnCredentialUpdateDto struct {
Name string `json:"name"`
}

View File

@@ -0,0 +1,213 @@
package service
import (
"fmt"
"github.com/stonith404/pocket-id/backend/internal/common"
"github.com/stonith404/pocket-id/backend/internal/model"
"github.com/stonith404/pocket-id/backend/internal/utils"
"gorm.io/gorm"
"log"
"mime/multipart"
"os"
"reflect"
)
type AppConfigService struct {
DbConfig *model.AppConfig
db *gorm.DB
}
func NewAppConfigService(db *gorm.DB) *AppConfigService {
service := &AppConfigService{
DbConfig: &defaultDbConfig,
db: db,
}
if err := service.InitDbConfig(); err != nil {
log.Fatalf("Failed to initialize app config service: %v", err)
}
return service
}
var defaultDbConfig = model.AppConfig{
AppName: model.AppConfigVariable{
Key: "appName",
Type: "string",
IsPublic: true,
Value: "Pocket ID",
},
SessionDuration: model.AppConfigVariable{
Key: "sessionDuration",
Type: "number",
Value: "60",
},
BackgroundImageType: model.AppConfigVariable{
Key: "backgroundImageType",
Type: "string",
IsInternal: true,
Value: "jpg",
},
LogoImageType: model.AppConfigVariable{
Key: "logoImageType",
Type: "string",
IsInternal: true,
Value: "svg",
},
}
func (s *AppConfigService) UpdateApplicationConfiguration(input model.AppConfigUpdateDto) ([]model.AppConfigVariable, error) {
savedConfigVariables := make([]model.AppConfigVariable, 10)
tx := s.db.Begin()
rt := reflect.ValueOf(input).Type()
rv := reflect.ValueOf(input)
for i := 0; i < rt.NumField(); i++ {
field := rt.Field(i)
key := field.Tag.Get("json")
value := rv.FieldByName(field.Name).String()
var applicationConfigurationVariable model.AppConfigVariable
if err := tx.First(&applicationConfigurationVariable, "key = ? AND is_internal = false", key).Error; err != nil {
tx.Rollback()
return nil, err
}
applicationConfigurationVariable.Value = value
if err := tx.Save(&applicationConfigurationVariable).Error; err != nil {
tx.Rollback()
return nil, err
}
savedConfigVariables[i] = applicationConfigurationVariable
}
tx.Commit()
if err := s.loadDbConfigFromDb(); err != nil {
return nil, err
}
return savedConfigVariables, nil
}
func (s *AppConfigService) UpdateImageType(imageName string, fileType string) error {
key := fmt.Sprintf("%sImageType", imageName)
err := s.db.Model(&model.AppConfigVariable{}).Where("key = ?", key).Update("value", fileType).Error
if err != nil {
return err
}
return s.loadDbConfigFromDb()
}
func (s *AppConfigService) ListApplicationConfiguration(showAll bool) ([]model.AppConfigVariable, error) {
var configuration []model.AppConfigVariable
var err error
if showAll {
err = s.db.Find(&configuration).Error
} else {
err = s.db.Find(&configuration, "is_public = true").Error
}
if err != nil {
return nil, err
}
return configuration, nil
}
func (s *AppConfigService) UpdateImage(uploadedFile *multipart.FileHeader, imageName string, oldImageType string) error {
fileType := utils.GetFileExtension(uploadedFile.Filename)
mimeType := utils.GetImageMimeType(fileType)
if mimeType == "" {
return common.ErrFileTypeNotSupported
}
// Delete the old image if it has a different file type
if fileType != oldImageType {
oldImagePath := fmt.Sprintf("%s/application-images/%s.%s", common.EnvConfig.UploadPath, imageName, oldImageType)
if err := os.Remove(oldImagePath); err != nil {
return err
}
}
imagePath := fmt.Sprintf("%s/application-images/%s.%s", common.EnvConfig.UploadPath, imageName, fileType)
if err := utils.SaveFile(uploadedFile, imagePath); err != nil {
return err
}
// Update the file type in the database
if err := s.UpdateImageType(imageName, fileType); err != nil {
return err
}
return nil
}
// InitDbConfig creates the default configuration values in the database if they do not exist,
// updates existing configurations if they differ from the default, and deletes any configurations
// that are not in the default configuration.
func (s *AppConfigService) InitDbConfig() error {
// Reflect to get the underlying value of DbConfig and its default configuration
defaultConfigReflectValue := reflect.ValueOf(defaultDbConfig)
defaultKeys := make(map[string]struct{})
// Iterate over the fields of DbConfig
for i := 0; i < defaultConfigReflectValue.NumField(); i++ {
defaultConfigVar := defaultConfigReflectValue.Field(i).Interface().(model.AppConfigVariable)
defaultKeys[defaultConfigVar.Key] = struct{}{}
var storedConfigVar model.AppConfigVariable
if err := s.db.First(&storedConfigVar, "key = ?", defaultConfigVar.Key).Error; err != nil {
// If the configuration does not exist, create it
if err := s.db.Create(&defaultConfigVar).Error; err != nil {
return err
}
continue
}
// Update existing configuration if it differs from the default
if storedConfigVar.Type != defaultConfigVar.Type || storedConfigVar.IsPublic != defaultConfigVar.IsPublic || storedConfigVar.IsInternal != defaultConfigVar.IsInternal {
storedConfigVar.Type = defaultConfigVar.Type
storedConfigVar.IsPublic = defaultConfigVar.IsPublic
storedConfigVar.IsInternal = defaultConfigVar.IsInternal
if err := s.db.Save(&storedConfigVar).Error; err != nil {
return err
}
}
}
// Delete any configurations not in the default keys
var allConfigVars []model.AppConfigVariable
if err := s.db.Find(&allConfigVars).Error; err != nil {
return err
}
for _, config := range allConfigVars {
if _, exists := defaultKeys[config.Key]; !exists {
if err := s.db.Delete(&config).Error; err != nil {
return err
}
}
}
return s.loadDbConfigFromDb()
}
func (s *AppConfigService) loadDbConfigFromDb() error {
dbConfigReflectValue := reflect.ValueOf(s.DbConfig).Elem()
for i := 0; i < dbConfigReflectValue.NumField(); i++ {
dbConfigField := dbConfigReflectValue.Field(i)
currentConfigVar := dbConfigField.Interface().(model.AppConfigVariable)
var storedConfigVar model.AppConfigVariable
if err := s.db.First(&storedConfigVar, "key = ?", currentConfigVar.Key).Error; err != nil {
return err
}
dbConfigField.Set(reflect.ValueOf(storedConfigVar))
}
return nil
}

View File

@@ -0,0 +1,248 @@
package service
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"errors"
"fmt"
"github.com/golang-jwt/jwt/v5"
"github.com/stonith404/pocket-id/backend/internal/common"
"github.com/stonith404/pocket-id/backend/internal/model"
"github.com/stonith404/pocket-id/backend/internal/utils"
"log"
"math/big"
"os"
"path/filepath"
"slices"
"strconv"
"strings"
"time"
)
const (
privateKeyPath = "data/keys/jwt_private_key.pem"
publicKeyPath = "data/keys/jwt_public_key.pem"
)
type JwtService struct {
publicKey *rsa.PublicKey
privateKey *rsa.PrivateKey
appConfigService *AppConfigService
}
func NewJwtService(appConfigService *AppConfigService) *JwtService {
service := &JwtService{
appConfigService: appConfigService,
}
// Ensure keys are generated or loaded
if err := service.loadOrGenerateKeys(); err != nil {
log.Fatalf("Failed to initialize jwt service: %v", err)
}
return service
}
type AccessTokenJWTClaims struct {
jwt.RegisteredClaims
IsAdmin bool `json:"isAdmin,omitempty"`
}
type JWK struct {
Kty string `json:"kty"`
Use string `json:"use"`
Kid string `json:"kid"`
Alg string `json:"alg"`
N string `json:"n"`
E string `json:"e"`
}
// loadOrGenerateKeys loads RSA keys from the given paths or generates them if they do not exist.
func (s *JwtService) loadOrGenerateKeys() error {
if _, err := os.Stat(privateKeyPath); os.IsNotExist(err) {
if err := s.generateKeys(); err != nil {
return err
}
}
privateKeyBytes, err := os.ReadFile(privateKeyPath)
if err != nil {
return errors.New("can't read jwt private key: " + err.Error())
}
s.privateKey, err = jwt.ParseRSAPrivateKeyFromPEM(privateKeyBytes)
if err != nil {
return errors.New("can't parse jwt private key: " + err.Error())
}
publicKeyBytes, err := os.ReadFile(publicKeyPath)
if err != nil {
return errors.New("can't read jwt public key: " + err.Error())
}
s.publicKey, err = jwt.ParseRSAPublicKeyFromPEM(publicKeyBytes)
if err != nil {
return errors.New("can't parse jwt public key: " + err.Error())
}
return nil
}
func (s *JwtService) GenerateIDToken(user model.User, clientID string, scope string, nonce string) (string, error) {
profileClaims := map[string]interface{}{
"given_name": user.FirstName,
"family_name": user.LastName,
"email": user.Email,
"preferred_username": user.Username,
}
claims := jwt.MapClaims{
"sub": user.ID,
"aud": clientID,
"exp": jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
"iat": jwt.NewNumericDate(time.Now()),
}
if nonce != "" {
claims["nonce"] = nonce
}
if strings.Contains(scope, "profile") {
for k, v := range profileClaims {
claims[k] = v
}
}
if strings.Contains(scope, "email") {
claims["email"] = user.Email
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
return token.SignedString(s.privateKey)
}
func (s *JwtService) GenerateAccessToken(user model.User) (string, error) {
sessionDurationInMinutes, _ := strconv.Atoi(s.appConfigService.DbConfig.SessionDuration.Value)
claim := AccessTokenJWTClaims{
RegisteredClaims: jwt.RegisteredClaims{
Subject: user.ID,
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Duration(sessionDurationInMinutes) * time.Minute)),
IssuedAt: jwt.NewNumericDate(time.Now()),
Audience: jwt.ClaimStrings{utils.GetHostFromURL(common.EnvConfig.AppURL)},
},
IsAdmin: user.IsAdmin,
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claim)
return token.SignedString(s.privateKey)
}
func (s *JwtService) VerifyAccessToken(tokenString string) (*AccessTokenJWTClaims, error) {
token, err := jwt.ParseWithClaims(tokenString, &AccessTokenJWTClaims{}, func(token *jwt.Token) (interface{}, error) {
return s.publicKey, nil
})
if err != nil || !token.Valid {
return nil, errors.New("couldn't handle this token")
}
claims, isValid := token.Claims.(*AccessTokenJWTClaims)
if !isValid {
return nil, errors.New("can't parse claims")
}
if !slices.Contains(claims.Audience, utils.GetHostFromURL(common.EnvConfig.AppURL)) {
return nil, errors.New("audience doesn't match")
}
return claims, nil
}
// GetJWK returns the JSON Web Key (JWK) for the public key.
func (s *JwtService) GetJWK() (JWK, error) {
if s.publicKey == nil {
return JWK{}, errors.New("public key is not initialized")
}
jwk := JWK{
Kty: "RSA",
Use: "sig",
Kid: "1",
Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(s.publicKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(s.publicKey.E)).Bytes()),
}
return jwk, nil
}
// generateKeys generates a new RSA key pair and saves them to the specified paths.
func (s *JwtService) generateKeys() error {
if err := os.MkdirAll(filepath.Dir(privateKeyPath), 0700); err != nil {
return errors.New("failed to create directories for keys: " + err.Error())
}
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return errors.New("failed to generate private key: " + err.Error())
}
s.privateKey = privateKey
if err := s.savePEMKey(privateKeyPath, x509.MarshalPKCS1PrivateKey(privateKey), "RSA PRIVATE KEY"); err != nil {
return err
}
publicKey := &privateKey.PublicKey
s.publicKey = publicKey
if err := s.savePEMKey(publicKeyPath, x509.MarshalPKCS1PublicKey(publicKey), "RSA PUBLIC KEY"); err != nil {
return err
}
return nil
}
// savePEMKey saves a PEM encoded key to a file.
func (s *JwtService) savePEMKey(path string, keyBytes []byte, keyType string) error {
keyFile, err := os.Create(path)
if err != nil {
return errors.New("failed to create key file: " + err.Error())
}
defer keyFile.Close()
keyPEM := pem.EncodeToMemory(&pem.Block{
Type: keyType,
Bytes: keyBytes,
})
if _, err := keyFile.Write(keyPEM); err != nil {
return errors.New("failed to write key file: " + err.Error())
}
return nil
}
// loadKeys loads RSA keys from the given paths.
func (s *JwtService) loadKeys() error {
if _, err := os.Stat(privateKeyPath); os.IsNotExist(err) {
if err := s.generateKeys(); err != nil {
return err
}
}
privateKeyBytes, err := os.ReadFile(privateKeyPath)
if err != nil {
return fmt.Errorf("can't read jwt private key: %w", err)
}
s.privateKey, err = jwt.ParseRSAPrivateKeyFromPEM(privateKeyBytes)
if err != nil {
return fmt.Errorf("can't parse jwt private key: %w", err)
}
publicKeyBytes, err := os.ReadFile(publicKeyPath)
if err != nil {
return fmt.Errorf("can't read jwt public key: %w", err)
}
s.publicKey, err = jwt.ParseRSAPublicKeyFromPEM(publicKeyBytes)
if err != nil {
return fmt.Errorf("can't parse jwt public key: %w", err)
}
return nil
}

View File

@@ -0,0 +1,282 @@
package service
import (
"errors"
"fmt"
"github.com/stonith404/pocket-id/backend/internal/common"
"github.com/stonith404/pocket-id/backend/internal/model"
"github.com/stonith404/pocket-id/backend/internal/utils"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
"mime/multipart"
"os"
"time"
)
type OidcService struct {
db *gorm.DB
jwtService *JwtService
}
func NewOidcService(db *gorm.DB, jwtService *JwtService) *OidcService {
return &OidcService{
db: db,
jwtService: jwtService,
}
}
func (s *OidcService) Authorize(req model.AuthorizeRequest, userID string) (string, error) {
var userAuthorizedOIDCClient model.UserAuthorizedOidcClient
s.db.First(&userAuthorizedOIDCClient, "client_id = ? AND user_id = ?", req.ClientID, userID)
if userAuthorizedOIDCClient.Scope != req.Scope {
return "", common.ErrOidcMissingAuthorization
}
return s.createAuthorizationCode(req.ClientID, userID, req.Scope, req.Nonce)
}
func (s *OidcService) AuthorizeNewClient(req model.AuthorizeNewClientDto, userID string) (string, error) {
userAuthorizedClient := model.UserAuthorizedOidcClient{
UserID: userID,
ClientID: req.ClientID,
Scope: req.Scope,
}
if err := s.db.Create(&userAuthorizedClient).Error; err != nil {
if errors.Is(err, gorm.ErrDuplicatedKey) {
err = s.db.Model(&userAuthorizedClient).Update("scope", req.Scope).Error
} else {
return "", err
}
}
return s.createAuthorizationCode(req.ClientID, userID, req.Scope, req.Nonce)
}
func (s *OidcService) CreateIDToken(req model.OidcIdTokenDto) (string, error) {
if req.GrantType != "authorization_code" {
return "", common.ErrOidcGrantTypeNotSupported
}
clientID := req.ClientID
clientSecret := req.ClientSecret
if clientID == "" || clientSecret == "" {
return "", common.ErrOidcMissingClientCredentials
}
var client model.OidcClient
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
return "", err
}
err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret))
if err != nil {
return "", common.ErrOidcClientSecretInvalid
}
var authorizationCodeMetaData model.OidcAuthorizationCode
err = s.db.Preload("User").First(&authorizationCodeMetaData, "code = ?", req.Code).Error
if err != nil {
return "", common.ErrOidcInvalidAuthorizationCode
}
if authorizationCodeMetaData.ClientID != clientID && authorizationCodeMetaData.ExpiresAt.Before(time.Now()) {
return "", common.ErrOidcInvalidAuthorizationCode
}
idToken, err := s.jwtService.GenerateIDToken(authorizationCodeMetaData.User, clientID, authorizationCodeMetaData.Scope, authorizationCodeMetaData.Nonce)
if err != nil {
return "", err
}
s.db.Delete(&authorizationCodeMetaData)
return idToken, nil
}
func (s *OidcService) GetClient(clientID string) (*model.OidcClient, error) {
var client model.OidcClient
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
return nil, err
}
return &client, nil
}
func (s *OidcService) ListClients(searchTerm string, page int, pageSize int) ([]model.OidcClient, utils.PaginationResponse, error) {
var clients []model.OidcClient
query := s.db.Model(&model.OidcClient{})
if searchTerm != "" {
searchPattern := "%" + searchTerm + "%"
query = query.Where("name LIKE ?", searchPattern)
}
pagination, err := utils.Paginate(page, pageSize, query, &clients)
if err != nil {
return nil, utils.PaginationResponse{}, err
}
return clients, pagination, nil
}
func (s *OidcService) CreateClient(input model.OidcClientCreateDto, userID string) (*model.OidcClient, error) {
client := model.OidcClient{
Name: input.Name,
CallbackURL: input.CallbackURL,
CreatedByID: userID,
}
if err := s.db.Create(&client).Error; err != nil {
return nil, err
}
return &client, nil
}
func (s *OidcService) UpdateClient(clientID string, input model.OidcClientCreateDto) (*model.OidcClient, error) {
var client model.OidcClient
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
return nil, err
}
client.Name = input.Name
client.CallbackURL = input.CallbackURL
if err := s.db.Save(&client).Error; err != nil {
return nil, err
}
return &client, nil
}
func (s *OidcService) DeleteClient(clientID string) error {
var client model.OidcClient
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
return err
}
if err := s.db.Delete(&client).Error; err != nil {
return err
}
return nil
}
func (s *OidcService) CreateClientSecret(clientID string) (string, error) {
var client model.OidcClient
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
return "", err
}
clientSecret, err := utils.GenerateRandomAlphanumericString(32)
if err != nil {
return "", err
}
hashedSecret, err := bcrypt.GenerateFromPassword([]byte(clientSecret), bcrypt.DefaultCost)
if err != nil {
return "", err
}
client.Secret = string(hashedSecret)
if err := s.db.Save(&client).Error; err != nil {
return "", err
}
return clientSecret, nil
}
func (s *OidcService) GetClientLogo(clientID string) (string, string, error) {
var client model.OidcClient
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
return "", "", err
}
if client.ImageType == nil {
return "", "", errors.New("image not found")
}
imageType := *client.ImageType
imagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, client.ID, imageType)
mimeType := utils.GetImageMimeType(imageType)
return imagePath, mimeType, nil
}
func (s *OidcService) UpdateClientLogo(clientID string, file *multipart.FileHeader) error {
fileType := utils.GetFileExtension(file.Filename)
if mimeType := utils.GetImageMimeType(fileType); mimeType == "" {
return common.ErrFileTypeNotSupported
}
imagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, clientID, fileType)
if err := utils.SaveFile(file, imagePath); err != nil {
return err
}
var client model.OidcClient
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
return err
}
if client.ImageType != nil && fileType != *client.ImageType {
oldImagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, client.ID, *client.ImageType)
if err := os.Remove(oldImagePath); err != nil {
return err
}
}
client.ImageType = &fileType
if err := s.db.Save(&client).Error; err != nil {
return err
}
return nil
}
func (s *OidcService) DeleteClientLogo(clientID string) error {
var client model.OidcClient
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
return err
}
if client.ImageType == nil {
return errors.New("image not found")
}
imagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, client.ID, *client.ImageType)
if err := os.Remove(imagePath); err != nil {
return err
}
client.ImageType = nil
if err := s.db.Save(&client).Error; err != nil {
return err
}
return nil
}
func (s *OidcService) createAuthorizationCode(clientID string, userID string, scope string, nonce string) (string, error) {
randomString, err := utils.GenerateRandomAlphanumericString(32)
if err != nil {
return "", err
}
oidcAuthorizationCode := model.OidcAuthorizationCode{
ExpiresAt: time.Now().Add(15 * time.Minute),
Code: randomString,
ClientID: clientID,
UserID: userID,
Scope: scope,
Nonce: nonce,
}
if err := s.db.Create(&oidcAuthorizationCode).Error; err != nil {
return "", err
}
return randomString, nil
}

View File

@@ -1,48 +1,33 @@
package handler
package service
import (
"crypto/ecdsa"
"crypto/x509"
"encoding/base64"
"fmt"
"github.com/fxamacker/cbor/v2"
"log"
"os"
"time"
"github.com/fxamacker/cbor/v2"
"github.com/gin-gonic/gin"
"github.com/go-webauthn/webauthn/protocol"
"golang-rest-api-template/internal/common"
"golang-rest-api-template/internal/model"
"golang-rest-api-template/internal/utils"
"github.com/stonith404/pocket-id/backend/internal/common"
"github.com/stonith404/pocket-id/backend/internal/model"
"github.com/stonith404/pocket-id/backend/internal/utils"
"gorm.io/gorm"
)
func RegisterTestRoutes(group *gin.RouterGroup) {
group.POST("/test/reset", resetAndSeedHandler)
type TestService struct {
db *gorm.DB
appConfigService *AppConfigService
}
func resetAndSeedHandler(c *gin.Context) {
if err := resetDatabase(); err != nil {
utils.UnknownHandlerError(c, err)
return
}
if err := resetApplicationImages(); err != nil {
utils.UnknownHandlerError(c, err)
return
}
if err := seedDatabase(); err != nil {
utils.UnknownHandlerError(c, err)
return
}
c.JSON(200, gin.H{"message": "Database reset and seeded"})
func NewTestService(db *gorm.DB, appConfigService *AppConfigService) *TestService {
return &TestService{db: db, appConfigService: appConfigService}
}
// seedDatabase seeds the database with initial data and uses a transaction to ensure atomicity.
func seedDatabase() error {
return common.DB.Transaction(func(tx *gorm.DB) error {
func (s *TestService) SeedDatabase() error {
return s.db.Transaction(func(tx *gorm.DB) error {
users := []model.User{
{
Base: model.Base{
@@ -128,11 +113,16 @@ func seedDatabase() error {
return err
}
publicKey1, err := getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEwcOo5KV169KR67QEHrcYkeXE3CCxv2BgwnSq4VYTQxyLtdmKxegexa8JdwFKhKXa2BMI9xaN15BoL6wSCRFJhg==")
publicKey2, err := getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAESq/wR8QbBu3dKnpaw/v0mDxFFDwnJ/L5XHSg2tAmq5x1BpSMmIr3+DxCbybVvGRmWGh8kKhy7SMnK91M6rFHTA==")
if err != nil {
return err
}
webauthnCredentials := []model.WebauthnCredential{
{
Name: "Passkey 1",
CredentialID: "test-credential-1",
PublicKey: getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEwcOo5KV169KR67QEHrcYkeXE3CCxv2BgwnSq4VYTQxyLtdmKxegexa8JdwFKhKXa2BMI9xaN15BoL6wSCRFJhg=="),
PublicKey: publicKey1,
AttestationType: "none",
Transport: model.AuthenticatorTransportList{protocol.Internal},
UserID: users[0].ID,
@@ -140,7 +130,7 @@ func seedDatabase() error {
{
Name: "Passkey 2",
CredentialID: "test-credential-2",
PublicKey: getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAESq/wR8QbBu3dKnpaw/v0mDxFFDwnJ/L5XHSg2tAmq5x1BpSMmIr3+DxCbybVvGRmWGh8kKhy7SMnK91M6rFHTA=="),
PublicKey: publicKey2,
AttestationType: "none",
Transport: model.AuthenticatorTransportList{protocol.Internal},
UserID: users[0].ID,
@@ -165,9 +155,8 @@ func seedDatabase() error {
})
}
// resetDatabase resets the database by deleting all rows from each table.
func resetDatabase() error {
err := common.DB.Transaction(func(tx *gorm.DB) error {
func (s *TestService) ResetDatabase() error {
err := s.db.Transaction(func(tx *gorm.DB) error {
var tables []string
if err := tx.Raw("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%' AND name != 'schema_migrations';").Scan(&tables).Error; err != nil {
return err
@@ -183,13 +172,11 @@ func resetDatabase() error {
if err != nil {
return err
}
common.InitDbConfig()
return nil
err = s.appConfigService.InitDbConfig()
return err
}
// resetApplicationImages resets the application images by removing existing images and replacing them with the default ones
func resetApplicationImages() error {
func (s *TestService) ResetApplicationImages() error {
if err := os.RemoveAll(common.EnvConfig.UploadPath); err != nil {
log.Printf("Error removing directory: %v", err)
return err
@@ -204,20 +191,19 @@ func resetApplicationImages() error {
}
// getCborPublicKey decodes a Base64 encoded public key and returns the CBOR encoded COSE key
func getCborPublicKey(base64PublicKey string) []byte {
func getCborPublicKey(base64PublicKey string) ([]byte, error) {
decodedKey, err := base64.StdEncoding.DecodeString(base64PublicKey)
if err != nil {
log.Fatalf("Failed to decode base64 key: %v", err)
return nil, fmt.Errorf("failed to decode base64 key: %w", err)
}
pubKey, err := x509.ParsePKIXPublicKey(decodedKey)
if err != nil {
log.Fatalf("Failed to parse public key: %v", err)
return nil, fmt.Errorf("failed to parse public key: %w", err)
}
ecdsaPubKey, ok := pubKey.(*ecdsa.PublicKey)
if !ok {
log.Fatalf("Not an ECDSA public key")
return nil, fmt.Errorf("not an ECDSA public key")
}
coseKey := map[int]interface{}{
@@ -230,8 +216,8 @@ func getCborPublicKey(base64PublicKey string) []byte {
cborPublicKey, err := cbor.Marshal(coseKey)
if err != nil {
log.Fatalf("Failed to encode CBOR: %v", err)
return nil, fmt.Errorf("failed to marshal COSE key: %w", err)
}
return cborPublicKey
return cborPublicKey, nil
}

View File

@@ -0,0 +1,165 @@
package service
import (
"errors"
"github.com/stonith404/pocket-id/backend/internal/common"
"github.com/stonith404/pocket-id/backend/internal/model"
"github.com/stonith404/pocket-id/backend/internal/utils"
"gorm.io/gorm"
"time"
)
type UserService struct {
db *gorm.DB
jwtService *JwtService
}
func NewUserService(db *gorm.DB, jwtService *JwtService) *UserService {
return &UserService{db: db, jwtService: jwtService}
}
func (s *UserService) ListUsers(searchTerm string, page int, pageSize int) ([]model.User, utils.PaginationResponse, error) {
var users []model.User
query := s.db.Model(&model.User{})
if searchTerm != "" {
searchPattern := "%" + searchTerm + "%"
query = query.Where("email LIKE ? OR first_name LIKE ? OR username LIKE ?", searchPattern, searchPattern, searchPattern)
}
pagination, err := utils.Paginate(page, pageSize, query, &users)
return users, pagination, err
}
func (s *UserService) GetUser(userID string) (model.User, error) {
var user model.User
err := s.db.Where("id = ?", userID).First(&user).Error
return user, err
}
func (s *UserService) DeleteUser(userID string) error {
var user model.User
if err := s.db.Where("id = ?", userID).First(&user).Error; err != nil {
return err
}
return s.db.Delete(&user).Error
}
func (s *UserService) CreateUser(user *model.User) error {
if err := s.db.Create(user).Error; err != nil {
if errors.Is(err, gorm.ErrDuplicatedKey) {
return s.checkDuplicatedFields(*user)
}
return err
}
return nil
}
func (s *UserService) UpdateUser(userID string, updatedUser model.User, updateOwnUser bool) (model.User, error) {
var user model.User
if err := s.db.Where("id = ?", userID).First(&user).Error; err != nil {
return model.User{}, err
}
user.FirstName = updatedUser.FirstName
user.LastName = updatedUser.LastName
user.Email = updatedUser.Email
user.Username = updatedUser.Username
if !updateOwnUser {
user.IsAdmin = updatedUser.IsAdmin
}
if err := s.db.Save(&user).Error; err != nil {
if errors.Is(err, gorm.ErrDuplicatedKey) {
return user, s.checkDuplicatedFields(user)
}
return user, err
}
return user, nil
}
func (s *UserService) CreateOneTimeAccessToken(userID string, expiresAt time.Time) (string, error) {
randomString, err := utils.GenerateRandomAlphanumericString(16)
if err != nil {
return "", err
}
oneTimeAccessToken := model.OneTimeAccessToken{
UserID: userID,
ExpiresAt: expiresAt,
Token: randomString,
}
if err := s.db.Create(&oneTimeAccessToken).Error; err != nil {
return "", err
}
return oneTimeAccessToken.Token, nil
}
func (s *UserService) ExchangeOneTimeAccessToken(token string) (model.User, string, error) {
var oneTimeAccessToken model.OneTimeAccessToken
if err := s.db.Where("token = ? AND expires_at > ?", token, utils.FormatDateForDb(time.Now())).Preload("User").First(&oneTimeAccessToken).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return model.User{}, "", common.ErrTokenInvalidOrExpired
}
return model.User{}, "", err
}
accessToken, err := s.jwtService.GenerateAccessToken(oneTimeAccessToken.User)
if err != nil {
return model.User{}, "", err
}
if err := s.db.Delete(&oneTimeAccessToken).Error; err != nil {
return model.User{}, "", err
}
return oneTimeAccessToken.User, accessToken, nil
}
func (s *UserService) SetupInitialAdmin() (model.User, string, error) {
var userCount int64
if err := s.db.Model(&model.User{}).Count(&userCount).Error; err != nil {
return model.User{}, "", err
}
if userCount > 1 {
return model.User{}, "", common.ErrSetupAlreadyCompleted
}
user := model.User{
FirstName: "Admin",
LastName: "Admin",
Username: "admin",
Email: "admin@admin.com",
IsAdmin: true,
}
if err := s.db.Model(&model.User{}).Preload("Credentials").FirstOrCreate(&user).Error; err != nil {
return model.User{}, "", err
}
if len(user.Credentials) > 0 {
return model.User{}, "", common.ErrSetupAlreadyCompleted
}
token, err := s.jwtService.GenerateAccessToken(user)
if err != nil {
return model.User{}, "", err
}
return user, token, nil
}
func (s *UserService) checkDuplicatedFields(user model.User) error {
var existingUser model.User
if s.db.Where("id != ? AND email = ?", user.ID, user.Email).First(&existingUser).Error == nil {
return common.ErrEmailTaken
}
if s.db.Where("id != ? AND username = ?", user.ID, user.Username).First(&existingUser).Error == nil {
return common.ErrUsernameTaken
}
return nil
}

View File

@@ -0,0 +1,196 @@
package service
import (
"github.com/go-webauthn/webauthn/protocol"
"github.com/go-webauthn/webauthn/webauthn"
"github.com/stonith404/pocket-id/backend/internal/common"
"github.com/stonith404/pocket-id/backend/internal/model"
"github.com/stonith404/pocket-id/backend/internal/utils"
"gorm.io/gorm"
"net/http"
"time"
)
type WebAuthnService struct {
db *gorm.DB
webAuthn *webauthn.WebAuthn
}
func NewWebAuthnService(db *gorm.DB, appConfigService *AppConfigService) *WebAuthnService {
webauthnConfig := &webauthn.Config{
RPDisplayName: appConfigService.DbConfig.AppName.Value,
RPID: utils.GetHostFromURL(common.EnvConfig.AppURL),
RPOrigins: []string{common.EnvConfig.AppURL},
Timeouts: webauthn.TimeoutsConfig{
Login: webauthn.TimeoutConfig{
Enforce: true,
Timeout: time.Second * 60,
TimeoutUVD: time.Second * 60,
},
Registration: webauthn.TimeoutConfig{
Enforce: true,
Timeout: time.Second * 60,
TimeoutUVD: time.Second * 60,
},
},
}
wa, _ := webauthn.New(webauthnConfig)
return &WebAuthnService{db: db, webAuthn: wa}
}
func (s *WebAuthnService) BeginRegistration(userID string) (*model.PublicKeyCredentialCreationOptions, error) {
var user model.User
if err := s.db.Preload("Credentials").Find(&user, "id = ?", userID).Error; err != nil {
return nil, err
}
options, session, err := s.webAuthn.BeginRegistration(&user, webauthn.WithResidentKeyRequirement(protocol.ResidentKeyRequirementRequired), webauthn.WithExclusions(user.WebAuthnCredentialDescriptors()))
if err != nil {
return nil, err
}
sessionToStore := &model.WebauthnSession{
ExpiresAt: session.Expires,
Challenge: session.Challenge,
UserVerification: string(session.UserVerification),
}
if err := s.db.Create(&sessionToStore).Error; err != nil {
return nil, err
}
return &model.PublicKeyCredentialCreationOptions{
Response: options.Response,
SessionID: sessionToStore.ID,
Timeout: s.webAuthn.Config.Timeouts.Registration.Timeout,
}, nil
}
func (s *WebAuthnService) VerifyRegistration(sessionID, userID string, r *http.Request) (*model.WebauthnCredential, error) {
var storedSession model.WebauthnSession
if err := s.db.First(&storedSession, "id = ?", sessionID).Error; err != nil {
return nil, err
}
session := webauthn.SessionData{
Challenge: storedSession.Challenge,
Expires: storedSession.ExpiresAt,
UserID: []byte(userID),
}
var user model.User
if err := s.db.Find(&user, "id = ?", userID).Error; err != nil {
return nil, err
}
credential, err := s.webAuthn.FinishRegistration(&user, session, r)
if err != nil {
return nil, err
}
credentialToStore := model.WebauthnCredential{
Name: "New Passkey",
CredentialID: string(credential.ID),
AttestationType: credential.AttestationType,
PublicKey: credential.PublicKey,
Transport: credential.Transport,
UserID: user.ID,
BackupEligible: credential.Flags.BackupEligible,
BackupState: credential.Flags.BackupState,
}
if err := s.db.Create(&credentialToStore).Error; err != nil {
return nil, err
}
return &credentialToStore, nil
}
func (s *WebAuthnService) BeginLogin() (*model.PublicKeyCredentialRequestOptions, error) {
options, session, err := s.webAuthn.BeginDiscoverableLogin()
if err != nil {
return nil, err
}
sessionToStore := &model.WebauthnSession{
ExpiresAt: session.Expires,
Challenge: session.Challenge,
UserVerification: string(session.UserVerification),
}
if err := s.db.Create(&sessionToStore).Error; err != nil {
return nil, err
}
return &model.PublicKeyCredentialRequestOptions{
Response: options.Response,
SessionID: sessionToStore.ID,
Timeout: s.webAuthn.Config.Timeouts.Registration.Timeout,
}, nil
}
func (s *WebAuthnService) VerifyLogin(sessionID, userID string, credentialAssertionData *protocol.ParsedCredentialAssertionData) (*model.User, error) {
var storedSession model.WebauthnSession
if err := s.db.First(&storedSession, "id = ?", sessionID).Error; err != nil {
return nil, err
}
session := webauthn.SessionData{
Challenge: storedSession.Challenge,
Expires: storedSession.ExpiresAt,
}
var user *model.User
_, err := s.webAuthn.ValidateDiscoverableLogin(func(_, userHandle []byte) (webauthn.User, error) {
if err := s.db.Preload("Credentials").First(&user, "id = ?", string(userHandle)).Error; err != nil {
return nil, err
}
return user, nil
}, session, credentialAssertionData)
if err != nil {
return nil, common.ErrInvalidCredentials
}
if err := s.db.Find(&user, "id = ?", userID).Error; err != nil {
return nil, err
}
return user, nil
}
func (s *WebAuthnService) ListCredentials(userID string) ([]model.WebauthnCredential, error) {
var credentials []model.WebauthnCredential
if err := s.db.Find(&credentials, "user_id = ?", userID).Error; err != nil {
return nil, err
}
return credentials, nil
}
func (s *WebAuthnService) DeleteCredential(userID, credentialID string) error {
var credential model.WebauthnCredential
if err := s.db.First(&credential, "id = ? AND user_id = ?", credentialID, userID).Error; err != nil {
return err
}
if err := s.db.Delete(&credential).Error; err != nil {
return err
}
return nil
}
func (s *WebAuthnService) UpdateCredential(userID, credentialID, name string) error {
var credential model.WebauthnCredential
if err := s.db.Where("id = ? AND user_id = ?", credentialID, userID).First(&credential).Error; err != nil {
return err
}
credential.Name = name
if err := s.db.Save(&credential).Error; err != nil {
return err
}
return nil
}

View File

@@ -2,6 +2,7 @@ package utils
import (
"io"
"mime/multipart"
"os"
"path/filepath"
"strings"
@@ -71,3 +72,24 @@ func copyFile(srcFilePath, destFilePath string) error {
return nil
}
func SaveFile(file *multipart.FileHeader, dst string) error {
src, err := file.Open()
if err != nil {
return err
}
defer src.Close()
if err = os.MkdirAll(filepath.Dir(dst), 0o750); err != nil {
return err
}
out, err := os.Create(dst)
if err != nil {
return err
}
defer out.Close()
_, err = io.Copy(out, src)
return err
}

View File

@@ -1,15 +1,23 @@
package utils
import (
"errors"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"log"
"net/http"
"strings"
)
func UnknownHandlerError(c *gin.Context, err error) {
log.Println(err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Something went wrong"})
if errors.Is(err, gorm.ErrRecordNotFound) {
HandlerError(c, http.StatusNotFound, "Record not found")
return
} else {
log.Println(err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Something went wrong"})
}
}
func HandlerError(c *gin.Context, statusCode int, message string) {

View File

@@ -1,9 +1,7 @@
package utils
import (
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"strconv"
)
type PaginationResponse struct {
@@ -12,10 +10,7 @@ type PaginationResponse struct {
CurrentPage int `json:"currentPage"`
}
func Paginate(c *gin.Context, db *gorm.DB, result interface{}) (PaginationResponse, error) {
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("limit", "10"))
func Paginate(page int, pageSize int, db *gorm.DB, result interface{}) (PaginationResponse, error) {
if page < 1 {
page = 1
}

View File

@@ -0,0 +1,2 @@
ALTER TABLE app_config_variables
RENAME TO application_configuration_variables;

View File

@@ -0,0 +1,2 @@
ALTER TABLE application_configuration_variables
RENAME TO app_config_variables;