diff --git a/cmd/web/handlers.go b/cmd/web/handlers.go index 92a22b7..cddadaa 100644 --- a/cmd/web/handlers.go +++ b/cmd/web/handlers.go @@ -42,7 +42,7 @@ func favicon(w http.ResponseWriter, r *http.Request) { http.ServeFile(w, r, "favicon.ico") } -func authenticated_user(w http.ResponseWriter, r *http.Request) int32 { +func authenticatedUser(w http.ResponseWriter, r *http.Request) int32 { session, err := store.Get(r, "id") if err != nil { log.Println(err) @@ -79,7 +79,7 @@ func home(w http.ResponseWriter, r *http.Request) { return } - id := authenticated_user(w, r) + id := authenticatedUser(w, r) account, _ := users.GetAccount(id) active_subscription := subscriptions.HasActiveSubscription(id) @@ -143,7 +143,7 @@ func login(w http.ResponseWriter, r *http.Request) { if len(errors) > 0 { - err := text.Execute(w, TemplateData{AuthenticatedUser: authenticated_user(w, r), FormErrors: errors}) + err := text.Execute(w, TemplateData{AuthenticatedUser: authenticatedUser(w, r), FormErrors: errors}) if err != nil { log.Println(err) http.Error(w, "Internal Server Error", 500) @@ -155,7 +155,7 @@ func login(w http.ResponseWriter, r *http.Request) { id, err := users.Authenticate(username, password) if err == models.ErrInvalidCredentials { errors["generic"] = "Email or Password is incorrect" - err := text.Execute(w, TemplateData{AuthenticatedUser: authenticated_user(w, r), FormErrors: errors}) + err := text.Execute(w, TemplateData{AuthenticatedUser: authenticatedUser(w, r), FormErrors: errors}) if err != nil { log.Println(err) http.Error(w, "Internal Server Error", 500) @@ -180,7 +180,7 @@ func logout(w http.ResponseWriter, r *http.Request) { log.Println(err) } - id := authenticated_user(w, r) + id := authenticatedUser(w, r) account, err := users.GetAccount(id) if err != nil { @@ -215,7 +215,7 @@ func register(w http.ResponseWriter, r *http.Request) { log.Println(err) } - id := authenticated_user(w, r) + id := authenticatedUser(w, r) account, _ := users.GetAccount(id) // if err != nil { @@ -251,7 +251,7 @@ func register(w http.ResponseWriter, r *http.Request) { if len(errors) > 0 { - err := text.Execute(w, TemplateData{AuthenticatedUser: authenticated_user(w, r), FormErrors: errors}) + err := text.Execute(w, TemplateData{AuthenticatedUser: authenticatedUser(w, r), FormErrors: errors}) if err != nil { log.Println(err) http.Error(w, "Internal Server Error", 500) @@ -281,7 +281,7 @@ func account(w http.ResponseWriter, r *http.Request) { // log.Fatal(err) //} - id := authenticated_user(w, r) + id := authenticatedUser(w, r) account, err := users.GetAccount(id) if err != nil { @@ -314,9 +314,9 @@ func account(w http.ResponseWriter, r *http.Request) { // } } -func deleteaccount(w http.ResponseWriter, r *http.Request) { +func deleteAccount(w http.ResponseWriter, r *http.Request) { - id := authenticated_user(w, r) + id := authenticatedUser(w, r) switch r.Method { case http.MethodPost: @@ -337,7 +337,7 @@ func deleteaccount(w http.ResponseWriter, r *http.Request) { } func subscribe(w http.ResponseWriter, r *http.Request) { - id := authenticated_user(w, r) + id := authenticatedUser(w, r) account, err := users.GetAccount(id) if err != nil { @@ -378,7 +378,7 @@ func subscribe(w http.ResponseWriter, r *http.Request) { } func subscribe_stripe(w http.ResponseWriter, r *http.Request) { - id := authenticated_user(w, r) + id := authenticatedUser(w, r) account, err := users.GetAccount(id) if err != nil { @@ -413,7 +413,7 @@ func subscribe_stripe(w http.ResponseWriter, r *http.Request) { } func managebilling(w http.ResponseWriter, r *http.Request) { - id := authenticated_user(w, r) + id := authenticatedUser(w, r) account, err := users.GetAccount(id) if err != nil { log.Println(err) diff --git a/cmd/web/main.go b/cmd/web/main.go index 9559288..e0ab851 100644 --- a/cmd/web/main.go +++ b/cmd/web/main.go @@ -17,11 +17,11 @@ import ( _ "github.com/lib/pq" "github.com/stripe/stripe-go/v78" - "alfheimgame.com/alfheim/pkg/models" + "alfheimgame.com/alfheim/pkg/models/postgresql" ) -var users *models.Usermodel -var subscriptions *models.SubscriptionModel +var users *postgresql.AccountModel +var subscriptions *postgresql.SubscriptionModel var key = []byte("super-secret-key") var store = sessions.NewCookieStore(key) @@ -32,7 +32,7 @@ var version string func main() { addr := flag.String("addr", "127.0.0.1:8080", "HTTP network addr") - prodaddr := flag.String("prodaddr", "127.0.0.1:4000", "HTTP network addr") + prodAddr := flag.String("prodaddr", "127.0.0.1:4000", "HTTP network addr") //prodaddr := flag.String("prodaddr", "45.76.84.7:443", "HTTP network addr") production := flag.Bool("production", false, "Whether to use production port and TLS") @@ -43,7 +43,7 @@ func main() { flag.Parse() if *displayVersion { - fmt.Println("Version: %s", version) + fmt.Printf("Version: %s\n", version) return } @@ -59,8 +59,8 @@ func main() { } defer db.Close() - users = &models.Usermodel{db} - subscriptions = &models.SubscriptionModel{db} + users = &postgresql.AccountModel{DB: db} + subscriptions = &postgresql.SubscriptionModel{DB: db} mux := http.NewServeMux() @@ -96,10 +96,10 @@ func main() { mux.HandleFunc("/login", login) mux.HandleFunc("/logout", logout) mux.HandleFunc("/register", register) - mux.HandleFunc("/account", require_authenticated_user(account)) - mux.HandleFunc("/deleteaccount", require_authenticated_user(deleteaccount)) - mux.HandleFunc("/subscribe", require_authenticated_user(subscribe_stripe)) - mux.HandleFunc("/managebilling", require_authenticated_user(managebilling)) + mux.HandleFunc("/account", requireAuthenticatedUser(account)) + mux.HandleFunc("/deleteaccount", requireAuthenticatedUser(deleteAccount)) + mux.HandleFunc("/subscribe", requireAuthenticatedUser(subscribe_stripe)) + mux.HandleFunc("/managebilling", requireAuthenticatedUser(managebilling)) mux.HandleFunc("/webhook", webhooks) if *production { @@ -120,8 +120,8 @@ func main() { // } // log.Fatal(server.ListenAndServeTLS("", "")) - log.Fatal(http.ListenAndServe(*prodaddr, log_request(secure_headers(mux)))) + log.Fatal(http.ListenAndServe(*prodAddr, logRequest(secureHeaders(mux)))) } else { - log.Fatal(http.ListenAndServe(*addr, log_request(secure_headers(mux)))) + log.Fatal(http.ListenAndServe(*addr, logRequest(secureHeaders(mux)))) } } diff --git a/cmd/web/middleware.go b/cmd/web/middleware.go index 1c35e50..4a4ec98 100644 --- a/cmd/web/middleware.go +++ b/cmd/web/middleware.go @@ -5,7 +5,7 @@ import ( "net/http" ) -func secure_headers(next http.Handler) http.Handler { +func secureHeaders(next http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-XSS-Protection", "1; mode=block") w.Header().Set("X-Frame-Options", "deny") @@ -16,7 +16,7 @@ func secure_headers(next http.Handler) http.Handler { return http.HandlerFunc(fn) } -func log_request(next http.Handler) http.Handler { +func logRequest(next http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { log.Printf("%s - %s %s %s", r.RemoteAddr, r.Proto, r.Method, r.URL) @@ -26,12 +26,12 @@ func log_request(next http.Handler) http.Handler { return http.HandlerFunc(fn) } -func require_authenticated_user(next http.HandlerFunc) http.HandlerFunc { +func requireAuthenticatedUser(next http.HandlerFunc) http.HandlerFunc { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // If the user is not authenticated, redirect them to the login page and // return from the middleware chain so that no subsequent handlers in // the chain are executed. - if authenticated_user(w, r) == 0 { + if authenticatedUser(w, r) == 0 { http.Redirect(w, r, "/login", http.StatusSeeOther) return } diff --git a/pkg/models/models.go b/pkg/models/models.go index 3f7b2f9..3cd6c64 100644 --- a/pkg/models/models.go +++ b/pkg/models/models.go @@ -5,19 +5,9 @@ package models import ( - "database/sql" "errors" - "fmt" "time" - - //import "golang.org/x/crypto/bcrypt" - - "log" - - "github.com/alexedwards/argon2id" _ "github.com/lib/pq" - "github.com/stripe/stripe-go/v78" - "github.com/stripe/stripe-go/v78/customer" ) var ErrNoRecord = errors.New("no matching record found") @@ -40,14 +30,14 @@ type Account struct { type SubscriptionStatus string const ( - incomplete SubscriptionStatus = "incomplete" - incomplete_expired SubscriptionStatus = "incomplete_expired" - trialing SubscriptionStatus = "trialing" - active SubscriptionStatus = "active" - past_due SubscriptionStatus = "past_due" - canceled SubscriptionStatus = "canceled" - unpaid SubscriptionStatus = "unpaid" - paused SubscriptionStatus = "paused" + Incomplete SubscriptionStatus = "incomplete" + IncompleteExpired SubscriptionStatus = "incomplete_expired" + Trialing SubscriptionStatus = "trialing" + Active SubscriptionStatus = "active" + PastDue SubscriptionStatus = "past_due" + Canceled SubscriptionStatus = "canceled" + Unpaid SubscriptionStatus = "unpaid" + Paused SubscriptionStatus = "paused" ) type Subscription struct { @@ -57,254 +47,3 @@ type Subscription struct { StripeCheckoutID string Status SubscriptionStatus } - -type UserModel struct { - DB *sql.DB -} - -type SubscriptionModel struct { - DB *sql.DB -} - -func (m *Usermodel) Insert(username string, password string, firstname string, lastname string, email string) (int32, error) { - - //hashedpassword, err := bcrypt.GenerateFromPassword([]byte(password), 12) - hashedpassword, err := argon2id.CreateHash(password, argon2id.DefaultParams) - - if err != nil { - log.Println(err) - return 0, err - } - - //log.Println(hashedpassword) - stmt := `INSERT INTO accounts (username, password, firstname, lastname, email, created) VALUES ($1, $2, $3, $4, $5, NOW()) RETURNING id` - - var insertid int32 - - row := m.DB.QueryRow(stmt, username, string(hashedpassword), firstname, lastname, email) - if row.Err() != nil { - log.Println(row.Err()) - return 0, row.Err() - } - - err = row.Scan(&insertid) - if err != nil { - log.Println(err) - return 0, err - } - - params := &stripe.CustomerParams{ - Name: stripe.String(fmt.Sprintf("%s %s", firstname, lastname)), - Email: stripe.String(email), - } - customer, err := customer.New(params) - - if err != nil { - log.Println(err) - return 0, err - } - - stmt = `UPDATE accounts SET stripe_id = $1 WHERE id = $2` - - //log.Println(customer.ID, insertid) - - _, err = m.DB.Exec(stmt, customer.ID, insertid) - if err != nil { - log.Println(err) - return 0, err - } - - return insertid, nil -} - -func (m *UserModel) Delete(id int32) error { - account, err := users.GetAccount(id) - - if err != nil { - log.Println(err) - return err - } - - if account.StripeID != "" { - /*result*/ _, err := customer.Del(account.StripeID, nil) - if err != nil { - log.Println(err) - } - //log.Println(result) - } - - stmt := `DELETE FROM accounts WHERE id = $1` - - _, err = m.DB.Exec(stmt, id) - if err != nil { - log.Println(err) - } - - return nil -} - -func (m *Usermodel) GetAccount(id int32) (Account, error) { - if id == 0 { - return Account{}, ErrNoRecord - } - stmt := `SELECT id, username, password, color, firstname, lastname, email, created, stripe_id FROM accounts WHERE id = $1` - row := m.DB.QueryRow(stmt, id) - - var account Account - err := row.Scan(&account.ID, &account.Username, &account.Password, &account.Color, &account.Firstname, &account.Lastname, &account.Email, &account.Created, &account.StripeID) - - if err == sql.ErrNoRows { - return Account{}, sql.ErrNoRows - } else if err != nil { - return Account{}, err - } - - return account, nil -} - -func (m *Usermodel) Authenticate(username string, password string) (int32, error) { - var id int32 - var hashedpassword string - row := m.DB.QueryRow("SELECT id, password FROM accounts WHERE username = $1", username) - err := row.Scan(&id, &hashedpassword) - if err == sql.ErrNoRows { - return 0, ErrInvalidCredentials - } - - match, err := argon2id.ComparePasswordAndHash(password, hashedpassword) - if !match { - return 0, ErrInvalidCredentials - } else if err != nil { - return 0, err - } - - return id, nil -} - -func (m *Usermodel) ExistsAccount(id int32) bool { - var exists bool - stmt := `SELECT EXISTS(SELECT 1 FROM accounts WHERE id = $1)` - row := m.DB.QueryRow(stmt, id) - if row.Err() != nil { - log.Println(row.Err()) - } - row.Scan(&exists) - - //log.Println(exists) - - return exists -} - -func (m *SubscriptionModel) Insert(stripeid string, stripesubscriptionid string, stripecheckoutid string, status stripe.SubscriptionStatus) (int32, error) { - var id int32 - stmt := `SELECT id FROM accounts WHERE stripe_id = $1` - - row := m.DB.QueryRow(stmt, stripeid) - if row.Err() != nil { - log.Println(row.Err()) - return 0, row.Err() - } - - err := row.Scan(&id) - if err != nil { - log.Println(err) - return 0, err - } - - stmt = `INSERT INTO subscriptions (account_id, stripe_subscription_id, stripe_checkout_id, status) VALUES ($1, $2, $3, $4::subscription_status) RETURNING id` - - var insertid int32 - - row = m.DB.QueryRow(stmt, id, string(stripesubscriptionid), string(stripecheckoutid), string(status)) - if row.Err() != nil { - log.Println(row.Err()) - return 0, row.Err() - } - - err = row.Scan(&insertid) - if err != nil { - log.Println(err) - return 0, err - } - - return insertid, nil -} - -func (m *SubscriptionModel) Delete(id int32) error { - stmt := `DELETE FROM accounts WHERE id = $1` - - _, err := m.DB.Exec(stmt, id) - if err != nil { - log.Println(err) - return err - } - - return nil -} - -func (m *SubscriptionModel) GetSubscription(id int32) (Subscription, error) { - if id == 0 { - return Subscription{}, ErrNoRecord - } - - stmt := `SELECT id, account_id, stripe_subscription_id, stripe_checkout_id, status FROM subscriptions WHERE id = $1` - row := m.DB.QueryRow(stmt, id) - - var subscription Subscription - err := row.Scan(&subscription.ID, &subscription.AccountID, &subscription.StripeSubscriptionID, &subscription.StripeCheckoutID, &subscription.Status) - - if err == sql.ErrNoRows { - return Subscription{}, sql.ErrNoRows - } else if err != nil { - return Subscription{}, err - } - - //log.Println(subscription.Status) - - return subscription, nil -} - -func (m *SubscriptionModel) GetSubscriptionsFromAccount(accountid int32) ([]Subscription, error) { - if accountid == 0 { - return nil, ErrNoRecord - } - - stmt := `SELECT id, account_id, stripe_subscription_id, stripe_checkout_id, status FROM subscriptions WHERE account_id = $1` - rows, err := m.DB.Query(stmt, accountid) - if err != nil { - return nil, err - } - defer rows.Close() - - var subscriptions []Subscription - for rows.Next() { - var subscription Subscription - err := rows.Scan(&subscription.ID, &subscription.AccountID, &subscription.StripeSubscriptionID, &subscription.StripeCheckoutID, &subscription.Status) - if err == sql.ErrNoRows { - return nil, sql.ErrNoRows - } else if err != nil { - return nil, err - } - - subscriptions = append(subscriptions, subscription) - } - - return subscriptions, nil -} - -func (m *SubscriptionModel) HasActiveSubscription(accountid int32) bool { - subscriptions, err := m.GetSubscriptionsFromAccount(accountid) - if err != nil { - return false - } - - for _, v := range subscriptions { - if v.Status == active { - return true - } else if v.Status == trialing { - return true - } - } - - return false -} diff --git a/pkg/models/postgresql/accounts.go b/pkg/models/postgresql/accounts.go new file mode 100644 index 0000000..77e01f0 --- /dev/null +++ b/pkg/models/postgresql/accounts.go @@ -0,0 +1,150 @@ +package postgresql + +import ( + "database/sql" + "fmt" + + //import "golang.org/x/crypto/bcrypt" + + "log" + + "github.com/alexedwards/argon2id" + _ "github.com/lib/pq" + "github.com/stripe/stripe-go/v78" + "github.com/stripe/stripe-go/v78/customer" + + "alfheimgame.com/alfheim/pkg/models" +) + +type AccountModel struct { + DB *sql.DB +} + +func (m *AccountModel) Insert(username string, password string, firstname string, lastname string, email string) (int32, error) { + + //hashedpassword, err := bcrypt.GenerateFromPassword([]byte(password), 12) + hashedpassword, err := argon2id.CreateHash(password, argon2id.DefaultParams) + + if err != nil { + log.Println(err) + return 0, err + } + + //log.Println(hashedpassword) + stmt := `INSERT INTO accounts (username, password, firstname, lastname, email, created) VALUES ($1, $2, $3, $4, $5, NOW()) RETURNING id` + + var insertid int32 + + row := m.DB.QueryRow(stmt, username, string(hashedpassword), firstname, lastname, email) + if row.Err() != nil { + log.Println(row.Err()) + return 0, row.Err() + } + + err = row.Scan(&insertid) + if err != nil { + log.Println(err) + return 0, err + } + + params := &stripe.CustomerParams{ + Name: stripe.String(fmt.Sprintf("%s %s", firstname, lastname)), + Email: stripe.String(email), + } + customer, err := customer.New(params) + + if err != nil { + log.Println(err) + return 0, err + } + + stmt = `UPDATE accounts SET stripe_id = $1 WHERE id = $2` + + //log.Println(customer.ID, insertid) + + _, err = m.DB.Exec(stmt, customer.ID, insertid) + if err != nil { + log.Println(err) + return 0, err + } + + return insertid, nil +} + +func (m *AccountModel) Delete(id int32) error { + account, err := m.GetAccount(id) + + if err != nil { + log.Println(err) + return err + } + + if account.StripeID != "" { + /*result*/ _, err := customer.Del(account.StripeID, nil) + if err != nil { + log.Println(err) + } + //log.Println(result) + } + + stmt := `DELETE FROM accounts WHERE id = $1` + + _, err = m.DB.Exec(stmt, id) + if err != nil { + log.Println(err) + } + + return nil +} + +func (m *AccountModel) GetAccount(id int32) (models.Account, error) { + if id == 0 { + return models.Account{}, models.ErrNoRecord + } + stmt := `SELECT id, username, password, color, firstname, lastname, email, created, stripe_id FROM accounts WHERE id = $1` + row := m.DB.QueryRow(stmt, id) + + var account models.Account + err := row.Scan(&account.ID, &account.Username, &account.Password, &account.Color, &account.Firstname, &account.Lastname, &account.Email, &account.Created, &account.StripeID) + + if err == sql.ErrNoRows { + return models.Account{}, sql.ErrNoRows + } else if err != nil { + return models.Account{}, err + } + + return account, nil +} + +func (m *AccountModel) Authenticate(username string, password string) (int32, error) { + var id int32 + var hashedpassword string + row := m.DB.QueryRow("SELECT id, password FROM accounts WHERE username = $1", username) + err := row.Scan(&id, &hashedpassword) + if err == sql.ErrNoRows { + return 0, models.ErrInvalidCredentials + } + + match, err := argon2id.ComparePasswordAndHash(password, hashedpassword) + if !match { + return 0, models.ErrInvalidCredentials + } else if err != nil { + return 0, err + } + + return id, nil +} + +func (m *AccountModel) ExistsAccount(id int32) bool { + var exists bool + stmt := `SELECT EXISTS(SELECT 1 FROM accounts WHERE id = $1)` + row := m.DB.QueryRow(stmt, id) + if row.Err() != nil { + log.Println(row.Err()) + } + row.Scan(&exists) + + //log.Println(exists) + + return exists +} diff --git a/pkg/models/postgresql/subscriptions.go b/pkg/models/postgresql/subscriptions.go new file mode 100644 index 0000000..7b31a33 --- /dev/null +++ b/pkg/models/postgresql/subscriptions.go @@ -0,0 +1,133 @@ +package postgresql + +import ( + "database/sql" + + //import "golang.org/x/crypto/bcrypt" + + "log" + + _ "github.com/lib/pq" + "github.com/stripe/stripe-go/v78" + + "alfheimgame.com/alfheim/pkg/models" +) + +type SubscriptionModel struct { + DB *sql.DB +} + +func (m *SubscriptionModel) Insert(stripeid string, stripesubscriptionid string, stripecheckoutid string, status stripe.SubscriptionStatus) (int32, error) { + var id int32 + stmt := `SELECT id FROM accounts WHERE stripe_id = $1` + + row := m.DB.QueryRow(stmt, stripeid) + if row.Err() != nil { + log.Println(row.Err()) + return 0, row.Err() + } + + err := row.Scan(&id) + if err != nil { + log.Println(err) + return 0, err + } + + stmt = `INSERT INTO subscriptions (account_id, stripe_subscription_id, stripe_checkout_id, status) VALUES ($1, $2, $3, $4::subscription_status) RETURNING id` + + var insertid int32 + + row = m.DB.QueryRow(stmt, id, string(stripesubscriptionid), string(stripecheckoutid), string(status)) + if row.Err() != nil { + log.Println(row.Err()) + return 0, row.Err() + } + + err = row.Scan(&insertid) + if err != nil { + log.Println(err) + return 0, err + } + + return insertid, nil +} + +func (m *SubscriptionModel) Delete(id int32) error { + stmt := `DELETE FROM accounts WHERE id = $1` + + _, err := m.DB.Exec(stmt, id) + if err != nil { + log.Println(err) + return err + } + + return nil +} + +func (m *SubscriptionModel) GetSubscription(id int32) (models.Subscription, error) { + if id == 0 { + return models.Subscription{}, models.ErrNoRecord + } + + stmt := `SELECT id, account_id, stripe_subscription_id, stripe_checkout_id, status FROM subscriptions WHERE id = $1` + row := m.DB.QueryRow(stmt, id) + + var subscription models.Subscription + err := row.Scan(&subscription.ID, &subscription.AccountID, &subscription.StripeSubscriptionID, &subscription.StripeCheckoutID, &subscription.Status) + + if err == sql.ErrNoRows { + return models.Subscription{}, sql.ErrNoRows + } else if err != nil { + return models.Subscription{}, err + } + + //log.Println(subscription.Status) + + return subscription, nil +} + +func (m *SubscriptionModel) GetSubscriptionsFromAccount(accountid int32) ([]models.Subscription, error) { + if accountid == 0 { + return nil, models.ErrNoRecord + } + + stmt := `SELECT id, account_id, stripe_subscription_id, stripe_checkout_id, status FROM subscriptions WHERE account_id = $1` + rows, err := m.DB.Query(stmt, accountid) + if err != nil { + return nil, err + } + defer rows.Close() + + var subscriptions []models.Subscription + for rows.Next() { + var subscription models.Subscription + err := rows.Scan(&subscription.ID, &subscription.AccountID, &subscription.StripeSubscriptionID, &subscription.StripeCheckoutID, &subscription.Status) + if err == sql.ErrNoRows { + return nil, sql.ErrNoRows + } else if err != nil { + return nil, err + } + + subscriptions = append(subscriptions, subscription) + } + + return subscriptions, nil +} + +func (m *SubscriptionModel) HasActiveSubscription(accountid int32) bool { + subscriptions, err := m.GetSubscriptionsFromAccount(accountid) + if err != nil { + return false + } + + for _, v := range subscriptions { + switch v.Status { + case models.Active: + return true + case models.Trialing: + return true + } + } + + return false +}