feat: add support for Postgres database provider (#79)

This commit is contained in:
Elias Schneider
2024-12-12 17:21:28 +01:00
committed by GitHub
parent e9d83dd6c3
commit 9d20a98dbb
38 changed files with 433 additions and 81 deletions

View File

@@ -2,9 +2,13 @@ package bootstrap
import (
"errors"
"fmt"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database/sqlite3"
"github.com/golang-migrate/migrate/v4/database"
postgresMigrate "github.com/golang-migrate/migrate/v4/database/postgres"
sqliteMigrate "github.com/golang-migrate/migrate/v4/database/sqlite3"
"github.com/stonith404/pocket-id/backend/internal/common"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
@@ -19,15 +23,29 @@ func newDatabase() (db *gorm.DB) {
log.Fatalf("failed to connect to database: %v", err)
}
sqlDb, err := db.DB()
sqlDb.SetMaxOpenConns(1)
if err != nil {
log.Fatalf("failed to get sql.DB: %v", err)
}
driver, err := sqlite3.WithInstance(sqlDb, &sqlite3.Config{})
// Choose the correct driver for the database provider
var driver database.Driver
switch common.EnvConfig.DbProvider {
case common.DbProviderSqlite:
driver, err = sqliteMigrate.WithInstance(sqlDb, &sqliteMigrate.Config{})
case common.DbProviderPostgres:
driver, err = postgresMigrate.WithInstance(sqlDb, &postgresMigrate.Config{})
default:
log.Fatalf("unsupported database provider: %s", common.EnvConfig.DbProvider)
}
if err != nil {
log.Fatalf("failed to create migration driver: %v", err)
}
// Run migrations
m, err := migrate.NewWithDatabaseInstance(
"file://migrations",
"postgres", driver)
"file://migrations/"+string(common.EnvConfig.DbProvider),
"pocket-id", driver,
)
if err != nil {
log.Fatalf("failed to create migration instance: %v", err)
}
@@ -41,15 +59,20 @@ func newDatabase() (db *gorm.DB) {
}
func connectDatabase() (db *gorm.DB, err error) {
dbPath := common.EnvConfig.DBPath
var dialector gorm.Dialector
// Use in-memory database for testing
if common.EnvConfig.AppEnv == "test" {
dbPath = "file::memory:?cache=shared"
// Choose the correct database provider
switch common.EnvConfig.DbProvider {
case common.DbProviderSqlite:
dialector = sqlite.Open(common.EnvConfig.SqliteDBPath)
case common.DbProviderPostgres:
dialector = postgres.Open(common.EnvConfig.PostgresConnectionString)
default:
return nil, fmt.Errorf("unsupported database provider: %s", common.EnvConfig.DbProvider)
}
for i := 1; i <= 3; i++ {
db, err = gorm.Open(sqlite.Open(dbPath), &gorm.Config{
db, err = gorm.Open(dialector, &gorm.Config{
TranslateError: true,
Logger: getLogger(),
})

View File

@@ -6,32 +6,55 @@ import (
"log"
)
type DbProvider string
const (
DbProviderSqlite DbProvider = "sqlite"
DbProviderPostgres DbProvider = "postgres"
)
type EnvConfigSchema struct {
AppEnv string `env:"APP_ENV"`
AppURL string `env:"PUBLIC_APP_URL"`
DBPath string `env:"DB_PATH"`
UploadPath string `env:"UPLOAD_PATH"`
Port string `env:"BACKEND_PORT"`
Host string `env:"HOST"`
EmailTemplatesPath string `env:"EMAIL_TEMPLATES_PATH"`
MaxMindLicenseKey string `env:"MAXMIND_LICENSE_KEY"`
GeoLiteDBPath string `env:"GEOLITE_DB_PATH"`
AppEnv string `env:"APP_ENV"`
AppURL string `env:"PUBLIC_APP_URL"`
DbProvider DbProvider `env:"DB_PROVIDER"`
SqliteDBPath string `env:"SQLITE_DB_PATH"`
PostgresConnectionString string `env:"POSTGRES_CONNECTION_STRING"`
UploadPath string `env:"UPLOAD_PATH"`
Port string `env:"BACKEND_PORT"`
Host string `env:"HOST"`
EmailTemplatesPath string `env:"EMAIL_TEMPLATES_PATH"`
MaxMindLicenseKey string `env:"MAXMIND_LICENSE_KEY"`
GeoLiteDBPath string `env:"GEOLITE_DB_PATH"`
}
var EnvConfig = &EnvConfigSchema{
AppEnv: "production",
DBPath: "data/pocket-id.db",
UploadPath: "data/uploads",
AppURL: "http://localhost",
Port: "8080",
Host: "localhost",
EmailTemplatesPath: "./email-templates",
MaxMindLicenseKey: "",
GeoLiteDBPath: "data/GeoLite2-City.mmdb",
AppEnv: "production",
DbProvider: "sqlite",
SqliteDBPath: "data/pocket-id.db",
PostgresConnectionString: "",
UploadPath: "data/uploads",
AppURL: "http://localhost",
Port: "8080",
Host: "localhost",
EmailTemplatesPath: "./email-templates",
MaxMindLicenseKey: "",
GeoLiteDBPath: "data/GeoLite2-City.mmdb",
}
func init() {
if err := env.ParseWithOptions(EnvConfig, env.Options{}); err != nil {
log.Fatal(err)
}
// Validate the environment variables
if EnvConfig.DbProvider != DbProviderSqlite && EnvConfig.DbProvider != DbProviderPostgres {
log.Fatal("Invalid DB_PROVIDER value. Must be 'sqlite' or 'postgres'")
}
if EnvConfig.DbProvider == DbProviderPostgres && EnvConfig.PostgresConnectionString == "" {
log.Fatal("Missing POSTGRES_CONNECTION_STRING environment variable")
}
if EnvConfig.DbProvider == DbProviderSqlite && EnvConfig.SqliteDBPath == "" {
log.Fatal("Missing SQLITE_DB_PATH environment variable")
}
}

View File

@@ -91,9 +91,7 @@ func (wc *WebauthnController) verifyLoginHandler(c *gin.Context) {
return
}
userID := c.GetString("userID")
user, token, err := wc.webAuthnService.VerifyLogin(sessionID, userID, credentialAssertionData, c.ClientIP(), c.Request.UserAgent())
user, token, err := wc.webAuthnService.VerifyLogin(sessionID, credentialAssertionData, c.ClientIP(), c.Request.UserAgent())
if err != nil {
c.Error(err)
return

View File

@@ -2,6 +2,7 @@ package datatype
import (
"database/sql/driver"
"github.com/stonith404/pocket-id/backend/internal/common"
"time"
)
@@ -14,7 +15,11 @@ func (date *DateTime) Scan(value interface{}) (err error) {
}
func (date DateTime) Value() (driver.Value, error) {
return time.Time(date).Unix(), nil
if common.EnvConfig.DbProvider == common.DbProviderSqlite {
return time.Time(date).Unix(), nil
} else {
return time.Time(date), nil
}
}
func (date DateTime) UTC() time.Time {

View File

@@ -33,7 +33,7 @@ func (u User) WebAuthnCredentials() []webauthn.Credential {
for i, credential := range u.Credentials {
credentials[i] = webauthn.Credential{
ID: []byte(credential.CredentialID),
ID: credential.CredentialID,
AttestationType: credential.AttestationType,
PublicKey: credential.PublicKey,
Transport: credential.Transport,

View File

@@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"github.com/go-webauthn/webauthn/protocol"
datatype "github.com/stonith404/pocket-id/backend/internal/model/types"
"time"
)
@@ -12,7 +13,7 @@ type WebauthnSession struct {
Base
Challenge string
ExpiresAt time.Time
ExpiresAt datatype.DateTime
UserVerification string
}
@@ -20,7 +21,7 @@ type WebauthnCredential struct {
Base
Name string
CredentialID string
CredentialID []byte
PublicKey []byte
AttestationType string
Transport AuthenticatorTransportList

View File

@@ -60,7 +60,7 @@ func (s *TestService) SeedDatabase() error {
userGroups := []model.UserGroup{
{
Base: model.Base{
ID: "4110f814-56f1-4b28-8998-752b69bc97c0e",
ID: "c7ae7c01-28a3-4f3c-9572-1ee734ea8368",
},
Name: "developers",
FriendlyName: "Developers",
@@ -146,7 +146,7 @@ func (s *TestService) SeedDatabase() error {
webauthnCredentials := []model.WebauthnCredential{
{
Name: "Passkey 1",
CredentialID: "test-credential-1",
CredentialID: []byte("test-credential-1"),
PublicKey: publicKey1,
AttestationType: "none",
Transport: model.AuthenticatorTransportList{protocol.Internal},
@@ -154,7 +154,7 @@ func (s *TestService) SeedDatabase() error {
},
{
Name: "Passkey 2",
CredentialID: "test-credential-2",
CredentialID: []byte("test-credential-2"),
PublicKey: publicKey2,
AttestationType: "none",
Transport: model.AuthenticatorTransportList{protocol.Internal},
@@ -169,7 +169,7 @@ func (s *TestService) SeedDatabase() error {
webauthnSession := model.WebauthnSession{
Challenge: "challenge",
ExpiresAt: time.Now().Add(1 * time.Hour),
ExpiresAt: datatype.DateTime(time.Now().Add(1 * time.Hour)),
UserVerification: "preferred",
}
if err := tx.Create(&webauthnSession).Error; err != nil {
@@ -183,13 +183,29 @@ func (s *TestService) SeedDatabase() error {
func (s *TestService) ResetDatabase() error {
err := s.db.Transaction(func(tx *gorm.DB) error {
var tables []string
if err := tx.Raw("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%' AND name != 'schema_migrations';").Scan(&tables).Error; err != nil {
return err
switch common.EnvConfig.DbProvider {
case common.DbProviderSqlite:
// Query to get all tables for SQLite
if err := tx.Raw("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%' AND name != 'schema_migrations';").Scan(&tables).Error; err != nil {
return err
}
case common.DbProviderPostgres:
// Query to get all tables for PostgreSQL
if err := tx.Raw(`
SELECT tablename
FROM pg_tables
WHERE schemaname = 'public' AND tablename != 'schema_migrations';
`).Scan(&tables).Error; err != nil {
return err
}
default:
return fmt.Errorf("unsupported database provider: %s", common.EnvConfig.DbProvider)
}
// Delete all rows from all tables
for _, table := range tables {
if err := tx.Exec("DELETE FROM " + table).Error; err != nil {
if err := tx.Exec(fmt.Sprintf("DELETE FROM %s;", table)).Error; err != nil {
return err
}
}

View File

@@ -5,6 +5,7 @@ import (
"github.com/go-webauthn/webauthn/webauthn"
"github.com/stonith404/pocket-id/backend/internal/common"
"github.com/stonith404/pocket-id/backend/internal/model"
datatype "github.com/stonith404/pocket-id/backend/internal/model/types"
"github.com/stonith404/pocket-id/backend/internal/utils"
"gorm.io/gorm"
"net/http"
@@ -55,7 +56,7 @@ func (s *WebAuthnService) BeginRegistration(userID string) (*model.PublicKeyCred
}
sessionToStore := &model.WebauthnSession{
ExpiresAt: session.Expires,
ExpiresAt: datatype.DateTime(session.Expires),
Challenge: session.Challenge,
UserVerification: string(session.UserVerification),
}
@@ -79,7 +80,7 @@ func (s *WebAuthnService) VerifyRegistration(sessionID, userID string, r *http.R
session := webauthn.SessionData{
Challenge: storedSession.Challenge,
Expires: storedSession.ExpiresAt,
Expires: storedSession.ExpiresAt.ToTime(),
UserID: []byte(userID),
}
@@ -95,7 +96,7 @@ func (s *WebAuthnService) VerifyRegistration(sessionID, userID string, r *http.R
credentialToStore := model.WebauthnCredential{
Name: "New Passkey",
CredentialID: string(credential.ID),
CredentialID: credential.ID,
AttestationType: credential.AttestationType,
PublicKey: credential.PublicKey,
Transport: credential.Transport,
@@ -117,7 +118,7 @@ func (s *WebAuthnService) BeginLogin() (*model.PublicKeyCredentialRequestOptions
}
sessionToStore := &model.WebauthnSession{
ExpiresAt: session.Expires,
ExpiresAt: datatype.DateTime(session.Expires),
Challenge: session.Challenge,
UserVerification: string(session.UserVerification),
}
@@ -133,7 +134,7 @@ func (s *WebAuthnService) BeginLogin() (*model.PublicKeyCredentialRequestOptions
}, nil
}
func (s *WebAuthnService) VerifyLogin(sessionID, userID string, credentialAssertionData *protocol.ParsedCredentialAssertionData, ipAddress, userAgent string) (model.User, string, error) {
func (s *WebAuthnService) VerifyLogin(sessionID string, credentialAssertionData *protocol.ParsedCredentialAssertionData, ipAddress, userAgent string) (model.User, string, error) {
var storedSession model.WebauthnSession
if err := s.db.First(&storedSession, "id = ?", sessionID).Error; err != nil {
return model.User{}, "", err
@@ -141,7 +142,7 @@ func (s *WebAuthnService) VerifyLogin(sessionID, userID string, credentialAssert
session := webauthn.SessionData{
Challenge: storedSession.Challenge,
Expires: storedSession.ExpiresAt,
Expires: storedSession.ExpiresAt.ToTime(),
}
var user *model.User
@@ -156,10 +157,6 @@ func (s *WebAuthnService) VerifyLogin(sessionID, userID string, credentialAssert
return model.User{}, "", err
}
if err := s.db.Find(&user, "id = ?", userID).Error; err != nil {
return model.User{}, "", err
}
token, err := s.jwtService.GenerateAccessToken(*user)
if err != nil {
return model.User{}, "", err