diff --git a/handlers.go b/handlers.go index d97edb2..468db2e 100644 --- a/handlers.go +++ b/handlers.go @@ -13,18 +13,21 @@ import "strings"; import "unicode/utf8"; import "errors"; import "runtime/debug"; +import "github.com/stripe/stripe-go/v78"; +import "github.com/stripe/stripe-go/v78/price"; type templatedata struct { AuthenticatedUser int32; FormErrors map[string]string; Account Account; + Prices []stripe.Price } func favicon(w http.ResponseWriter, r *http.Request) { http.ServeFile(w, r, "favicon.ico"); } -func authenticated_user(r *http.Request) int32 { +func authenticated_user(w http.ResponseWriter, r *http.Request) int32 { session, err := store.Get(r, "id"); if err != nil { fmt.Println(err); @@ -39,6 +42,19 @@ func authenticated_user(r *http.Request) int32 { return 0; } + // check if the saved id exists in the database, otherwise it's a bad id and has to be removed from the cookies + + exists := users.Exists_account(id); + + if !exists { + session, _ := store.Get(r, "id"); + + session.Values["id"] = 0; + session.Save(r, w); + + return 0; + } + return id; } @@ -48,7 +64,7 @@ func home(w http.ResponseWriter, r *http.Request) { return; } - id := authenticated_user(r); + id := authenticated_user(w, r); account, err := users.Get_account(id); text, err := template.ParseFiles("ui/base.html", "ui/index.html"); @@ -111,7 +127,7 @@ func login(w http.ResponseWriter, r *http.Request) { if len(errors) > 0 { - err := text.Execute(w, templatedata{AuthenticatedUser: authenticated_user(r), FormErrors: errors}); + err := text.Execute(w, templatedata{AuthenticatedUser: authenticated_user(w, r), FormErrors: errors}); if err != nil { log.Fatal(err); http.Error(w, "Internal Server Error", 500); @@ -123,7 +139,7 @@ func login(w http.ResponseWriter, r *http.Request) { id, err := users.Authenticate(username, password); if err == ErrInvalidCredentials { errors["generic"] = "Email or Password is incorrect"; - err := text.Execute(w, templatedata{AuthenticatedUser: authenticated_user(r), FormErrors: errors}); + err := text.Execute(w, templatedata{AuthenticatedUser: authenticated_user(w, r), FormErrors: errors}); if err != nil { log.Fatal(err); http.Error(w, "Internal Server Error", 500); @@ -148,7 +164,7 @@ func logout(w http.ResponseWriter, r *http.Request) { log.Fatal(err); } - id := authenticated_user(r); + id := authenticated_user(w, r); account, err := users.Get_account(id); switch r.Method { @@ -161,9 +177,9 @@ func logout(w http.ResponseWriter, r *http.Request) { } case http.MethodPost: - session, _ := store.Get(r, "id");; + session, _ := store.Get(r, "id"); - session.Values["id"] = 0;; + session.Values["id"] = 0; session.Save(r, w); http.Redirect(w, r, "/", http.StatusSeeOther); } @@ -176,7 +192,7 @@ func register(w http.ResponseWriter, r *http.Request) { log.Fatal(err); } - id := authenticated_user(r); + id := authenticated_user(w, r); account, err := users.Get_account(id); switch r.Method { @@ -207,7 +223,7 @@ func register(w http.ResponseWriter, r *http.Request) { if len(errors) > 0 { - err := text.Execute(w, templatedata{AuthenticatedUser: authenticated_user(r), FormErrors: errors}); + err := text.Execute(w, templatedata{AuthenticatedUser: authenticated_user(w, r), FormErrors: errors}); if err != nil { log.Fatal(err); http.Error(w, "Internal Server Error", 500); @@ -237,11 +253,12 @@ func account(w http.ResponseWriter, r *http.Request) { // log.Fatal(err);; //}; - id := authenticated_user(r); + id := authenticated_user(w, r); account, err := users.Get_account(id); - text, err := template.ParseFiles("ui/base.html", "ui/account.html"); + fmt.Println(id, account) + text, err := template.ParseFiles("ui/base.html", "ui/account.html"); if err != nil { http.Error(w, "Internal Server Error", 500); log.Fatal(err); @@ -256,7 +273,7 @@ func account(w http.ResponseWriter, r *http.Request) { } } - //case http.MethodPost:; + //case http.MethodPost: // text.Execute(w, false); // if err != nil {; // log.Fatal(err); @@ -264,9 +281,9 @@ func account(w http.ResponseWriter, r *http.Request) { // }; } -func deleteaccount(w http.ResponseWriter, r *http.Request) {; +func deleteaccount(w http.ResponseWriter, r *http.Request) { - id := authenticated_user(r); + id := authenticated_user(w, r); switch r.Method { case http.MethodPost: @@ -281,3 +298,32 @@ func deleteaccount(w http.ResponseWriter, r *http.Request) {; http.Redirect(w, r, "/", http.StatusSeeOther); } } + +func subscribe(w http.ResponseWriter, r *http.Request) { + id := authenticated_user(w, r); + account, err := users.Get_account(id); + + params := &stripe.PriceListParams{}; + params.Limit = stripe.Int64(3) + params.AddExpand("data.product"); + results := price.List(params); + + prices := make([]stripe.Price, 0); + + for results.Next() { + fmt.Println(results.Current()) + prices = append(prices, *results.Price()); + } + + text, err := template.ParseFiles("ui/base.html", "ui/subscribe.html"); + if err != nil { + http.Error(w, "Internal Server Error", 500); + log.Fatal(err); + } + + err = text.Execute(w, templatedata{AuthenticatedUser: id, Account: account, Prices: prices}); + if err != nil { + log.Fatal(err); + http.Error(w, "Internal Server Error", 500); + } +} diff --git a/main.go b/main.go index 9d90857..455242f 100644 --- a/main.go +++ b/main.go @@ -24,31 +24,6 @@ var store = sessions.NewCookieStore(key); var emailrx = regexp.MustCompile("/^[a-zA-Z0-9.!#$%&'*+\\/=?^_`{|}~-]+@[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(?:\\.[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*$/"); -func secure_headers(next http.Handler) http.Handler {; - fn := func(w http.ResponseWriter, r *http.Request) {; - w.Header().Set("X-XSS-Protection", "1; mode=block"); - w.Header().Set("X-Frame-Options", "deny"); - - next.ServeHTTP(w, r); - } - - return http.HandlerFunc(fn); -} - -func require_authenticated_user(next http.HandlerFunc) http.HandlerFunc { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // If the user is not authenticated, redirect them to the login page and; - // return from the middleware chain so that no subsequent handlers in; - // the chain are executed.; - if authenticated_user(r) == 0 { - http.Redirect(w, r, "/login", http.StatusSeeOther); - return; - } - // Otherwise call the next handler in the chain.; - next.ServeHTTP(w, r); - }); -} - func main() { addr := flag.String("addr", ":8080", "HTTP network address"); flag.Parse(); @@ -103,6 +78,7 @@ func main() { mux.HandleFunc("/register", register); mux.HandleFunc("/account", require_authenticated_user(account)); mux.HandleFunc("/deleteaccount", require_authenticated_user(deleteaccount)); + mux.HandleFunc("/subscribe", subscribe); log.Fatal(http.ListenAndServe(*addr, secure_headers(mux))); diff --git a/middleware.go b/middleware.go new file mode 100644 index 0000000..b5daef6 --- /dev/null +++ b/middleware.go @@ -0,0 +1,28 @@ +package main; + +import "net/http"; + +func secure_headers(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-XSS-Protection", "1; mode=block"); + w.Header().Set("X-Frame-Options", "deny"); + + next.ServeHTTP(w, r); + } + + return http.HandlerFunc(fn); +} + +func require_authenticated_user(next http.HandlerFunc) http.HandlerFunc { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // If the user is not authenticated, redirect them to the login page and; + // return from the middleware chain so that no subsequent handlers in; + // the chain are executed.; + if authenticated_user(w, r) == 0 { + http.Redirect(w, r, "/login", http.StatusSeeOther); + return; + } + // Otherwise call the next handler in the chain.; + next.ServeHTTP(w, r); + }); +} diff --git a/models.go b/models.go index 1b05f97..e857de7 100644 --- a/models.go +++ b/models.go @@ -45,8 +45,8 @@ func (m *Usermodel) Insert(username string, password string, firstname string, l row := m.DB.QueryRow(stmt, username, string(hashedpassword), firstname, lastname, email); if row.Err() != nil { - fmt.Println(err); - return err; + fmt.Println(row.Err()); + return row.Err(); } err = row.Scan(&insertid); @@ -129,3 +129,17 @@ func (m *Usermodel) Authenticate(username string, password string) (int32, error return id, nil; } + +func (m *Usermodel) Exists_account(id int32) bool { + var exists bool; + stmt := `SELECT EXISTS(SELECT 1 FROM accounts WHERE id = $1);`; + row := m.DB.QueryRow(stmt, id); + if row.Err() != nil { + fmt.Println(row.Err()); + } + row.Scan(&exists); + + fmt.Println(exists); + + return exists; +} diff --git a/ui/base.html b/ui/base.html index 4c03cb0..fe47c53 100644 --- a/ui/base.html +++ b/ui/base.html @@ -19,6 +19,8 @@