diff --git a/backend/internal/bootstrap/bootstrap.go b/backend/internal/bootstrap/bootstrap.go index 7fdabb6..3d120ae 100644 --- a/backend/internal/bootstrap/bootstrap.go +++ b/backend/internal/bootstrap/bootstrap.go @@ -2,7 +2,6 @@ package bootstrap import ( _ "github.com/golang-migrate/migrate/v4/source/file" - "github.com/stonith404/pocket-id/backend/internal/job" "github.com/stonith404/pocket-id/backend/internal/service" ) @@ -11,6 +10,5 @@ func Bootstrap() { appConfigService := service.NewAppConfigService(db) initApplicationImages() - job.RegisterJobs(db) initRouter(db, appConfigService) } diff --git a/backend/internal/bootstrap/router_bootstrap.go b/backend/internal/bootstrap/router_bootstrap.go index 9538535..993c147 100644 --- a/backend/internal/bootstrap/router_bootstrap.go +++ b/backend/internal/bootstrap/router_bootstrap.go @@ -38,21 +38,25 @@ func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) { auditLogService := service.NewAuditLogService(db, appConfigService, emailService, geoLiteService) jwtService := service.NewJwtService(appConfigService) webauthnService := service.NewWebAuthnService(db, jwtService, auditLogService, appConfigService) - userService := service.NewUserService(db, jwtService, auditLogService) + userService := service.NewUserService(db, jwtService, auditLogService, emailService) customClaimService := service.NewCustomClaimService(db) oidcService := service.NewOidcService(db, jwtService, appConfigService, auditLogService, customClaimService) testService := service.NewTestService(db, appConfigService) userGroupService := service.NewUserGroupService(db) ldapService := service.NewLdapService(db, appConfigService, userService, userGroupService) + rateLimitMiddleware := middleware.NewRateLimitMiddleware() + + // Setup global middleware r.Use(middleware.NewCorsMiddleware().Add()) r.Use(middleware.NewErrorHandlerMiddleware().Add()) - r.Use(middleware.NewRateLimitMiddleware().Add(rate.Every(time.Second), 60)) + r.Use(rateLimitMiddleware.Add(rate.Every(time.Second), 60)) r.Use(middleware.NewJwtAuthMiddleware(jwtService, true).Add(false)) job.RegisterLdapJobs(ldapService, appConfigService) + job.RegisterDbCleanupJobs(db) - // Initialize middleware + // Initialize middleware for specific routes jwtAuthMiddleware := middleware.NewJwtAuthMiddleware(jwtService, false) fileSizeLimitMiddleware := middleware.NewFileSizeLimitMiddleware() diff --git a/backend/internal/common/errors.go b/backend/internal/common/errors.go index 84393d6..0a09ff5 100644 --- a/backend/internal/common/errors.go +++ b/backend/internal/common/errors.go @@ -97,7 +97,7 @@ func (e *MissingPermissionError) HttpStatusCode() int { return http.StatusForbid type TooManyRequestsError struct{} func (e *TooManyRequestsError) Error() string { - return "Too many requests. Please wait a while before trying again." + return "Too many requests" } func (e *TooManyRequestsError) HttpStatusCode() int { return http.StatusTooManyRequests } diff --git a/backend/internal/controller/user_controller.go b/backend/internal/controller/user_controller.go index a210e03..f970130 100644 --- a/backend/internal/controller/user_controller.go +++ b/backend/internal/controller/user_controller.go @@ -30,6 +30,7 @@ func NewUserController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.Jwt group.POST("/users/:id/one-time-access-token", jwtAuthMiddleware.Add(true), uc.createOneTimeAccessTokenHandler) group.POST("/one-time-access-token/:token", rateLimitMiddleware.Add(rate.Every(10*time.Second), 5), uc.exchangeOneTimeAccessTokenHandler) group.POST("/one-time-access-token/setup", uc.getSetupAccessTokenHandler) + group.POST("/one-time-access-email", rateLimitMiddleware.Add(rate.Every(10*time.Minute), 3), uc.requestOneTimeAccessEmailHandler) } type UserController struct { @@ -145,7 +146,7 @@ func (uc *UserController) createOneTimeAccessTokenHandler(c *gin.Context) { return } - token, err := uc.userService.CreateOneTimeAccessToken(input.UserID, input.ExpiresAt, c.ClientIP(), c.Request.UserAgent()) + token, err := uc.userService.CreateOneTimeAccessToken(input.UserID, input.ExpiresAt) if err != nil { c.Error(err) return @@ -154,8 +155,24 @@ func (uc *UserController) createOneTimeAccessTokenHandler(c *gin.Context) { c.JSON(http.StatusCreated, gin.H{"token": token}) } +func (uc *UserController) requestOneTimeAccessEmailHandler(c *gin.Context) { + var input dto.OneTimeAccessEmailDto + if err := c.ShouldBindJSON(&input); err != nil { + c.Error(err) + return + } + + err := uc.userService.RequestOneTimeAccessEmail(input.Email, input.RedirectPath) + if err != nil { + c.Error(err) + return + } + + c.Status(http.StatusNoContent) +} + func (uc *UserController) exchangeOneTimeAccessTokenHandler(c *gin.Context) { - user, token, err := uc.userService.ExchangeOneTimeAccessToken(c.Param("token")) + user, token, err := uc.userService.ExchangeOneTimeAccessToken(c.Param("token"), c.ClientIP(), c.Request.UserAgent()) if err != nil { c.Error(err) return diff --git a/backend/internal/dto/app_config_dto.go b/backend/internal/dto/app_config_dto.go index 6a66a61..5a03c2a 100644 --- a/backend/internal/dto/app_config_dto.go +++ b/backend/internal/dto/app_config_dto.go @@ -16,7 +16,6 @@ type AppConfigUpdateDto struct { SessionDuration string `json:"sessionDuration" binding:"required"` EmailsVerified string `json:"emailsVerified" binding:"required"` AllowOwnAccountEdit string `json:"allowOwnAccountEdit" binding:"required"` - EmailEnabled string `json:"emailEnabled" binding:"required"` SmtHost string `json:"smtpHost"` SmtpPort string `json:"smtpPort"` SmtpFrom string `json:"smtpFrom" binding:"omitempty,email"` @@ -38,4 +37,6 @@ type AppConfigUpdateDto struct { LdapAttributeGroupUniqueIdentifier string `json:"ldapAttributeGroupUniqueIdentifier"` LdapAttributeGroupName string `json:"ldapAttributeGroupName"` LdapAttributeAdminGroup string `json:"ldapAttributeAdminGroup"` + EmailOneTimeAccessEnabled string `json:"emailOneTimeAccessEnabled" binding:"required"` + EmailLoginNotificationEnabled string `json:"emailLoginNotificationEnabled" binding:"required"` } diff --git a/backend/internal/dto/user_dto.go b/backend/internal/dto/user_dto.go index 9d93da6..930766d 100644 --- a/backend/internal/dto/user_dto.go +++ b/backend/internal/dto/user_dto.go @@ -26,3 +26,8 @@ type OneTimeAccessTokenCreateDto struct { UserID string `json:"userId" binding:"required"` ExpiresAt time.Time `json:"expiresAt" binding:"required"` } + +type OneTimeAccessEmailDto struct { + Email string `json:"email" binding:"required,email"` + RedirectPath string `json:"redirectPath"` +} diff --git a/backend/internal/job/db_cleanup.go b/backend/internal/job/db_cleanup.go index 00ce17f..0eebdba 100644 --- a/backend/internal/job/db_cleanup.go +++ b/backend/internal/job/db_cleanup.go @@ -10,7 +10,7 @@ import ( "time" ) -func RegisterJobs(db *gorm.DB) { +func RegisterDbCleanupJobs(db *gorm.DB) { scheduler, err := gocron.NewScheduler() if err != nil { log.Fatalf("Failed to create a new scheduler: %s", err) diff --git a/backend/internal/middleware/rate_limit.go b/backend/internal/middleware/rate_limit.go index 30c7f2c..83ef9f7 100644 --- a/backend/internal/middleware/rate_limit.go +++ b/backend/internal/middleware/rate_limit.go @@ -16,8 +16,12 @@ func NewRateLimitMiddleware() *RateLimitMiddleware { } func (m *RateLimitMiddleware) Add(limit rate.Limit, burst int) gin.HandlerFunc { + // Map to store the rate limiters per IP + var clients = make(map[string]*client) + var mu sync.Mutex + // Start the cleanup routine - go cleanupClients() + go cleanupClients(&mu, clients) return func(c *gin.Context) { ip := c.ClientIP() @@ -29,7 +33,7 @@ func (m *RateLimitMiddleware) Add(limit rate.Limit, burst int) gin.HandlerFunc { return } - limiter := getLimiter(ip, limit, burst) + limiter := getLimiter(ip, limit, burst, &mu, clients) if !limiter.Allow() { c.Error(&common.TooManyRequestsError{}) c.Abort() @@ -45,12 +49,8 @@ type client struct { lastSeen time.Time } -// Map to store the rate limiters per IP -var clients = make(map[string]*client) -var mu sync.Mutex - // Cleanup routine to remove stale clients that haven't been seen for a while -func cleanupClients() { +func cleanupClients(mu *sync.Mutex, clients map[string]*client) { for { time.Sleep(time.Minute) mu.Lock() @@ -64,7 +64,7 @@ func cleanupClients() { } // getLimiter retrieves the rate limiter for a given IP address, creating one if it doesn't exist -func getLimiter(ip string, limit rate.Limit, burst int) *rate.Limiter { +func getLimiter(ip string, limit rate.Limit, burst int, mu *sync.Mutex, clients map[string]*client) *rate.Limiter { mu.Lock() defer mu.Unlock() diff --git a/backend/internal/model/app_config.go b/backend/internal/model/app_config.go index 59475bd..ee65561 100644 --- a/backend/internal/model/app_config.go +++ b/backend/internal/model/app_config.go @@ -20,7 +20,6 @@ type AppConfig struct { LogoLightImageType AppConfigVariable LogoDarkImageType AppConfigVariable // Email - EmailEnabled AppConfigVariable SmtpHost AppConfigVariable SmtpPort AppConfigVariable SmtpFrom AppConfigVariable @@ -28,6 +27,8 @@ type AppConfig struct { SmtpPassword AppConfigVariable SmtpTls AppConfigVariable SmtpSkipCertVerify AppConfigVariable + EmailLoginNotificationEnabled AppConfigVariable + EmailOneTimeAccessEnabled AppConfigVariable // LDAP LdapEnabled AppConfigVariable LdapUrl AppConfigVariable diff --git a/backend/internal/service/app_config_service.go b/backend/internal/service/app_config_service.go index dd33e16..8ffcb14 100644 --- a/backend/internal/service/app_config_service.go +++ b/backend/internal/service/app_config_service.go @@ -73,12 +73,7 @@ var defaultDbConfig = model.AppConfig{ IsInternal: true, DefaultValue: "svg", }, - // Email - EmailEnabled: model.AppConfigVariable{ - Key: "emailEnabled", - Type: "bool", - DefaultValue: "false", - }, + // Email SmtpHost: model.AppConfigVariable{ Key: "smtpHost", Type: "string", @@ -109,6 +104,17 @@ var defaultDbConfig = model.AppConfig{ Type: "bool", DefaultValue: "false", }, + EmailLoginNotificationEnabled: model.AppConfigVariable{ + Key: "emailLoginNotificationEnabled", + Type: "bool", + DefaultValue: "false", + }, + EmailOneTimeAccessEnabled: model.AppConfigVariable{ + Key: "emailOneTimeAccessEnabled", + Type: "bool", + IsPublic: true, + DefaultValue: "false", + }, // LDAP LdapEnabled: model.AppConfigVariable{ Key: "ldapEnabled", @@ -182,6 +188,13 @@ func (s *AppConfigService) UpdateAppConfig(input dto.AppConfigUpdateDto) ([]mode key := field.Tag.Get("json") value := rv.FieldByName(field.Name).String() + // If the emailEnabled is set to false, disable the emailOneTimeAccessEnabled + if key == s.DbConfig.EmailOneTimeAccessEnabled.Key { + if rv.FieldByName("EmailEnabled").String() == "false" { + value = "false" + } + } + var appConfigVariable model.AppConfigVariable if err := tx.First(&appConfigVariable, "key = ? AND is_internal = false", key).Error; err != nil { tx.Rollback() diff --git a/backend/internal/service/audit_log_service.go b/backend/internal/service/audit_log_service.go index 7c318d1..32af610 100644 --- a/backend/internal/service/audit_log_service.go +++ b/backend/internal/service/audit_log_service.go @@ -58,8 +58,8 @@ func (s *AuditLogService) CreateNewSignInWithEmail(ipAddress, userAgent, userID return createdAuditLog } - // If the user hasn't logged in from the same device before, send an email - if count <= 1 { + // If the user hasn't logged in from the same device before and email notifications are enabled, send an email + if s.appConfigService.DbConfig.EmailLoginNotificationEnabled.Value == "true" && count <= 1 { go func() { var user model.User s.db.Where("id = ?", userID).First(&user) diff --git a/backend/internal/service/email_service.go b/backend/internal/service/email_service.go index c6e5d70..06c7f98 100644 --- a/backend/internal/service/email_service.go +++ b/backend/internal/service/email_service.go @@ -3,7 +3,6 @@ package service import ( "bytes" "crypto/tls" - "errors" "fmt" "github.com/stonith404/pocket-id/backend/internal/common" "github.com/stonith404/pocket-id/backend/internal/model" @@ -16,8 +15,13 @@ import ( "net/smtp" "net/textproto" ttemplate "text/template" + "time" ) +var netDialer = &net.Dialer{ + Timeout: 3 * time.Second, +} + type EmailService struct { appConfigService *AppConfigService db *gorm.DB @@ -58,11 +62,6 @@ func (srv *EmailService) SendTestEmail(recipientUserId string) error { } func SendEmail[V any](srv *EmailService, toEmail email.Address, template email.Template[V], tData *V) error { - // Check if SMTP settings are set - if srv.appConfigService.DbConfig.EmailEnabled.Value != "true" { - return errors.New("email not enabled") - } - data := &email.TemplateData[V]{ AppName: srv.appConfigService.DbConfig.AppName.Value, LogoURL: common.EnvConfig.AppURL + "/api/application-configuration/logo", @@ -112,11 +111,13 @@ func SendEmail[V any](srv *EmailService, toEmail email.Address, template email.T tlsConfig, ) } - defer client.Quit() + if err != nil { return fmt.Errorf("failed to connect to SMTP server: %w", err) } + defer client.Close() + smtpUser := srv.appConfigService.DbConfig.SmtpUser.Value smtpPassword := srv.appConfigService.DbConfig.SmtpPassword.Value @@ -141,7 +142,11 @@ func SendEmail[V any](srv *EmailService, toEmail email.Address, template email.T } func (srv *EmailService) connectToSmtpServerUsingImplicitTLS(serverAddr string, tlsConfig *tls.Config) (*smtp.Client, error) { - conn, err := tls.Dial("tcp", serverAddr, tlsConfig) + tlsDialer := &tls.Dialer{ + NetDialer: netDialer, + Config: tlsConfig, + } + conn, err := tlsDialer.Dial("tcp", serverAddr) if err != nil { return nil, fmt.Errorf("failed to connect to SMTP server: %w", err) } @@ -156,7 +161,7 @@ func (srv *EmailService) connectToSmtpServerUsingImplicitTLS(serverAddr string, } func (srv *EmailService) connectToSmtpServerUsingStartTLS(serverAddr string, tlsConfig *tls.Config) (*smtp.Client, error) { - conn, err := net.Dial("tcp", serverAddr) + conn, err := netDialer.Dial("tcp", serverAddr) if err != nil { return nil, fmt.Errorf("failed to connect to SMTP server: %w", err) } diff --git a/backend/internal/service/email_service_templates.go b/backend/internal/service/email_service_templates.go index e6d0fb8..1f0962c 100644 --- a/backend/internal/service/email_service_templates.go +++ b/backend/internal/service/email_service_templates.go @@ -9,7 +9,7 @@ import ( /** How to add new template: - pick unique and descriptive template ${name} (for example "login-with-new-device") -- in backend/email-templates/ create "${name}_html.tmpl" and "${name}_text.tmpl" +- in backend/resources/email-templates/ create "${name}_html.tmpl" and "${name}_text.tmpl" - create xxxxTemplate and xxxxTemplateData (for example NewLoginTemplate and NewLoginTemplateData) - Path *must* be ${name} - add xxxTemplate.Path to "emailTemplatePaths" at the end @@ -27,6 +27,13 @@ var NewLoginTemplate = email.Template[NewLoginTemplateData]{ }, } +var OneTimeAccessTemplate = email.Template[OneTimeAccessTemplateData]{ + Path: "one-time-access", + Title: func(data *email.TemplateData[OneTimeAccessTemplateData]) string { + return "One time access" + }, +} + var TestTemplate = email.Template[struct{}]{ Path: "test", Title: func(data *email.TemplateData[struct{}]) string { @@ -42,5 +49,9 @@ type NewLoginTemplateData struct { DateTime time.Time } +type OneTimeAccessTemplateData = struct { + Link string +} + // this is list of all template paths used for preloading templates -var emailTemplatesPaths = []string{NewLoginTemplate.Path, TestTemplate.Path} +var emailTemplatesPaths = []string{NewLoginTemplate.Path, OneTimeAccessTemplate.Path, TestTemplate.Path} diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index 0fe2709..455b9d1 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -2,12 +2,17 @@ package service import ( "errors" + "fmt" "github.com/stonith404/pocket-id/backend/internal/common" "github.com/stonith404/pocket-id/backend/internal/dto" "github.com/stonith404/pocket-id/backend/internal/model" "github.com/stonith404/pocket-id/backend/internal/model/types" "github.com/stonith404/pocket-id/backend/internal/utils" + "github.com/stonith404/pocket-id/backend/internal/utils/email" "gorm.io/gorm" + "log" + "net/url" + "strings" "time" ) @@ -15,10 +20,11 @@ type UserService struct { db *gorm.DB jwtService *JwtService auditLogService *AuditLogService + emailService *EmailService } -func NewUserService(db *gorm.DB, jwtService *JwtService, auditLogService *AuditLogService) *UserService { - return &UserService{db: db, jwtService: jwtService, auditLogService: auditLogService} +func NewUserService(db *gorm.DB, jwtService *JwtService, auditLogService *AuditLogService, emailService *EmailService) *UserService { + return &UserService{db: db, jwtService: jwtService, auditLogService: auditLogService, emailService: emailService} } func (s *UserService) ListUsers(searchTerm string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.User, utils.PaginationResponse, error) { @@ -99,7 +105,46 @@ func (s *UserService) UpdateUser(userID string, updatedUser dto.UserCreateDto, u return user, nil } -func (s *UserService) CreateOneTimeAccessToken(userID string, expiresAt time.Time, ipAddress, userAgent string) (string, error) { +func (s *UserService) RequestOneTimeAccessEmail(emailAddress, redirectPath string) error { + var user model.User + if err := s.db.Where("email = ?", emailAddress).First(&user).Error; err != nil { + // Do not return error if user not found to prevent email enumeration + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil + } else { + return err + } + } + + oneTimeAccessToken, err := s.CreateOneTimeAccessToken(user.ID, time.Now().Add(time.Hour)) + if err != nil { + return err + } + + link := fmt.Sprintf("%s/login/%s", common.EnvConfig.AppURL, oneTimeAccessToken) + + // Add redirect path to the link + if strings.HasPrefix(redirectPath, "/") { + encodedRedirectPath := url.QueryEscape(redirectPath) + link = fmt.Sprintf("%s?redirect=%s", link, encodedRedirectPath) + } + + go func() { + err := SendEmail(s.emailService, email.Address{ + Name: user.Username, + Email: user.Email, + }, OneTimeAccessTemplate, &OneTimeAccessTemplateData{ + Link: link, + }) + if err != nil { + log.Printf("Failed to send email to '%s': %v\n", user.Email, err) + } + }() + + return nil +} + +func (s *UserService) CreateOneTimeAccessToken(userID string, expiresAt time.Time) (string, error) { randomString, err := utils.GenerateRandomAlphanumericString(16) if err != nil { return "", err @@ -115,12 +160,10 @@ func (s *UserService) CreateOneTimeAccessToken(userID string, expiresAt time.Tim return "", err } - s.auditLogService.Create(model.AuditLogEventOneTimeAccessTokenSignIn, ipAddress, userAgent, userID, model.AuditLogData{}) - return oneTimeAccessToken.Token, nil } -func (s *UserService) ExchangeOneTimeAccessToken(token string) (model.User, string, error) { +func (s *UserService) ExchangeOneTimeAccessToken(token string, ipAddress, userAgent string) (model.User, string, error) { var oneTimeAccessToken model.OneTimeAccessToken if err := s.db.Where("token = ? AND expires_at > ?", token, datatype.DateTime(time.Now())).Preload("User").First(&oneTimeAccessToken).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { @@ -137,6 +180,10 @@ func (s *UserService) ExchangeOneTimeAccessToken(token string) (model.User, stri return model.User{}, "", err } + if ipAddress != "" && userAgent != "" { + s.auditLogService.Create(model.AuditLogEventOneTimeAccessTokenSignIn, ipAddress, userAgent, oneTimeAccessToken.User.ID, model.AuditLogData{}) + } + return oneTimeAccessToken.User, accessToken, nil } diff --git a/backend/internal/utils/email/email_service_templates.go b/backend/internal/utils/email/email_service_templates.go index 2701a03..d477272 100644 --- a/backend/internal/utils/email/email_service_templates.go +++ b/backend/internal/utils/email/email_service_templates.go @@ -9,8 +9,6 @@ import ( ttemplate "text/template" ) -const templateComponentsDir = "components" - type Template[V any] struct { Path string Title func(data *TemplateData[V]) string diff --git a/backend/resources/email-templates/components/style_html.tmpl b/backend/resources/email-templates/components/style_html.tmpl index d378806..f907dbe 100644 --- a/backend/resources/email-templates/components/style_html.tmpl +++ b/backend/resources/email-templates/components/style_html.tmpl @@ -76,5 +76,20 @@ font-size: 1rem; line-height: 1.5; } + .button { + border-radius: 0.375rem; + font-size: 1rem; + font-weight: 500; + background-color: #000000; + color: #ffffff; + padding: 0.7rem 1.5rem; + outline: none; + border: none; + text-decoration: none; + } + .button-container { + text-align: center; + margin-top: 24px; + } {{ end }} diff --git a/backend/resources/email-templates/login-with-new-device_html.tmpl b/backend/resources/email-templates/login-with-new-device_html.tmpl index c911e83..6c2c811 100644 --- a/backend/resources/email-templates/login-with-new-device_html.tmpl +++ b/backend/resources/email-templates/login-with-new-device_html.tmpl @@ -1,7 +1,7 @@ {{ define "base" }}
Client not found
{:else} -