git-svn-id: svn://losandesgames.com/alfheim-website@49 15359d88-9307-4e75-a9c1-e5686e5897df
This commit is contained in:
parent
afb402223f
commit
391becf14e
@ -42,7 +42,7 @@ func favicon(w http.ResponseWriter, r *http.Request) {
|
|||||||
http.ServeFile(w, r, "favicon.ico")
|
http.ServeFile(w, r, "favicon.ico")
|
||||||
}
|
}
|
||||||
|
|
||||||
func authenticated_user(w http.ResponseWriter, r *http.Request) int32 {
|
func authenticatedUser(w http.ResponseWriter, r *http.Request) int32 {
|
||||||
session, err := store.Get(r, "id")
|
session, err := store.Get(r, "id")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println(err)
|
log.Println(err)
|
||||||
@ -79,7 +79,7 @@ func home(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
id := authenticated_user(w, r)
|
id := authenticatedUser(w, r)
|
||||||
account, _ := users.GetAccount(id)
|
account, _ := users.GetAccount(id)
|
||||||
|
|
||||||
active_subscription := subscriptions.HasActiveSubscription(id)
|
active_subscription := subscriptions.HasActiveSubscription(id)
|
||||||
@ -143,7 +143,7 @@ func login(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
if len(errors) > 0 {
|
if len(errors) > 0 {
|
||||||
|
|
||||||
err := text.Execute(w, TemplateData{AuthenticatedUser: authenticated_user(w, r), FormErrors: errors})
|
err := text.Execute(w, TemplateData{AuthenticatedUser: authenticatedUser(w, r), FormErrors: errors})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println(err)
|
log.Println(err)
|
||||||
http.Error(w, "Internal Server Error", 500)
|
http.Error(w, "Internal Server Error", 500)
|
||||||
@ -155,7 +155,7 @@ func login(w http.ResponseWriter, r *http.Request) {
|
|||||||
id, err := users.Authenticate(username, password)
|
id, err := users.Authenticate(username, password)
|
||||||
if err == models.ErrInvalidCredentials {
|
if err == models.ErrInvalidCredentials {
|
||||||
errors["generic"] = "Email or Password is incorrect"
|
errors["generic"] = "Email or Password is incorrect"
|
||||||
err := text.Execute(w, TemplateData{AuthenticatedUser: authenticated_user(w, r), FormErrors: errors})
|
err := text.Execute(w, TemplateData{AuthenticatedUser: authenticatedUser(w, r), FormErrors: errors})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println(err)
|
log.Println(err)
|
||||||
http.Error(w, "Internal Server Error", 500)
|
http.Error(w, "Internal Server Error", 500)
|
||||||
@ -180,7 +180,7 @@ func logout(w http.ResponseWriter, r *http.Request) {
|
|||||||
log.Println(err)
|
log.Println(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
id := authenticated_user(w, r)
|
id := authenticatedUser(w, r)
|
||||||
account, err := users.GetAccount(id)
|
account, err := users.GetAccount(id)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -215,7 +215,7 @@ func register(w http.ResponseWriter, r *http.Request) {
|
|||||||
log.Println(err)
|
log.Println(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
id := authenticated_user(w, r)
|
id := authenticatedUser(w, r)
|
||||||
account, _ := users.GetAccount(id)
|
account, _ := users.GetAccount(id)
|
||||||
|
|
||||||
// if err != nil {
|
// if err != nil {
|
||||||
@ -251,7 +251,7 @@ func register(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
if len(errors) > 0 {
|
if len(errors) > 0 {
|
||||||
|
|
||||||
err := text.Execute(w, TemplateData{AuthenticatedUser: authenticated_user(w, r), FormErrors: errors})
|
err := text.Execute(w, TemplateData{AuthenticatedUser: authenticatedUser(w, r), FormErrors: errors})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println(err)
|
log.Println(err)
|
||||||
http.Error(w, "Internal Server Error", 500)
|
http.Error(w, "Internal Server Error", 500)
|
||||||
@ -281,7 +281,7 @@ func account(w http.ResponseWriter, r *http.Request) {
|
|||||||
// log.Fatal(err)
|
// log.Fatal(err)
|
||||||
//}
|
//}
|
||||||
|
|
||||||
id := authenticated_user(w, r)
|
id := authenticatedUser(w, r)
|
||||||
account, err := users.GetAccount(id)
|
account, err := users.GetAccount(id)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -314,9 +314,9 @@ func account(w http.ResponseWriter, r *http.Request) {
|
|||||||
// }
|
// }
|
||||||
}
|
}
|
||||||
|
|
||||||
func deleteaccount(w http.ResponseWriter, r *http.Request) {
|
func deleteAccount(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
id := authenticated_user(w, r)
|
id := authenticatedUser(w, r)
|
||||||
|
|
||||||
switch r.Method {
|
switch r.Method {
|
||||||
case http.MethodPost:
|
case http.MethodPost:
|
||||||
@ -337,7 +337,7 @@ func deleteaccount(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func subscribe(w http.ResponseWriter, r *http.Request) {
|
func subscribe(w http.ResponseWriter, r *http.Request) {
|
||||||
id := authenticated_user(w, r)
|
id := authenticatedUser(w, r)
|
||||||
account, err := users.GetAccount(id)
|
account, err := users.GetAccount(id)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -378,7 +378,7 @@ func subscribe(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func subscribe_stripe(w http.ResponseWriter, r *http.Request) {
|
func subscribe_stripe(w http.ResponseWriter, r *http.Request) {
|
||||||
id := authenticated_user(w, r)
|
id := authenticatedUser(w, r)
|
||||||
account, err := users.GetAccount(id)
|
account, err := users.GetAccount(id)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -413,7 +413,7 @@ func subscribe_stripe(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func managebilling(w http.ResponseWriter, r *http.Request) {
|
func managebilling(w http.ResponseWriter, r *http.Request) {
|
||||||
id := authenticated_user(w, r)
|
id := authenticatedUser(w, r)
|
||||||
account, err := users.GetAccount(id)
|
account, err := users.GetAccount(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println(err)
|
log.Println(err)
|
||||||
|
|||||||
@ -17,11 +17,11 @@ import (
|
|||||||
_ "github.com/lib/pq"
|
_ "github.com/lib/pq"
|
||||||
"github.com/stripe/stripe-go/v78"
|
"github.com/stripe/stripe-go/v78"
|
||||||
|
|
||||||
"alfheimgame.com/alfheim/pkg/models"
|
"alfheimgame.com/alfheim/pkg/models/postgresql"
|
||||||
)
|
)
|
||||||
|
|
||||||
var users *models.Usermodel
|
var users *postgresql.AccountModel
|
||||||
var subscriptions *models.SubscriptionModel
|
var subscriptions *postgresql.SubscriptionModel
|
||||||
|
|
||||||
var key = []byte("super-secret-key")
|
var key = []byte("super-secret-key")
|
||||||
var store = sessions.NewCookieStore(key)
|
var store = sessions.NewCookieStore(key)
|
||||||
@ -32,7 +32,7 @@ var version string
|
|||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
addr := flag.String("addr", "127.0.0.1:8080", "HTTP network addr")
|
addr := flag.String("addr", "127.0.0.1:8080", "HTTP network addr")
|
||||||
prodaddr := flag.String("prodaddr", "127.0.0.1:4000", "HTTP network addr")
|
prodAddr := flag.String("prodaddr", "127.0.0.1:4000", "HTTP network addr")
|
||||||
//prodaddr := flag.String("prodaddr", "45.76.84.7:443", "HTTP network addr")
|
//prodaddr := flag.String("prodaddr", "45.76.84.7:443", "HTTP network addr")
|
||||||
|
|
||||||
production := flag.Bool("production", false, "Whether to use production port and TLS")
|
production := flag.Bool("production", false, "Whether to use production port and TLS")
|
||||||
@ -43,7 +43,7 @@ func main() {
|
|||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
if *displayVersion {
|
if *displayVersion {
|
||||||
fmt.Println("Version: %s", version)
|
fmt.Printf("Version: %s\n", version)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -59,8 +59,8 @@ func main() {
|
|||||||
}
|
}
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
|
|
||||||
users = &models.Usermodel{db}
|
users = &postgresql.AccountModel{DB: db}
|
||||||
subscriptions = &models.SubscriptionModel{db}
|
subscriptions = &postgresql.SubscriptionModel{DB: db}
|
||||||
|
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
|
|
||||||
@ -96,10 +96,10 @@ func main() {
|
|||||||
mux.HandleFunc("/login", login)
|
mux.HandleFunc("/login", login)
|
||||||
mux.HandleFunc("/logout", logout)
|
mux.HandleFunc("/logout", logout)
|
||||||
mux.HandleFunc("/register", register)
|
mux.HandleFunc("/register", register)
|
||||||
mux.HandleFunc("/account", require_authenticated_user(account))
|
mux.HandleFunc("/account", requireAuthenticatedUser(account))
|
||||||
mux.HandleFunc("/deleteaccount", require_authenticated_user(deleteaccount))
|
mux.HandleFunc("/deleteaccount", requireAuthenticatedUser(deleteAccount))
|
||||||
mux.HandleFunc("/subscribe", require_authenticated_user(subscribe_stripe))
|
mux.HandleFunc("/subscribe", requireAuthenticatedUser(subscribe_stripe))
|
||||||
mux.HandleFunc("/managebilling", require_authenticated_user(managebilling))
|
mux.HandleFunc("/managebilling", requireAuthenticatedUser(managebilling))
|
||||||
mux.HandleFunc("/webhook", webhooks)
|
mux.HandleFunc("/webhook", webhooks)
|
||||||
|
|
||||||
if *production {
|
if *production {
|
||||||
@ -120,8 +120,8 @@ func main() {
|
|||||||
// }
|
// }
|
||||||
|
|
||||||
// log.Fatal(server.ListenAndServeTLS("", ""))
|
// log.Fatal(server.ListenAndServeTLS("", ""))
|
||||||
log.Fatal(http.ListenAndServe(*prodaddr, log_request(secure_headers(mux))))
|
log.Fatal(http.ListenAndServe(*prodAddr, logRequest(secureHeaders(mux))))
|
||||||
} else {
|
} else {
|
||||||
log.Fatal(http.ListenAndServe(*addr, log_request(secure_headers(mux))))
|
log.Fatal(http.ListenAndServe(*addr, logRequest(secureHeaders(mux))))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -5,7 +5,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
func secure_headers(next http.Handler) http.Handler {
|
func secureHeaders(next http.Handler) http.Handler {
|
||||||
fn := func(w http.ResponseWriter, r *http.Request) {
|
fn := func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("X-XSS-Protection", "1; mode=block")
|
w.Header().Set("X-XSS-Protection", "1; mode=block")
|
||||||
w.Header().Set("X-Frame-Options", "deny")
|
w.Header().Set("X-Frame-Options", "deny")
|
||||||
@ -16,7 +16,7 @@ func secure_headers(next http.Handler) http.Handler {
|
|||||||
return http.HandlerFunc(fn)
|
return http.HandlerFunc(fn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func log_request(next http.Handler) http.Handler {
|
func logRequest(next http.Handler) http.Handler {
|
||||||
fn := func(w http.ResponseWriter, r *http.Request) {
|
fn := func(w http.ResponseWriter, r *http.Request) {
|
||||||
log.Printf("%s - %s %s %s", r.RemoteAddr, r.Proto, r.Method, r.URL)
|
log.Printf("%s - %s %s %s", r.RemoteAddr, r.Proto, r.Method, r.URL)
|
||||||
|
|
||||||
@ -26,12 +26,12 @@ func log_request(next http.Handler) http.Handler {
|
|||||||
return http.HandlerFunc(fn)
|
return http.HandlerFunc(fn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func require_authenticated_user(next http.HandlerFunc) http.HandlerFunc {
|
func requireAuthenticatedUser(next http.HandlerFunc) http.HandlerFunc {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
// If the user is not authenticated, redirect them to the login page and
|
// If the user is not authenticated, redirect them to the login page and
|
||||||
// return from the middleware chain so that no subsequent handlers in
|
// return from the middleware chain so that no subsequent handlers in
|
||||||
// the chain are executed.
|
// the chain are executed.
|
||||||
if authenticated_user(w, r) == 0 {
|
if authenticatedUser(w, r) == 0 {
|
||||||
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@ -5,19 +5,9 @@
|
|||||||
package models
|
package models
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
//import "golang.org/x/crypto/bcrypt"
|
|
||||||
|
|
||||||
"log"
|
|
||||||
|
|
||||||
"github.com/alexedwards/argon2id"
|
|
||||||
_ "github.com/lib/pq"
|
_ "github.com/lib/pq"
|
||||||
"github.com/stripe/stripe-go/v78"
|
|
||||||
"github.com/stripe/stripe-go/v78/customer"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrNoRecord = errors.New("no matching record found")
|
var ErrNoRecord = errors.New("no matching record found")
|
||||||
@ -40,14 +30,14 @@ type Account struct {
|
|||||||
type SubscriptionStatus string
|
type SubscriptionStatus string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
incomplete SubscriptionStatus = "incomplete"
|
Incomplete SubscriptionStatus = "incomplete"
|
||||||
incomplete_expired SubscriptionStatus = "incomplete_expired"
|
IncompleteExpired SubscriptionStatus = "incomplete_expired"
|
||||||
trialing SubscriptionStatus = "trialing"
|
Trialing SubscriptionStatus = "trialing"
|
||||||
active SubscriptionStatus = "active"
|
Active SubscriptionStatus = "active"
|
||||||
past_due SubscriptionStatus = "past_due"
|
PastDue SubscriptionStatus = "past_due"
|
||||||
canceled SubscriptionStatus = "canceled"
|
Canceled SubscriptionStatus = "canceled"
|
||||||
unpaid SubscriptionStatus = "unpaid"
|
Unpaid SubscriptionStatus = "unpaid"
|
||||||
paused SubscriptionStatus = "paused"
|
Paused SubscriptionStatus = "paused"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Subscription struct {
|
type Subscription struct {
|
||||||
@ -57,254 +47,3 @@ type Subscription struct {
|
|||||||
StripeCheckoutID string
|
StripeCheckoutID string
|
||||||
Status SubscriptionStatus
|
Status SubscriptionStatus
|
||||||
}
|
}
|
||||||
|
|
||||||
type UserModel struct {
|
|
||||||
DB *sql.DB
|
|
||||||
}
|
|
||||||
|
|
||||||
type SubscriptionModel struct {
|
|
||||||
DB *sql.DB
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Usermodel) Insert(username string, password string, firstname string, lastname string, email string) (int32, error) {
|
|
||||||
|
|
||||||
//hashedpassword, err := bcrypt.GenerateFromPassword([]byte(password), 12)
|
|
||||||
hashedpassword, err := argon2id.CreateHash(password, argon2id.DefaultParams)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
//log.Println(hashedpassword)
|
|
||||||
stmt := `INSERT INTO accounts (username, password, firstname, lastname, email, created) VALUES ($1, $2, $3, $4, $5, NOW()) RETURNING id`
|
|
||||||
|
|
||||||
var insertid int32
|
|
||||||
|
|
||||||
row := m.DB.QueryRow(stmt, username, string(hashedpassword), firstname, lastname, email)
|
|
||||||
if row.Err() != nil {
|
|
||||||
log.Println(row.Err())
|
|
||||||
return 0, row.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
err = row.Scan(&insertid)
|
|
||||||
if err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
params := &stripe.CustomerParams{
|
|
||||||
Name: stripe.String(fmt.Sprintf("%s %s", firstname, lastname)),
|
|
||||||
Email: stripe.String(email),
|
|
||||||
}
|
|
||||||
customer, err := customer.New(params)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
stmt = `UPDATE accounts SET stripe_id = $1 WHERE id = $2`
|
|
||||||
|
|
||||||
//log.Println(customer.ID, insertid)
|
|
||||||
|
|
||||||
_, err = m.DB.Exec(stmt, customer.ID, insertid)
|
|
||||||
if err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return insertid, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *UserModel) Delete(id int32) error {
|
|
||||||
account, err := users.GetAccount(id)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if account.StripeID != "" {
|
|
||||||
/*result*/ _, err := customer.Del(account.StripeID, nil)
|
|
||||||
if err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
}
|
|
||||||
//log.Println(result)
|
|
||||||
}
|
|
||||||
|
|
||||||
stmt := `DELETE FROM accounts WHERE id = $1`
|
|
||||||
|
|
||||||
_, err = m.DB.Exec(stmt, id)
|
|
||||||
if err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Usermodel) GetAccount(id int32) (Account, error) {
|
|
||||||
if id == 0 {
|
|
||||||
return Account{}, ErrNoRecord
|
|
||||||
}
|
|
||||||
stmt := `SELECT id, username, password, color, firstname, lastname, email, created, stripe_id FROM accounts WHERE id = $1`
|
|
||||||
row := m.DB.QueryRow(stmt, id)
|
|
||||||
|
|
||||||
var account Account
|
|
||||||
err := row.Scan(&account.ID, &account.Username, &account.Password, &account.Color, &account.Firstname, &account.Lastname, &account.Email, &account.Created, &account.StripeID)
|
|
||||||
|
|
||||||
if err == sql.ErrNoRows {
|
|
||||||
return Account{}, sql.ErrNoRows
|
|
||||||
} else if err != nil {
|
|
||||||
return Account{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return account, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Usermodel) Authenticate(username string, password string) (int32, error) {
|
|
||||||
var id int32
|
|
||||||
var hashedpassword string
|
|
||||||
row := m.DB.QueryRow("SELECT id, password FROM accounts WHERE username = $1", username)
|
|
||||||
err := row.Scan(&id, &hashedpassword)
|
|
||||||
if err == sql.ErrNoRows {
|
|
||||||
return 0, ErrInvalidCredentials
|
|
||||||
}
|
|
||||||
|
|
||||||
match, err := argon2id.ComparePasswordAndHash(password, hashedpassword)
|
|
||||||
if !match {
|
|
||||||
return 0, ErrInvalidCredentials
|
|
||||||
} else if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return id, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Usermodel) ExistsAccount(id int32) bool {
|
|
||||||
var exists bool
|
|
||||||
stmt := `SELECT EXISTS(SELECT 1 FROM accounts WHERE id = $1)`
|
|
||||||
row := m.DB.QueryRow(stmt, id)
|
|
||||||
if row.Err() != nil {
|
|
||||||
log.Println(row.Err())
|
|
||||||
}
|
|
||||||
row.Scan(&exists)
|
|
||||||
|
|
||||||
//log.Println(exists)
|
|
||||||
|
|
||||||
return exists
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *SubscriptionModel) Insert(stripeid string, stripesubscriptionid string, stripecheckoutid string, status stripe.SubscriptionStatus) (int32, error) {
|
|
||||||
var id int32
|
|
||||||
stmt := `SELECT id FROM accounts WHERE stripe_id = $1`
|
|
||||||
|
|
||||||
row := m.DB.QueryRow(stmt, stripeid)
|
|
||||||
if row.Err() != nil {
|
|
||||||
log.Println(row.Err())
|
|
||||||
return 0, row.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
err := row.Scan(&id)
|
|
||||||
if err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
stmt = `INSERT INTO subscriptions (account_id, stripe_subscription_id, stripe_checkout_id, status) VALUES ($1, $2, $3, $4::subscription_status) RETURNING id`
|
|
||||||
|
|
||||||
var insertid int32
|
|
||||||
|
|
||||||
row = m.DB.QueryRow(stmt, id, string(stripesubscriptionid), string(stripecheckoutid), string(status))
|
|
||||||
if row.Err() != nil {
|
|
||||||
log.Println(row.Err())
|
|
||||||
return 0, row.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
err = row.Scan(&insertid)
|
|
||||||
if err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return insertid, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *SubscriptionModel) Delete(id int32) error {
|
|
||||||
stmt := `DELETE FROM accounts WHERE id = $1`
|
|
||||||
|
|
||||||
_, err := m.DB.Exec(stmt, id)
|
|
||||||
if err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *SubscriptionModel) GetSubscription(id int32) (Subscription, error) {
|
|
||||||
if id == 0 {
|
|
||||||
return Subscription{}, ErrNoRecord
|
|
||||||
}
|
|
||||||
|
|
||||||
stmt := `SELECT id, account_id, stripe_subscription_id, stripe_checkout_id, status FROM subscriptions WHERE id = $1`
|
|
||||||
row := m.DB.QueryRow(stmt, id)
|
|
||||||
|
|
||||||
var subscription Subscription
|
|
||||||
err := row.Scan(&subscription.ID, &subscription.AccountID, &subscription.StripeSubscriptionID, &subscription.StripeCheckoutID, &subscription.Status)
|
|
||||||
|
|
||||||
if err == sql.ErrNoRows {
|
|
||||||
return Subscription{}, sql.ErrNoRows
|
|
||||||
} else if err != nil {
|
|
||||||
return Subscription{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
//log.Println(subscription.Status)
|
|
||||||
|
|
||||||
return subscription, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *SubscriptionModel) GetSubscriptionsFromAccount(accountid int32) ([]Subscription, error) {
|
|
||||||
if accountid == 0 {
|
|
||||||
return nil, ErrNoRecord
|
|
||||||
}
|
|
||||||
|
|
||||||
stmt := `SELECT id, account_id, stripe_subscription_id, stripe_checkout_id, status FROM subscriptions WHERE account_id = $1`
|
|
||||||
rows, err := m.DB.Query(stmt, accountid)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var subscriptions []Subscription
|
|
||||||
for rows.Next() {
|
|
||||||
var subscription Subscription
|
|
||||||
err := rows.Scan(&subscription.ID, &subscription.AccountID, &subscription.StripeSubscriptionID, &subscription.StripeCheckoutID, &subscription.Status)
|
|
||||||
if err == sql.ErrNoRows {
|
|
||||||
return nil, sql.ErrNoRows
|
|
||||||
} else if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
subscriptions = append(subscriptions, subscription)
|
|
||||||
}
|
|
||||||
|
|
||||||
return subscriptions, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *SubscriptionModel) HasActiveSubscription(accountid int32) bool {
|
|
||||||
subscriptions, err := m.GetSubscriptionsFromAccount(accountid)
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, v := range subscriptions {
|
|
||||||
if v.Status == active {
|
|
||||||
return true
|
|
||||||
} else if v.Status == trialing {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|||||||
150
pkg/models/postgresql/accounts.go
Normal file
150
pkg/models/postgresql/accounts.go
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
package postgresql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
//import "golang.org/x/crypto/bcrypt"
|
||||||
|
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"github.com/alexedwards/argon2id"
|
||||||
|
_ "github.com/lib/pq"
|
||||||
|
"github.com/stripe/stripe-go/v78"
|
||||||
|
"github.com/stripe/stripe-go/v78/customer"
|
||||||
|
|
||||||
|
"alfheimgame.com/alfheim/pkg/models"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AccountModel struct {
|
||||||
|
DB *sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *AccountModel) Insert(username string, password string, firstname string, lastname string, email string) (int32, error) {
|
||||||
|
|
||||||
|
//hashedpassword, err := bcrypt.GenerateFromPassword([]byte(password), 12)
|
||||||
|
hashedpassword, err := argon2id.CreateHash(password, argon2id.DefaultParams)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
//log.Println(hashedpassword)
|
||||||
|
stmt := `INSERT INTO accounts (username, password, firstname, lastname, email, created) VALUES ($1, $2, $3, $4, $5, NOW()) RETURNING id`
|
||||||
|
|
||||||
|
var insertid int32
|
||||||
|
|
||||||
|
row := m.DB.QueryRow(stmt, username, string(hashedpassword), firstname, lastname, email)
|
||||||
|
if row.Err() != nil {
|
||||||
|
log.Println(row.Err())
|
||||||
|
return 0, row.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
err = row.Scan(&insertid)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
params := &stripe.CustomerParams{
|
||||||
|
Name: stripe.String(fmt.Sprintf("%s %s", firstname, lastname)),
|
||||||
|
Email: stripe.String(email),
|
||||||
|
}
|
||||||
|
customer, err := customer.New(params)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt = `UPDATE accounts SET stripe_id = $1 WHERE id = $2`
|
||||||
|
|
||||||
|
//log.Println(customer.ID, insertid)
|
||||||
|
|
||||||
|
_, err = m.DB.Exec(stmt, customer.ID, insertid)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return insertid, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *AccountModel) Delete(id int32) error {
|
||||||
|
account, err := m.GetAccount(id)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if account.StripeID != "" {
|
||||||
|
/*result*/ _, err := customer.Del(account.StripeID, nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
}
|
||||||
|
//log.Println(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt := `DELETE FROM accounts WHERE id = $1`
|
||||||
|
|
||||||
|
_, err = m.DB.Exec(stmt, id)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *AccountModel) GetAccount(id int32) (models.Account, error) {
|
||||||
|
if id == 0 {
|
||||||
|
return models.Account{}, models.ErrNoRecord
|
||||||
|
}
|
||||||
|
stmt := `SELECT id, username, password, color, firstname, lastname, email, created, stripe_id FROM accounts WHERE id = $1`
|
||||||
|
row := m.DB.QueryRow(stmt, id)
|
||||||
|
|
||||||
|
var account models.Account
|
||||||
|
err := row.Scan(&account.ID, &account.Username, &account.Password, &account.Color, &account.Firstname, &account.Lastname, &account.Email, &account.Created, &account.StripeID)
|
||||||
|
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return models.Account{}, sql.ErrNoRows
|
||||||
|
} else if err != nil {
|
||||||
|
return models.Account{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return account, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *AccountModel) Authenticate(username string, password string) (int32, error) {
|
||||||
|
var id int32
|
||||||
|
var hashedpassword string
|
||||||
|
row := m.DB.QueryRow("SELECT id, password FROM accounts WHERE username = $1", username)
|
||||||
|
err := row.Scan(&id, &hashedpassword)
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return 0, models.ErrInvalidCredentials
|
||||||
|
}
|
||||||
|
|
||||||
|
match, err := argon2id.ComparePasswordAndHash(password, hashedpassword)
|
||||||
|
if !match {
|
||||||
|
return 0, models.ErrInvalidCredentials
|
||||||
|
} else if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return id, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *AccountModel) ExistsAccount(id int32) bool {
|
||||||
|
var exists bool
|
||||||
|
stmt := `SELECT EXISTS(SELECT 1 FROM accounts WHERE id = $1)`
|
||||||
|
row := m.DB.QueryRow(stmt, id)
|
||||||
|
if row.Err() != nil {
|
||||||
|
log.Println(row.Err())
|
||||||
|
}
|
||||||
|
row.Scan(&exists)
|
||||||
|
|
||||||
|
//log.Println(exists)
|
||||||
|
|
||||||
|
return exists
|
||||||
|
}
|
||||||
133
pkg/models/postgresql/subscriptions.go
Normal file
133
pkg/models/postgresql/subscriptions.go
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
package postgresql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
//import "golang.org/x/crypto/bcrypt"
|
||||||
|
|
||||||
|
"log"
|
||||||
|
|
||||||
|
_ "github.com/lib/pq"
|
||||||
|
"github.com/stripe/stripe-go/v78"
|
||||||
|
|
||||||
|
"alfheimgame.com/alfheim/pkg/models"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SubscriptionModel struct {
|
||||||
|
DB *sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *SubscriptionModel) Insert(stripeid string, stripesubscriptionid string, stripecheckoutid string, status stripe.SubscriptionStatus) (int32, error) {
|
||||||
|
var id int32
|
||||||
|
stmt := `SELECT id FROM accounts WHERE stripe_id = $1`
|
||||||
|
|
||||||
|
row := m.DB.QueryRow(stmt, stripeid)
|
||||||
|
if row.Err() != nil {
|
||||||
|
log.Println(row.Err())
|
||||||
|
return 0, row.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
err := row.Scan(&id)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt = `INSERT INTO subscriptions (account_id, stripe_subscription_id, stripe_checkout_id, status) VALUES ($1, $2, $3, $4::subscription_status) RETURNING id`
|
||||||
|
|
||||||
|
var insertid int32
|
||||||
|
|
||||||
|
row = m.DB.QueryRow(stmt, id, string(stripesubscriptionid), string(stripecheckoutid), string(status))
|
||||||
|
if row.Err() != nil {
|
||||||
|
log.Println(row.Err())
|
||||||
|
return 0, row.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
err = row.Scan(&insertid)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return insertid, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *SubscriptionModel) Delete(id int32) error {
|
||||||
|
stmt := `DELETE FROM accounts WHERE id = $1`
|
||||||
|
|
||||||
|
_, err := m.DB.Exec(stmt, id)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *SubscriptionModel) GetSubscription(id int32) (models.Subscription, error) {
|
||||||
|
if id == 0 {
|
||||||
|
return models.Subscription{}, models.ErrNoRecord
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt := `SELECT id, account_id, stripe_subscription_id, stripe_checkout_id, status FROM subscriptions WHERE id = $1`
|
||||||
|
row := m.DB.QueryRow(stmt, id)
|
||||||
|
|
||||||
|
var subscription models.Subscription
|
||||||
|
err := row.Scan(&subscription.ID, &subscription.AccountID, &subscription.StripeSubscriptionID, &subscription.StripeCheckoutID, &subscription.Status)
|
||||||
|
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return models.Subscription{}, sql.ErrNoRows
|
||||||
|
} else if err != nil {
|
||||||
|
return models.Subscription{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
//log.Println(subscription.Status)
|
||||||
|
|
||||||
|
return subscription, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *SubscriptionModel) GetSubscriptionsFromAccount(accountid int32) ([]models.Subscription, error) {
|
||||||
|
if accountid == 0 {
|
||||||
|
return nil, models.ErrNoRecord
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt := `SELECT id, account_id, stripe_subscription_id, stripe_checkout_id, status FROM subscriptions WHERE account_id = $1`
|
||||||
|
rows, err := m.DB.Query(stmt, accountid)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var subscriptions []models.Subscription
|
||||||
|
for rows.Next() {
|
||||||
|
var subscription models.Subscription
|
||||||
|
err := rows.Scan(&subscription.ID, &subscription.AccountID, &subscription.StripeSubscriptionID, &subscription.StripeCheckoutID, &subscription.Status)
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, sql.ErrNoRows
|
||||||
|
} else if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
subscriptions = append(subscriptions, subscription)
|
||||||
|
}
|
||||||
|
|
||||||
|
return subscriptions, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *SubscriptionModel) HasActiveSubscription(accountid int32) bool {
|
||||||
|
subscriptions, err := m.GetSubscriptionsFromAccount(accountid)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, v := range subscriptions {
|
||||||
|
switch v.Status {
|
||||||
|
case models.Active:
|
||||||
|
return true
|
||||||
|
case models.Trialing:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user