mirror of
https://github.com/nikdoof/pocket-id.git
synced 2025-12-14 07:12:19 +00:00
refactor: use dependency injection in backend
This commit is contained in:
@@ -1,138 +0,0 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"github.com/caarlos0/env/v11"
|
||||
_ "github.com/joho/godotenv/autoload"
|
||||
"golang-rest-api-template/internal/model"
|
||||
"log"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
type EnvConfigSchema struct {
|
||||
AppEnv string `env:"APP_ENV"`
|
||||
AppURL string `env:"PUBLIC_APP_URL"`
|
||||
DBPath string `env:"DB_PATH"`
|
||||
UploadPath string `env:"UPLOAD_PATH"`
|
||||
Port string `env:"BACKEND_PORT"`
|
||||
Host string `env:"HOST"`
|
||||
}
|
||||
|
||||
var EnvConfig = &EnvConfigSchema{
|
||||
AppEnv: "production",
|
||||
DBPath: "data/pocket-id.db",
|
||||
UploadPath: "data/uploads",
|
||||
AppURL: "http://localhost",
|
||||
Port: "8080",
|
||||
Host: "localhost",
|
||||
}
|
||||
|
||||
var DbConfig = NewDefaultDbConfig()
|
||||
|
||||
func NewDefaultDbConfig() model.ApplicationConfiguration {
|
||||
return model.ApplicationConfiguration{
|
||||
AppName: model.ApplicationConfigurationVariable{
|
||||
Key: "appName",
|
||||
Type: "string",
|
||||
IsPublic: true,
|
||||
Value: "Pocket ID",
|
||||
},
|
||||
SessionDuration: model.ApplicationConfigurationVariable{
|
||||
Key: "sessionDuration",
|
||||
Type: "number",
|
||||
Value: "60",
|
||||
},
|
||||
BackgroundImageType: model.ApplicationConfigurationVariable{
|
||||
Key: "backgroundImageType",
|
||||
Type: "string",
|
||||
IsInternal: true,
|
||||
Value: "jpg",
|
||||
},
|
||||
LogoImageType: model.ApplicationConfigurationVariable{
|
||||
Key: "logoImageType",
|
||||
Type: "string",
|
||||
IsInternal: true,
|
||||
Value: "svg",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// LoadDbConfigFromDb refreshes the database configuration by loading the current values
|
||||
// from the database and updating the DbConfig struct.
|
||||
func LoadDbConfigFromDb() error {
|
||||
dbConfigReflectValue := reflect.ValueOf(&DbConfig).Elem()
|
||||
|
||||
for i := 0; i < dbConfigReflectValue.NumField(); i++ {
|
||||
dbConfigField := dbConfigReflectValue.Field(i)
|
||||
currentConfigVar := dbConfigField.Interface().(model.ApplicationConfigurationVariable)
|
||||
var storedConfigVar model.ApplicationConfigurationVariable
|
||||
if err := DB.First(&storedConfigVar, "key = ?", currentConfigVar.Key).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dbConfigField.Set(reflect.ValueOf(storedConfigVar))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// InitDbConfig creates the default configuration values in the database if they do not exist,
|
||||
// updates existing configurations if they differ from the default, and deletes any configurations
|
||||
// that are not in the default configuration.
|
||||
func InitDbConfig() {
|
||||
// Reflect to get the underlying value of DbConfig and its default configuration
|
||||
dbConfigReflectValue := reflect.ValueOf(&DbConfig).Elem()
|
||||
defaultDbConfig := NewDefaultDbConfig()
|
||||
defaultConfigReflectValue := reflect.ValueOf(&defaultDbConfig).Elem()
|
||||
defaultKeys := make(map[string]struct{})
|
||||
|
||||
// Iterate over the fields of DbConfig
|
||||
for i := 0; i < dbConfigReflectValue.NumField(); i++ {
|
||||
dbConfigField := dbConfigReflectValue.Field(i)
|
||||
currentConfigVar := dbConfigField.Interface().(model.ApplicationConfigurationVariable)
|
||||
defaultConfigVar := defaultConfigReflectValue.Field(i).Interface().(model.ApplicationConfigurationVariable)
|
||||
defaultKeys[currentConfigVar.Key] = struct{}{}
|
||||
|
||||
var storedConfigVar model.ApplicationConfigurationVariable
|
||||
if err := DB.First(&storedConfigVar, "key = ?", currentConfigVar.Key).Error; err != nil {
|
||||
// If the configuration does not exist, create it
|
||||
if err := DB.Create(&defaultConfigVar).Error; err != nil {
|
||||
log.Fatalf("Failed to create default configuration: %v", err)
|
||||
}
|
||||
dbConfigField.Set(reflect.ValueOf(defaultConfigVar))
|
||||
continue
|
||||
}
|
||||
|
||||
// Update existing configuration if it differs from the default
|
||||
if storedConfigVar.Type != defaultConfigVar.Type || storedConfigVar.IsPublic != defaultConfigVar.IsPublic || storedConfigVar.IsInternal != defaultConfigVar.IsInternal {
|
||||
storedConfigVar.Type = defaultConfigVar.Type
|
||||
storedConfigVar.IsPublic = defaultConfigVar.IsPublic
|
||||
storedConfigVar.IsInternal = defaultConfigVar.IsInternal
|
||||
if err := DB.Save(&storedConfigVar).Error; err != nil {
|
||||
log.Fatalf("Failed to update configuration: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Set the value in DbConfig
|
||||
dbConfigField.Set(reflect.ValueOf(storedConfigVar))
|
||||
}
|
||||
|
||||
// Delete any configurations not in the default keys
|
||||
var allConfigVars []model.ApplicationConfigurationVariable
|
||||
if err := DB.Find(&allConfigVars).Error; err != nil {
|
||||
log.Fatalf("Failed to retrieve existing configurations: %v", err)
|
||||
}
|
||||
|
||||
for _, config := range allConfigVars {
|
||||
if _, exists := defaultKeys[config.Key]; !exists {
|
||||
if err := DB.Delete(&config).Error; err != nil {
|
||||
log.Fatalf("Failed to delete outdated configuration: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
if err := env.ParseWithOptions(EnvConfig, env.Options{}); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
@@ -1,84 +0,0 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
"github.com/golang-migrate/migrate/v4/database/sqlite3"
|
||||
"gorm.io/gorm/logger"
|
||||
"log"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var DB *gorm.DB
|
||||
|
||||
func InitDatabase() {
|
||||
connectDatabase()
|
||||
sqlDb, err := DB.DB()
|
||||
if err != nil {
|
||||
log.Fatal("failed to get sql db", err)
|
||||
}
|
||||
driver, err := sqlite3.WithInstance(sqlDb, &sqlite3.Config{})
|
||||
m, err := migrate.NewWithDatabaseInstance(
|
||||
"file://migrations",
|
||||
"postgres", driver)
|
||||
if err != nil {
|
||||
log.Fatal("failed to create migration instance", err)
|
||||
}
|
||||
|
||||
err = m.Up()
|
||||
if err != nil && !errors.Is(err, migrate.ErrNoChange) {
|
||||
log.Fatal("failed to run migrations", err)
|
||||
}
|
||||
}
|
||||
|
||||
func connectDatabase() {
|
||||
var database *gorm.DB
|
||||
var err error
|
||||
|
||||
dbPath := EnvConfig.DBPath
|
||||
if EnvConfig.AppEnv == "test" {
|
||||
dbPath = "file::memory:?cache=shared"
|
||||
}
|
||||
|
||||
for i := 1; i <= 3; i++ {
|
||||
database, err = gorm.Open(sqlite.Open(dbPath), &gorm.Config{
|
||||
TranslateError: true,
|
||||
Logger: getLogger(),
|
||||
})
|
||||
if err == nil {
|
||||
break
|
||||
} else {
|
||||
log.Printf("Attempt %d: Failed to initialize database. Retrying...", i)
|
||||
time.Sleep(3 * time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
DB = database
|
||||
}
|
||||
|
||||
func getLogger() logger.Interface {
|
||||
isProduction := EnvConfig.AppEnv == "production"
|
||||
|
||||
var logLevel logger.LogLevel
|
||||
if isProduction {
|
||||
logLevel = logger.Error
|
||||
} else {
|
||||
logLevel = logger.Info
|
||||
}
|
||||
|
||||
// Create the GORM logger
|
||||
return logger.New(
|
||||
log.New(os.Stdout, "\r\n", log.LstdFlags),
|
||||
logger.Config{
|
||||
SlowThreshold: 200 * time.Millisecond,
|
||||
LogLevel: logLevel,
|
||||
IgnoreRecordNotFoundError: isProduction,
|
||||
ParameterizedQueries: isProduction,
|
||||
Colorful: !isProduction,
|
||||
},
|
||||
)
|
||||
}
|
||||
31
backend/internal/common/env_config.go
Normal file
31
backend/internal/common/env_config.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"github.com/caarlos0/env/v11"
|
||||
_ "github.com/joho/godotenv/autoload"
|
||||
"log"
|
||||
)
|
||||
|
||||
type EnvConfigSchema struct {
|
||||
AppEnv string `env:"APP_ENV"`
|
||||
AppURL string `env:"PUBLIC_APP_URL"`
|
||||
DBPath string `env:"DB_PATH"`
|
||||
UploadPath string `env:"UPLOAD_PATH"`
|
||||
Port string `env:"BACKEND_PORT"`
|
||||
Host string `env:"HOST"`
|
||||
}
|
||||
|
||||
var EnvConfig = &EnvConfigSchema{
|
||||
AppEnv: "production",
|
||||
DBPath: "data/pocket-id.db",
|
||||
UploadPath: "data/uploads",
|
||||
AppURL: "http://localhost",
|
||||
Port: "8080",
|
||||
Host: "localhost",
|
||||
}
|
||||
|
||||
func init() {
|
||||
if err := env.ParseWithOptions(EnvConfig, env.Options{}); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
18
backend/internal/common/errors.go
Normal file
18
backend/internal/common/errors.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package common
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrUsernameTaken = errors.New("username is already taken")
|
||||
ErrEmailTaken = errors.New("email is already taken")
|
||||
ErrSetupAlreadyCompleted = errors.New("setup already completed")
|
||||
ErrInvalidBody = errors.New("invalid request body")
|
||||
ErrTokenInvalidOrExpired = errors.New("token is invalid or expired")
|
||||
ErrOidcMissingAuthorization = errors.New("missing authorization")
|
||||
ErrOidcGrantTypeNotSupported = errors.New("grant type not supported")
|
||||
ErrOidcMissingClientCredentials = errors.New("client id or secret not provided")
|
||||
ErrOidcClientSecretInvalid = errors.New("invalid client secret")
|
||||
ErrOidcInvalidAuthorizationCode = errors.New("invalid authorization code")
|
||||
ErrFileTypeNotSupported = errors.New("file type not supported")
|
||||
ErrInvalidCredentials = errors.New("no user found with provided credentials")
|
||||
)
|
||||
@@ -1,209 +0,0 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"golang-rest-api-template/internal/model"
|
||||
"golang-rest-api-template/internal/utils"
|
||||
"log"
|
||||
"math/big"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
PrivateKey *rsa.PrivateKey
|
||||
PublicKey *rsa.PublicKey
|
||||
)
|
||||
|
||||
const (
|
||||
privateKeyPath = "data/keys/jwt_private_key.pem"
|
||||
publicKeyPath = "data/keys/jwt_public_key.pem"
|
||||
)
|
||||
|
||||
type accessTokenJWTClaims struct {
|
||||
jwt.RegisteredClaims
|
||||
IsAdmin bool `json:"isAdmin,omitempty"`
|
||||
}
|
||||
|
||||
// GenerateIDToken generates an ID token for the given user, clientID, scope and nonce.
|
||||
func GenerateIDToken(user model.User, clientID string, scope string, nonce string) (tokenString string, err error) {
|
||||
profileClaims := map[string]interface{}{
|
||||
"given_name": user.FirstName,
|
||||
"family_name": user.LastName,
|
||||
"email": user.Email,
|
||||
"preferred_username": user.Username,
|
||||
}
|
||||
|
||||
claims := jwt.MapClaims{
|
||||
"sub": user.ID,
|
||||
"aud": clientID,
|
||||
"exp": jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
|
||||
"iat": jwt.NewNumericDate(time.Now()),
|
||||
}
|
||||
|
||||
if nonce != "" {
|
||||
claims["nonce"] = nonce
|
||||
}
|
||||
if strings.Contains(scope, "profile") {
|
||||
for k, v := range profileClaims {
|
||||
claims[k] = v
|
||||
}
|
||||
}
|
||||
if strings.Contains(scope, "email") {
|
||||
claims["email"] = user.Email
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||||
signedToken, err := token.SignedString(PrivateKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return signedToken, nil
|
||||
}
|
||||
|
||||
// GenerateAccessToken generates an access token for the given user.
|
||||
func GenerateAccessToken(user model.User) (tokenString string, err error) {
|
||||
sessionDurationInMinutes, _ := strconv.Atoi(DbConfig.SessionDuration.Value)
|
||||
claim := accessTokenJWTClaims{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: user.ID,
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Duration(sessionDurationInMinutes) * time.Minute)),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
Audience: jwt.ClaimStrings{utils.GetHostFromURL(EnvConfig.AppURL)},
|
||||
},
|
||||
IsAdmin: user.IsAdmin,
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claim)
|
||||
tokenString, err = token.SignedString(PrivateKey)
|
||||
return tokenString, err
|
||||
}
|
||||
|
||||
// VerifyAccessToken verifies the given access token and returns the claims if the token is valid.
|
||||
func VerifyAccessToken(tokenString string) (*accessTokenJWTClaims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &accessTokenJWTClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
return PublicKey, nil
|
||||
})
|
||||
if err != nil || !token.Valid {
|
||||
return nil, errors.New("couldn't handle this token")
|
||||
}
|
||||
|
||||
claims, isValid := token.Claims.(*accessTokenJWTClaims)
|
||||
if !isValid {
|
||||
return nil, errors.New("can't parse claims")
|
||||
}
|
||||
|
||||
if !slices.Contains(claims.Audience, utils.GetHostFromURL(EnvConfig.AppURL)) {
|
||||
return nil, errors.New("audience doesn't match")
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
type JWK struct {
|
||||
Kty string `json:"kty"`
|
||||
Use string `json:"use"`
|
||||
Kid string `json:"kid"`
|
||||
Alg string `json:"alg"`
|
||||
N string `json:"n"`
|
||||
E string `json:"e"`
|
||||
}
|
||||
|
||||
// GetJWK returns the JSON Web Key (JWK) for the public key.
|
||||
func GetJWK() (JWK, error) {
|
||||
if PublicKey == nil {
|
||||
return JWK{}, errors.New("public key is not initialized")
|
||||
}
|
||||
|
||||
// Create JWK from RSA public key
|
||||
jwk := JWK{
|
||||
Kty: "RSA",
|
||||
Use: "sig",
|
||||
Kid: "1", // Key ID can be set to any identifier. Here it's statically set to "1"
|
||||
Alg: "RS256",
|
||||
N: base64.RawURLEncoding.EncodeToString(PublicKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(PublicKey.E)).Bytes()),
|
||||
}
|
||||
|
||||
return jwk, nil
|
||||
}
|
||||
|
||||
// generateKeys generates a new RSA key pair and saves the private and public keys to the data folder.
|
||||
func generateKeys() {
|
||||
if err := os.MkdirAll(filepath.Dir(privateKeyPath), 0700); err != nil {
|
||||
log.Fatal("Failed to create directories for keys", err)
|
||||
}
|
||||
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
log.Fatal("Failed to generate private key", err)
|
||||
}
|
||||
|
||||
privateKeyFile, err := os.Create(privateKeyPath)
|
||||
if err != nil {
|
||||
log.Fatal("Failed to create private key file", err)
|
||||
}
|
||||
defer privateKeyFile.Close()
|
||||
|
||||
privateKeyPEM := pem.EncodeToMemory(
|
||||
&pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
|
||||
},
|
||||
)
|
||||
_, err = privateKeyFile.Write(privateKeyPEM)
|
||||
if err != nil {
|
||||
log.Fatal("Failed to write private key file", err)
|
||||
}
|
||||
|
||||
publicKey := &privateKey.PublicKey
|
||||
publicKeyFile, err := os.Create(publicKeyPath)
|
||||
if err != nil {
|
||||
log.Fatal("Failed to create public key file", err)
|
||||
}
|
||||
defer publicKeyFile.Close()
|
||||
|
||||
publicKeyPEM := pem.EncodeToMemory(
|
||||
&pem.Block{
|
||||
Type: "RSA PUBLIC KEY",
|
||||
Bytes: x509.MarshalPKCS1PublicKey(publicKey),
|
||||
},
|
||||
)
|
||||
_, err = publicKeyFile.Write(publicKeyPEM)
|
||||
if err != nil {
|
||||
log.Fatal("Failed to write public key file", err)
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
if _, err := os.Stat(privateKeyPath); os.IsNotExist(err) {
|
||||
generateKeys()
|
||||
}
|
||||
|
||||
privateKeyBytes, err := os.ReadFile(privateKeyPath)
|
||||
if err != nil {
|
||||
log.Fatal("Can't read jwt private key", err)
|
||||
}
|
||||
PrivateKey, err = jwt.ParseRSAPrivateKeyFromPEM(privateKeyBytes)
|
||||
if err != nil {
|
||||
log.Fatal("Can't parse jwt private key", err)
|
||||
}
|
||||
|
||||
publicKeyBytes, err := os.ReadFile(publicKeyPath)
|
||||
if err != nil {
|
||||
log.Fatal("Can't read jwt public key", err)
|
||||
}
|
||||
PublicKey, err = jwt.ParseRSAPublicKeyFromPEM(publicKeyBytes)
|
||||
if err != nil {
|
||||
log.Fatal("Can't parse jwt public key", err)
|
||||
}
|
||||
}
|
||||
@@ -1,18 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"golang-rest-api-template/internal/common"
|
||||
"time"
|
||||
|
||||
"github.com/gin-contrib/cors"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func Cors() gin.HandlerFunc {
|
||||
return cors.New(cors.Config{
|
||||
AllowOrigins: []string{common.EnvConfig.AppURL},
|
||||
AllowMethods: []string{"*"},
|
||||
AllowHeaders: []string{"*"},
|
||||
MaxAge: 12 * time.Hour,
|
||||
})
|
||||
}
|
||||
@@ -1,40 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang-rest-api-template/internal/utils"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func LimitFileSize(maxSize int64) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, maxSize)
|
||||
if err := c.Request.ParseMultipartForm(maxSize); err != nil {
|
||||
utils.HandlerError(c, http.StatusRequestEntityTooLarge, fmt.Sprintf("The file can't be larger than %s bytes", formatFileSize(maxSize)))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// formatFileSize formats a file size in bytes to a human-readable string
|
||||
func formatFileSize(size int64) string {
|
||||
const (
|
||||
KB = 1 << (10 * 1)
|
||||
MB = 1 << (10 * 2)
|
||||
GB = 1 << (10 * 3)
|
||||
)
|
||||
|
||||
switch {
|
||||
case size >= GB:
|
||||
return fmt.Sprintf("%.2f GB", float64(size)/GB)
|
||||
case size >= MB:
|
||||
return fmt.Sprintf("%.2f MB", float64(size)/MB)
|
||||
case size >= KB:
|
||||
return fmt.Sprintf("%.2f KB", float64(size)/KB)
|
||||
default:
|
||||
return fmt.Sprintf("%d bytes", size)
|
||||
}
|
||||
}
|
||||
@@ -1,47 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang-rest-api-template/internal/common"
|
||||
"golang-rest-api-template/internal/utils"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func JWTAuth(adminOnly bool) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
|
||||
// Extract the token from the cookie or the Authorization header
|
||||
token, err := c.Cookie("access_token")
|
||||
if err != nil {
|
||||
authorizationHeaderSplitted := strings.Split(c.GetHeader("Authorization"), " ")
|
||||
if len(authorizationHeaderSplitted) == 2 {
|
||||
token = authorizationHeaderSplitted[1]
|
||||
} else {
|
||||
utils.HandlerError(c, http.StatusUnauthorized, "You're not signed in")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Verify the token
|
||||
claims, err := common.VerifyAccessToken(token)
|
||||
if err != nil {
|
||||
utils.HandlerError(c, http.StatusUnauthorized, "You're not signed in")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Check if the user is an admin
|
||||
if adminOnly && !claims.IsAdmin {
|
||||
utils.HandlerError(c, http.StatusForbidden, "You don't have permission to access this resource")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Set("userID", claims.Subject)
|
||||
c.Set("userIsAdmin", claims.IsAdmin)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
@@ -1,76 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"golang-rest-api-template/internal/common"
|
||||
"golang-rest-api-template/internal/utils"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// RateLimiter is a Gin middleware for rate limiting based on client IP
|
||||
func RateLimiter(limit rate.Limit, burst int) gin.HandlerFunc {
|
||||
// Start the cleanup routine
|
||||
go cleanupClients()
|
||||
|
||||
return func(c *gin.Context) {
|
||||
ip := c.ClientIP()
|
||||
|
||||
// Skip rate limiting for localhost and test environment
|
||||
// If the client ip is localhost the request comes from the frontend
|
||||
if ip == "127.0.0.1" || ip == "::1" || common.EnvConfig.AppEnv == "test" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
limiter := getLimiter(ip, limit, burst)
|
||||
if !limiter.Allow() {
|
||||
utils.HandlerError(c, http.StatusTooManyRequests, "Too many requests. Please wait a while before trying again.")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
type client struct {
|
||||
limiter *rate.Limiter
|
||||
lastSeen time.Time
|
||||
}
|
||||
|
||||
// Map to store the rate limiters per IP
|
||||
var clients = make(map[string]*client)
|
||||
var mu sync.Mutex
|
||||
|
||||
// Cleanup routine to remove stale clients that haven't been seen for a while
|
||||
func cleanupClients() {
|
||||
for {
|
||||
time.Sleep(time.Minute)
|
||||
mu.Lock()
|
||||
for ip, client := range clients {
|
||||
if time.Since(client.lastSeen) > 3*time.Minute {
|
||||
delete(clients, ip)
|
||||
}
|
||||
}
|
||||
mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// getLimiter retrieves the rate limiter for a given IP address, creating one if it doesn't exist
|
||||
func getLimiter(ip string, limit rate.Limit, burst int) *rate.Limiter {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
if client, exists := clients[ip]; exists {
|
||||
client.lastSeen = time.Now()
|
||||
return client.limiter
|
||||
}
|
||||
|
||||
limiter := rate.NewLimiter(limit, burst)
|
||||
clients[ip] = &client{limiter: limiter, lastSeen: time.Now()}
|
||||
return limiter
|
||||
}
|
||||
@@ -1,37 +0,0 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"github.com/go-webauthn/webauthn/webauthn"
|
||||
"golang-rest-api-template/internal/utils"
|
||||
"log"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
WebAuthn *webauthn.WebAuthn
|
||||
err error
|
||||
)
|
||||
|
||||
func init() {
|
||||
config := &webauthn.Config{
|
||||
RPDisplayName: DbConfig.AppName.Value,
|
||||
RPID: utils.GetHostFromURL(EnvConfig.AppURL),
|
||||
RPOrigins: []string{EnvConfig.AppURL},
|
||||
Timeouts: webauthn.TimeoutsConfig{
|
||||
Login: webauthn.TimeoutConfig{
|
||||
Enforce: true,
|
||||
Timeout: time.Second * 60,
|
||||
TimeoutUVD: time.Second * 60,
|
||||
},
|
||||
Registration: webauthn.TimeoutConfig{
|
||||
Enforce: true,
|
||||
Timeout: time.Second * 60,
|
||||
TimeoutUVD: time.Second * 60,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if WebAuthn, err = webauthn.New(config); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user