This commit is contained in:
Vicente Ferrari Smith 2026-04-08 07:51:15 +02:00
parent 943408255c
commit ffa5cd9cd7
38 changed files with 2938 additions and 68 deletions

2
.vscode/launch.json vendored
View File

@ -10,7 +10,7 @@
"type": "go",
"request": "launch",
"mode": "auto",
"program": "${workspaceFolder}/cmd/web"
"program": "${workspaceFolder}/cmd/api"
}
]
}

26
cmd/api/context.go Normal file
View File

@ -0,0 +1,26 @@
package main
import (
"context"
"net/http"
"party.at/party/internal/data"
)
type contextKey string
const userContextKey = "user"
func (app *application) contextSetUser(r *http.Request, user *data.User) *http.Request {
ctx := context.WithValue(r.Context(), userContextKey, user)
return r.WithContext(ctx)
}
func (app *application) contextGetUser(r *http.Request) *data.User {
user, ok := r.Context().Value(userContextKey).(*data.User)
if !ok {
panic("missing user value in request context")
}
return user
}

View File

@ -1,19 +1,19 @@
package main
import (
"context"
"github.com/jackc/pgx/v5"
// "context"
"database/sql"
"log"
)
func database_init(conn *pgx.Conn) {
func database_init_dont_use(conn *sql.DB) {
sql := `DROP TABLE IF EXISTS vote;
DROP TABLE IF EXISTS vote_token;
DROP TABLE IF EXISTS option;
DROP TABLE IF EXISTS issue;
DROP TABLE IF EXISTS issues;
DROP TABLE IF EXISTS account;`
_, err := conn.Exec(context.Background(), sql)
_, err := conn.Exec(sql)
if err != nil {
log.Fatal(err)
}
@ -29,7 +29,7 @@ created TIMESTAMPTZ DEFAULT now() NOT NULL,
last_login TIMESTAMPTZ
)`
_, err = conn.Exec(context.Background(), sql)
_, err = conn.Exec(sql)
if err != nil {
log.Fatal(err)
}
@ -74,12 +74,12 @@ INSERT INTO account
(username, password_hash, first_name, last_name, email, created)
VALUES('yyZGFAsmPEBoSf', '$argon2id$v=19$m=65536,t=1,p=1$vLVm5ol6GQkXPU0BPut92g$Kyvp7/dl3lGUszXRiXfgsgB/IY0EKulZVpVttXQaDDU', 'svVjyPUaAkN', 'mdngeHlf', 'giyahyd4141@gmail.com', '2024-11-28 19:02:32.806');`
_, err = conn.Exec(context.Background(), sql)
_, err = conn.Exec(sql)
if err != nil {
log.Fatal(err)
}
sql = `CREATE TABLE issue (
sql = `CREATE TABLE issues (
id BIGINT GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
title VARCHAR(255) NOT NULL,
description TEXT,
@ -88,32 +88,32 @@ end_time TIMESTAMPTZ NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
)`
_, err = conn.Exec(context.Background(), sql)
_, err = conn.Exec(sql)
if err != nil {
log.Fatal(err)
}
sql = `CREATE TABLE option (
id BIGINT GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
issue_id BIGINT NOT NULL REFERENCES issue(id) ON DELETE CASCADE,
issue_id BIGINT NOT NULL REFERENCES issues(id) ON DELETE CASCADE,
label VARCHAR(255) NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
)`
_, err = conn.Exec(context.Background(), sql)
_, err = conn.Exec(sql)
if err != nil {
log.Fatal(err)
}
sql = `CREATE TABLE vote_token (
id BIGINT GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
issue_id BIGINT NOT NULL REFERENCES issue(id) ON DELETE CASCADE,
issue_id BIGINT NOT NULL REFERENCES issues(id) ON DELETE CASCADE,
token UUID NOT NULL UNIQUE DEFAULT gen_random_uuid(),
used BOOLEAN NOT NULL DEFAULT FALSE,
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
)`
_, err = conn.Exec(context.Background(), sql)
_, err = conn.Exec(sql)
if err != nil {
log.Fatal(err)
}
@ -125,7 +125,7 @@ option_id BIGINT NOT NULL REFERENCES option(id) ON DELETE CASCADE,
voted_at TIMESTAMPTZ NOT NULL DEFAULT now()
)`
_, err = conn.Exec(context.Background(), sql)
_, err = conn.Exec(sql)
if err != nil {
log.Fatal(err)
}
@ -134,7 +134,7 @@ voted_at TIMESTAMPTZ NOT NULL DEFAULT now()
CREATE INDEX idx_vote_tokens_issue_id ON vote_token(issue_id);
CREATE INDEX idx_options_issue_id ON option(issue_id);`
_, err = conn.Exec(context.Background(), sql)
_, err = conn.Exec(sql)
if err != nil {
log.Fatal(err)
}

104
cmd/api/errors.go Normal file
View File

@ -0,0 +1,104 @@
package main
import (
"fmt"
"net/http"
)
// The logError() method is a generic helper for logging an error message. Later in the
// book we'll upgrade this to use structured logging, and record additional information
// about the request including the HTTP method and URL.
func (app *application) logError(r *http.Request, err error) {
app.logger.PrintError(err, map[string]string{
"request_method": r.Method,
"request_url": r.URL.String(),
})
}
// The errorResponse() method is a generic helper for sending JSON-formatted error
// messages to the client with a given status code. Note that we're using an interface{}
// type for the message parameter, rather than just a string type, as this gives us
// more flexibility over the values that we can include in the response.
func (app *application) errorResponse(w http.ResponseWriter, r *http.Request, status int, message interface{}) {
env := envelope{"error": message}
// Write the response using the writeJSON() helper. If this happens to return an
// error then log it, and fall back to sending the client an empty response with a
// 500 Internal Server Error status code.
err := app.writeJSON(w, status, env, nil)
if err != nil {
app.logError(r, err)
w.WriteHeader(500)
}
}
// The serverErrorResponse() method will be used when our application encounters an
// unexpected problem at runtime. It logs the detailed error message, then uses the
// errorResponse() helper to send a 500 Internal Server Error status code and JSON
// response (containing a generic error message) to the client.
func (app *application) serverErrorResponse(w http.ResponseWriter, r *http.Request, err error) {
app.logError(r, err)
message :=
"the server encountered a problem and could not process your request"
app.errorResponse(w, r, http.StatusInternalServerError, message)
}
// The notFoundResponse() method will be used to send a 404 Not Found status code and
// JSON response to the client.
func (app *application) notFoundResponse(w http.ResponseWriter, r *http.Request) {
message := "the requested resource could not be found"
app.errorResponse(w, r, http.StatusNotFound, message)
}
// The methodNotAllowedResponse() method will be used to send a 405 Method Not Allowed
// status code and JSON response to the client.
func (app *application) methodNotAllowedResponse(w http.ResponseWriter, r *http.Request) {
message := fmt.Sprintf("the %s method is not supported for this resource", r.Method)
app.errorResponse(w, r, http.StatusMethodNotAllowed, message)
}
func (app *application) badRequestResponse(w http.ResponseWriter, r *http.Request, err error) {
app.errorResponse(w, r, http.StatusBadRequest, err.Error())
}
func (app *application) failedValidationResponse(w http.ResponseWriter, r *http.Request, errors map[string]string) {
app.errorResponse(w, r, http.StatusUnprocessableEntity, errors)
}
func (app *application) editConflictResponse(w http.ResponseWriter, r *http.Request) {
message := "unable to update the record due to an edit conflict, please try again"
app.errorResponse(w, r, http.StatusConflict, message)
}
func (app *application) rateLimitExceededResponse(w http.ResponseWriter, r *http.Request) {
message := "rate limit exceeded"
app.errorResponse(w, r, http.StatusTooManyRequests, message)
}
func (app *application) invalidCredentialsResponse(w http.ResponseWriter, r *http.Request) {
message := "invalid authentication credentials"
app.errorResponse(w, r, http.StatusUnauthorized, message)
}
func (app *application) invalidAuthenticationTokenResponse(w http.ResponseWriter, r *http.Request) {
w.Header().Set("WWW-Authenticate", "Bearer")
message := "invalid or missing authentication token"
app.errorResponse(w, r, http.StatusUnauthorized, message)
}
func (app *application) authenticationRequiredResponse(w http.ResponseWriter, r *http.Request) {
message :="you must be authenticated to access this resource"
app.errorResponse(w, r, http.StatusUnauthorized, message)
}
func (app *application) inactiveAccountResponse(w http.ResponseWriter, r *http.Request) {
message := "your user account must be activated to access this resource"
app.errorResponse(w, r, http.StatusForbidden, message)
}
func (app *application) notPermittedResponse(w http.ResponseWriter, r *http.Request) {
message := "your user account doesn't have the necessary permissions to access this resource"
app.errorResponse(w, r, http.StatusForbidden, message)
}

View File

@ -1,11 +1,17 @@
package main
import (
// "encoding/json"
"fmt"
"html/template"
"log"
"net/http"
"time"
"log"
"html/template"
// "github.com/julienschmidt/httprouter"
// "strconv"
// "errors"
// "party.at/party/internal/data"
// "party.at/party/internal/validator"
)
func home(w http.ResponseWriter, r *http.Request) {
@ -48,7 +54,7 @@ func ws(w http.ResponseWriter, r *http.Request) {
done := make(chan struct{})
go func() {
ticker := time.NewTicker(1 * time.Second);
ticker := time.NewTicker(1 * time.Second)
for {
select {
@ -60,7 +66,6 @@ func ws(w http.ResponseWriter, r *http.Request) {
"timestamp": t.Format(time.RFC3339),
}
if err := conn.WriteJSON(msg); err != nil {
fmt.Println("Write error:", err)
return
@ -87,6 +92,33 @@ func ws(w http.ResponseWriter, r *http.Request) {
}
}
// func handleMobileLogin(w http.ResponseWriter, r *http.Request) {
// // 1. Get the token from the request header
// rawIDToken := r.Header.Get("Authorization")
// // 2. Initialize the verifier (pointing to ID Austria's keys)
// verifier := provider.Verifier(&oidc.Config{ClientID: "YOUR_APP_ID"})
// // 3. Verify the signature and expiration
// idToken, err := verifier.Verify(ctx, rawIDToken)
// if err != nil {
// http.Error(w, "Invalid Token", http.StatusUnauthorized)
// return
// }
// // 4. Extract User Data
// var claims struct {
// Subject string `json:"sub"` // This is the unique ID
// Name string `json:"name"`
// }
// if err := idToken.Claims(&claims); err != nil {
// // Handle error
// }
// // 5. Create your own application session for the mobile app
// issueLocalSession(w, claims.Subject)
// }
func redirectToHTTPS(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "https://localhost:8443" + r.URL.RequestURI(), http.StatusMovedPermanently)
http.Redirect(w, r, "https://localhost:8443"+r.URL.RequestURI(), http.StatusMovedPermanently)
}

View File

@ -1,12 +1,22 @@
package main
import (
"fmt"
// "encoding/json"
//"fmt"
"net/http"
)
func (app *application) healthcheckHandler(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "status: available")
fmt.Fprintf(w, "environment: %s\n", app.config.env)
fmt.Fprintf(w, "version: %s\n", version)
env := envelope{
"status": "available",
"system_info": map[string]string{
"environment": app.config.env,
"version": version,
},
}
err := app.writeJSON(w, http.StatusOK, envelope{"health_check": env}, nil)
if err != nil {
app.serverErrorResponse(w,r, err)
}
}

193
cmd/api/helpers.go Normal file
View File

@ -0,0 +1,193 @@
package main
import (
"net/http"
"encoding/json"
"strings"
"errors"
"fmt"
"io"
"net/url"
"strconv"
"party.at/party/internal/validator"
"github.com/julienschmidt/httprouter"
)
type envelope map[string]interface{}
func (app *application) readString(qs url.Values, key string, defaultValue string) string {
// Extract the value for a given key from the query string. If no key exists this
// will return the empty string "".
s := qs.Get(key)
// If no key exists (or the value is empty) then return the default value.
if s == "" {
return defaultValue
}
// Otherwise return the string.
return s
}
func (app *application) readCSV(qs url.Values, key string, defaultValue []string) []string {
// Extract the value from the query string.
csv := qs.Get(key)
// If no key exists (or the value is empty) then return the default value.
if csv == "" {
return defaultValue
}
// Otherwise parse the value into a []string slice and return it.
return strings.Split(csv, ",")
}
func (app *application) readInt(qs url.Values, key string, defaultValue int, v *validator.Validator) int {
// Extract the value from the query string.
s := qs.Get(key)
// If no key exists (or the value is empty) then return the default value.
if s == "" {
return defaultValue
}
// Try to convert the value to an int. If this fails, add an error message to the
// validator instance and return the default value.
i, err := strconv.Atoi(s)
if err != nil {
v.AddError(key, "must be an integer value")
return defaultValue
}
// Otherwise, return the converted integer value.
return i
}
func (app *application) readIDParam(r *http.Request) (int64, error) {
params := httprouter.ParamsFromContext(r.Context())
id, err := strconv.ParseInt(params.ByName("id"), 10, 64)
if err != nil || id < 1 {
return 0, errors.New("invalid id parameter")
}
return id, nil
}
func (app *application) writeJSON(w http.ResponseWriter, status int, data envelope, headers http.Header) error {
// Encode the data to JSON, returning the error if there was one.
js, err := json.MarshalIndent(data, "", "\t")
if err != nil {
return err
}
// Append a newline to make it easier to view in terminal applications.
js = append(js, '\n')
// At this point, we know that we won't encounter any more errors before writing the
// response, so it's safe to add any headers that we want to include. We loop
// through the header map and add each header to the http.ResponseWriter header map.
// Note that it's OK if the provided header map is nil. Go doesn't throw an error
// if you try to range over (or generally, read from) a nil map.
for key, value := range headers {
w.Header()[key] = value
}
// Add the "Content-Type: application/json" header, then write the status code and
// JSON response.
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
w.Write(js)
return nil
}
func (app *application) readJSON(w http.ResponseWriter, r *http.Request, dst interface{}) error {
maxBytes := 1048576
r.Body = http.MaxBytesReader(w, r.Body, int64(maxBytes))
// Initialize the json.Decoder, and call the DisallowUnknownFields() method on it
// before decoding. This means that if the JSON from the client now includes any
// field which cannot be mapped to the target destination, the decoder will return
// an error instead of just ignoring the field.
dec := json.NewDecoder(r.Body)
dec.DisallowUnknownFields()
// Decode the request body to the destination.
err := dec.Decode(dst)
if err != nil {
var syntaxError *json.SyntaxError
var unmarshalTypeError *json.UnmarshalTypeError
var invalidUnmarshalError *json.InvalidUnmarshalError
switch {
case errors.As(err, &syntaxError):
return fmt.Errorf("body contains badly-formed JSON (at character %d)", syntaxError.Offset)
case errors.Is(err, io.ErrUnexpectedEOF):
return errors.New("body contains badly-formed JSON")
case errors.As(err, &unmarshalTypeError):
if unmarshalTypeError.Field != "" {
return fmt.Errorf("body contains incorrect JSON type for field %q", unmarshalTypeError.Field)
}
return fmt.Errorf("body contains incorrect JSON type (at character %d)", unmarshalTypeError.Offset)
case errors.Is(err, io.EOF):
return errors.New("body must not be empty")
// If the JSON contains a field which cannot be mapped to the target destination
// then Decode() will now return an error message in the format "json: unknown
// field "<name>"". We check for this, extract the field name from the error,
// and interpolate it into our custom error message. Note that there's an open
// issue at https://github.com/golang/go/issues/29035 regarding turning this
// into a distinct error type in the future.
case strings.HasPrefix(err.Error(), "json: unknown field "):
fieldName := strings.TrimPrefix(err.Error(), "json: unknown field ")
return fmt.Errorf("body contains unknown key %s", fieldName)
// If the request body exceeds 1MB in size the decode will now fail with the
// error "http: request body too large". There is an open issue about turning
// this into a distinct error type at https://github.com/golang/go/issues/30715.
case err.Error() == "http: request body too large":
return fmt.Errorf("body must not be larger than %d bytes", maxBytes)
case errors.As(err, &invalidUnmarshalError):
panic(err)
default:
return err
}
}
// Call Decode() again, using a pointer to an empty anonymous struct as the
// destination. If the request body only contained a single JSON value this will
// return an io.EOF error. So if we get anything else, we know that there is
// additional data in the request body and we return our own custom error message.
err = dec.Decode(&struct{}{})
if err != io.EOF {
return errors.New("body must only contain a single JSON value")
}
return nil
}
func (app *application) background(fn func()) {
// Launch a background goroutine.
go func() {
// Recover any panic.
defer func() {
if err := recover(); err != nil {
app.logger.PrintError(fmt.Errorf("%s", err), nil)
}
}()
// Execute the arbitrary function that we passed as the parameter.
fn()
}()
}

221
cmd/api/issues.go Normal file
View File

@ -0,0 +1,221 @@
package main
import (
// "encoding/json"
"fmt"
// "html/template"
// "log"
"net/http"
"time"
// "github.com/julienschmidt/httprouter"
// "strconv"
"errors"
"party.at/party/internal/data"
"party.at/party/internal/validator"
)
func (app *application) listIssuesHandler(w http.ResponseWriter, r *http.Request) {
// To keep things consistent with our other handlers, we'll define an input struct
// to hold the expected values from the request query string.
var input struct {
Title string
data.Filters
}
// Initialize a new Validator instance.
v := validator.New()
// Call r.URL.Query() to get the url.Values map containing the query string data.
qs := r.URL.Query()
// Use our helpers to extract the title and genres query string values, falling back
// to defaults of an empty string and an empty slice respectively if they are not
// provided by the client.
input.Title = app.readString(qs, "title", "")
// input.Genres = app.readCSV(qs, "genres", []string{})
// Get the page and page_size query string values as integers. Notice that we set
// the default page value to 1 and default page_size to 20, and that we pass the
// validator instance as the final argument here.
input.Filters.Page = app.readInt(qs, "page", 1, v)
input.Filters.PageSize = app.readInt(qs, "page_size", 20, v)
// Extract the sort query string value, falling back to "id" if it is not provided
// by the client (which will imply a ascending sort on issue ID).
input.Filters.Sort = app.readString(qs, "sort", "id")
input.Filters.SortSafelist = []string{"id", "-id", "title", "-title", "description", "-description"}
if data.ValidateFilters(v, input.Filters); !v.Valid() {
app.failedValidationResponse(w, r, v.Errors)
return
}
issues, metadata, err := app.models.Issues.GetAll(input.Title, input.Filters)
if err != nil {
app.serverErrorResponse(w, r, err)
return
}
err = app.writeJSON(w, http.StatusOK, envelope{"issues": issues, "metadata": metadata}, nil)
if err != nil {
app.serverErrorResponse(w, r, err)
}
}
func (app *application) createIssueHandler(w http.ResponseWriter, r *http.Request) {
var input struct {
Title string `json:"title"`
Description string `json:"description"`
StartTime time.Time `json:"start_time"`
EndTime time.Time `json:"end_time"`
}
err := app.readJSON(w, r, &input)
if err != nil {
// Use the new badRequestResponse() helper.
app.badRequestResponse(w, r, err)
return
}
issue := &data.Issue{
Title: input.Title,
Description: input.Description,
StartTime: input.StartTime,
EndTime: input.EndTime,
}
v := validator.New()
if data.ValidateIssue(v, issue); !v.Valid() {
app.failedValidationResponse(w, r, v.Errors)
return
}
err = app.models.Issues.Insert(issue)
if err != nil {
app.serverErrorResponse(w, r, err)
return
}
headers := make(http.Header)
headers.Set("Location", fmt.Sprintf("/v1/issues/%d", issue.ID))
err = app.writeJSON(w, http.StatusCreated, envelope{"issue": issue}, headers)
if err != nil {
app.serverErrorResponse(w, r, err)
}
}
func (app *application) readIssueHandler(w http.ResponseWriter, r *http.Request) {
id, err := app.readIDParam(r)
if err != nil {
app.notFoundResponse(w, r)
return
}
issue, err := app.models.Issues.Get(id)
if err != nil {
switch {
case errors.Is(err, data.ErrRecordNotFound):
app.notFoundResponse(w, r)
default:
app.serverErrorResponse(w, r, err)
}
return
}
// Encode the struct to JSON and send it as the HTTP response.
err = app.writeJSON(w, http.StatusOK, envelope{"issue": issue}, nil)
if err != nil {
app.serverErrorResponse(w, r, err)
}
}
func (app *application) updateIssueHandler(w http.ResponseWriter, r *http.Request) {
id, err := app.readIDParam(r)
if err != nil {
app.notFoundResponse(w, r)
return
}
issue, err := app.models.Issues.Get(id)
if err != nil {
switch {
case errors.Is(err, data.ErrRecordNotFound):
app.notFoundResponse(w, r)
default:
app.serverErrorResponse(w, r, err)
}
return
}
var input struct {
Title *string `json:"title"`
Description *string `json:"description"`
StartTime *time.Time `json:"start_time"`
EndTime *time.Time `json:"end_time"`
}
err = app.readJSON(w, r, &input)
if err != nil {
app.badRequestResponse(w, r, err)
return
}
if input.Title != nil { issue.Title = *input.Title }
if input.Description != nil { issue.Description = *input.Description }
if input.StartTime != nil { issue.StartTime = *input.StartTime }
if input.StartTime != nil { issue.EndTime = *input.EndTime }
v := validator.New()
if data.ValidateIssue(v, issue); !v.Valid() {
app.failedValidationResponse(w, r, v.Errors)
return
}
err = app.models.Issues.Update(issue)
if err != nil {
switch {
case errors.Is(err, data.ErrEditConflict):
app.editConflictResponse(w, r)
default:
app.serverErrorResponse(w, r, err)
}
return
}
// Write the updated issue record in a JSON response.
err = app.writeJSON(w, http.StatusOK, envelope{"issue": issue}, nil)
if err != nil {
app.serverErrorResponse(w, r, err)
}
}
func (app *application) deleteIssueHandler(w http.ResponseWriter, r *http.Request) {
id, err := app.readIDParam(r)
if err != nil {
app.notFoundResponse(w, r)
return
}
// Delete the issue from the database, sending a 404 Not Found response to the
// client if there isn't a matching record.
err = app.models.Issues.Delete(id)
if err != nil {
switch {
case errors.Is(err, data.ErrRecordNotFound):
app.notFoundResponse(w, r)
default:
app.serverErrorResponse(w, r, err)
}
return
}
// Return a 200 OK status code along with a success message.
err = app.writeJSON(w, http.StatusOK, envelope{"message": "issue successfully deleted"}, nil)
if err != nil {
app.serverErrorResponse(w, r, err)
}
}

View File

@ -3,31 +3,63 @@ package main
import (
"context"
"log"
"strings"
//"html/template"
"net/http"
"flag"
"net/http"
"github.com/gorilla/websocket"
"github.com/jackc/pgx/v5"
)
import (
"fmt"
"os"
"time"
"database/sql"
_ "github.com/lib/pq"
"party.at/party/internal/data"
"party.at/party/internal/mailer"
"party.at/party/internal/jsonlog"
)
const version = "1.0.0"
type config struct {
port int
env string
env string
db struct {
dsn string
maxOpenConns int
maxIdleConns int
maxIdleTime string
}
limiter struct {
rps float64
burst int
enabled bool
}
smtp struct {
host string
port int
username string
password string
sender string
}
cors struct {
trustedOrigins []string
}
}
type application struct {
config config
logger *log.Logger
logger *jsonlog.Logger
models data.Models
mailer mailer.Mailer
}
type Message struct {
@ -47,39 +79,84 @@ func main() {
var cfg config
flag.IntVar(&cfg.port, "port", 4000, "API server port")
flag.StringVar(&cfg.env, "env", "development", "Environment (development|staging|production)")
flag.StringVar(&cfg.env, "env", "development", "Environment (development|staging|production)")
flag.StringVar(&cfg.db.dsn, "db-dsn", os.Getenv("PARTY_DB_DSN"), "PostgreSQL DSN")
//addr := flag.String("addr", ":8443", "HTTP network address")
flag.IntVar(&cfg.db.maxOpenConns, "db-max-open-conns", 25, "PostgreSQL max open connections")
flag.IntVar(&cfg.db.maxIdleConns, "db-max-idle-conns", 25, "PostgreSQL max idle connections")
flag.StringVar(&cfg.db.maxIdleTime, "db-max-idle-time", "15m", "PostgreSQL max connection idle time")
flag.StringVar(&cfg.smtp.host, "smtp-host", "smtp.mailtrap.io", "SMTP host")
flag.IntVar(&cfg.smtp.port, "smtp-port", 25, "SMTP port")
flag.StringVar(&cfg.smtp.username, "smtp-username", "98cf60028d7fcb", "SMTP username")
flag.StringVar(&cfg.smtp.password, "smtp-password", "b9d4a35372e971", "SMTP password")
flag.StringVar(&cfg.smtp.sender, "smtp-sender", "Greenlight <no-reply@greenlight.alexedwards.net>", "SMTP sender")
flag.Float64Var(&cfg.limiter.rps, "limiter-rps", 2, "Rate limiter maximum requests per second")
flag.IntVar(&cfg.limiter.burst, "limiter-burst", 4, "Rate limiter maximum burst")
flag.BoolVar(&cfg.limiter.enabled, "limiter-enabled", true, "Enable rate limiter")
flag.Func("cors-trusted-origins", "Trusted CORS origins (space separated)", func(val string) error {
cfg.cors.trustedOrigins = strings.Fields(val)
return nil
})
flag.Parse()
logger := log.New(os.Stdout, "", log.Ldate | log.Ltime)
logger := jsonlog.New(os.Stdout, jsonlog.LevelInfo)
log.Printf("%s\n", cfg.db.dsn)
db, err := openDB(cfg)
if err != nil {
logger.PrintFatal(err, nil)
}
defer db.Close()
app := &application{
config: cfg,
logger: logger,
models: data.NewModels(db),
mailer: mailer.New(cfg.smtp.host, cfg.smtp.port, cfg.smtp.username, cfg.smtp.password, cfg.smtp.sender),
}
log.Println("Hello, Sailor!")
conn, err := pgx.Connect(context.Background(), "postgres://party:iK2SoVbDhdCki5n3LxGyP6zKpLspt4@losandesgames.com:5432/party")
err = app.serve()
if err != nil {
log.Fatal(err)
}
defer conn.Close(context.Background())
database_init(conn)
srv := &http.Server{
Addr: fmt.Sprintf(":%d", cfg.port),
Handler: app.routes(),
IdleTimeout: time.Minute,
ReadTimeout: 10 * time.Second,
WriteTimeout: 30 * time.Second,
}
logger.Printf("starting %s server on %s", cfg.env, srv.Addr);
// Start HTTPS server (requires cert.pem and key.pem in current dir)
err = srv.ListenAndServe()
if err != nil {
panic(err)
logger.PrintFatal(err, nil)
}
}
func openDB(cfg config) (*sql.DB, error) {
db, err := sql.Open("postgres", cfg.db.dsn)
if err != nil {
return nil, err
}
db.SetMaxOpenConns(cfg.db.maxOpenConns)
db.SetMaxIdleConns(cfg.db.maxIdleConns)
duration, err := time.ParseDuration(cfg.db.maxIdleTime)
if err != nil {
return nil, err
}
db.SetConnMaxIdleTime(duration)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// Use PingContext() to establish a new connection to the database, passing in the
// context we created above as a parameter. If the connection couldn't be
// established successfully within the 5 second deadline, then this will return an
// error.
err = db.PingContext(ctx)
if err != nil {
return nil, err
}
return db, nil
}

215
cmd/api/middleware.go Normal file
View File

@ -0,0 +1,215 @@
package main
import (
"fmt"
"net"
"net/http"
"sync"
"time"
"strings"
"errors"
"golang.org/x/time/rate"
"party.at/party/internal/data"
"party.at/party/internal/validator"
)
func (app *application) recoverPanic(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Create a deferred function (which will always be run in the event of a panic
// as Go unwinds the stack).
defer func() {
// Use the builtin recover function to check if there has been a panic or
// not.
if err := recover(); err != nil {
// If there was a panic, set a "Connection: close" header on the
// response. This acts as a trigger to make Go's HTTP server
// automatically close the current connection after a response has been
// sent.
w.Header().Set("Connection", "close")
// The value returned by recover() has the type interface{}, so we use
// fmt.Errorf() to normalize it into an error and call our
// serverErrorResponse() helper. In turn, this will log the error using
// our custom Logger type at the ERROR level and send the client a 500
// Internal Server Error response.
app.serverErrorResponse(w, r, fmt.Errorf("%s", err))
}
}()
next.ServeHTTP(w, r)
})
}
func (app *application) rateLimit(next http.Handler) http.Handler {
type client struct {
limiter *rate.Limiter
lastSeen time.Time
}
var mu sync.Mutex
var clients = make(map[string]*client)
go func() {
for {
time.Sleep(time.Minute)
// Lock the mutex to prevent any rate limiter checks from happening while
// the cleanup is taking place.
mu.Lock()
// Loop through all clients. If they haven't been seen within the last three
// minutes, delete the corresponding entry from the map.
for ip, client := range clients {
if time.Since(client.lastSeen) > 3*time.Minute {
delete(clients, ip)
}
}
// Importantly, unlock the mutex when the cleanup is complete.
mu.Unlock()
}
}()
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if app.config.limiter.enabled {
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
app.serverErrorResponse(w, r, err)
return
}
mu.Lock()
if _, found := clients[ip]; !found {
clients[ip] = &client{limiter: rate.NewLimiter(rate.Limit(app.config.limiter.rps), app.config.limiter.burst)}
}
clients[ip].lastSeen = time.Now()
if !clients[ip].limiter.Allow() {
mu.Unlock()
app.rateLimitExceededResponse(w, r)
return
}
mu.Unlock()
}
next.ServeHTTP(w, r)
})
}
func (app *application) authenticate(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Vary", "Authorization")
authorizationHeader := r.Header.Get("Authorization")
if authorizationHeader == "" {
r = app.contextSetUser(r, data.AnonymousUser)
next.ServeHTTP(w, r)
return
}
headerParts := strings.Split(authorizationHeader, " ")
if len(headerParts) != 2 || headerParts[0] != "Bearer" {
app.invalidAuthenticationTokenResponse(w, r)
return
}
token := headerParts[1]
v := validator.New()
if data.ValidateTokenPlaintext(v, token); !v.Valid() {
app.invalidAuthenticationTokenResponse(w, r)
return
}
userIdentity, err := app.models.UserIdentities.GetForToken(data.ScopeAuthentication, token)
if err != nil {
switch {
case errors.Is(err, data.ErrRecordNotFound):
app.invalidAuthenticationTokenResponse(w, r)
default:
app.serverErrorResponse(w, r, err)
}
return
}
user, err := app.models.Users.Get(userIdentity.UserID)
if err != nil {
switch {
case errors.Is(err, data.ErrRecordNotFound):
app.invalidCredentialsResponse(w, r)
default:
app.serverErrorResponse(w, r, err)
}
return
}
// Call the contextSetUser() helper to add the user information to the request
// context.
r = app.contextSetUser(r, user)
// Call the next handler in the chain.
next.ServeHTTP(w, r)
})
}
func (app *application) requireAuthenticatedUser(next http.HandlerFunc) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user := app.contextGetUser(r)
if user.IsAnonymous() {
app.authenticationRequiredResponse(w, r)
return
}
next.ServeHTTP(w, r)
})
}
func (app *application) requireActivatedUser(next http.HandlerFunc) http.HandlerFunc {
// Rather than returning this http.HandlerFunc we assign it to the variable fn.
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user := app.contextGetUser(r)
// Check that a user is activated.
if !user.Activated {
app.inactiveAccountResponse(w, r)
return
}
next.ServeHTTP(w, r)
})
// Wrap fn with the requireAuthenticatedUser() middleware before returning it.
return app.requireAuthenticatedUser(fn)
}
func (app *application) requirePermission(code string, next http.HandlerFunc) http.HandlerFunc {
fn := func(w http.ResponseWriter, r *http.Request) {
// Retrieve the user from the request context.
user := app.contextGetUser(r)
// Get the slice of permissions for the user.
permissions, err := app.models.Permissions.GetAllForUser(user.ID)
if err != nil {
app.serverErrorResponse(w, r, err)
return
}
// Check if the slice includes the required permission. If it doesn't, then
// return a 403 Forbidden response.
if !permissions.Include(code) {
app.notPermittedResponse(w, r)
return
}
// Otherwise they have the required permission so we call the next handler in
// the chain.
next.ServeHTTP(w, r)
}
// Wrap this with the requireActivatedUser() middleware before returning it.
return app.requireActivatedUser(fn)
}

View File

@ -5,18 +5,34 @@ import (
"github.com/julienschmidt/httprouter"
)
func (app *application) routes() *httprouter.Router {
func (app *application) routes() http.Handler {
router := httprouter.New()
router.NotFound = http.HandlerFunc(app.notFoundResponse)
router.MethodNotAllowed = http.HandlerFunc(app.methodNotAllowedResponse)
fileServer := http.FileServer(http.Dir("ui/static"))
router.HandlerFunc(http.MethodGet, "/", home)
router.HandlerFunc(http.MethodGet, "/ws", ws)
router.Handler(http.MethodGet, "/static/", http.StripPrefix("/static", fileServer))
router.HandlerFunc(http.MethodGet, "/v1/healthcheck", app.healthcheckHandler)
// router.HandlerFunc(http.MethodPost, "/v1/movies", app.createMovieHandler)
// router.HandlerFunc(http.MethodGet, "/v1/movies/:id", app.showMovieHandler)
router.HandlerFunc(http.MethodGet, "/", home)
router.HandlerFunc(http.MethodGet, "/ws", ws)
router.Handler (http.MethodGet, "/static/", http.StripPrefix("/static", fileServer))
router.HandlerFunc(http.MethodGet, "/v1/healthcheck", app.healthcheckHandler)
router.HandlerFunc(http.MethodGet, "/v1/issues", app.requirePermission("issues:read", app.listIssuesHandler))
return router
router.HandlerFunc(http.MethodPost, "/v1/issues", app.requirePermission("issues:write", app.createIssueHandler))
router.HandlerFunc(http.MethodGet, "/v1/issues/:id", app.requirePermission("issues:read", app.readIssueHandler))
router.HandlerFunc(http.MethodPatch, "/v1/issues/:id", app.requirePermission("issues:write", app.updateIssueHandler))
router.HandlerFunc(http.MethodDelete, "/v1/issues/:id", app.requirePermission("issues:write", app.deleteIssueHandler))
router.HandlerFunc(http.MethodPost, "/v1/users", app.createUserHandler)
// router.HandlerFunc(http.MethodGet, "/v1/users/:id", app.readUserHandler)
// router.HandlerFunc(http.MethodPatch, "/v1/users/:id", app.updateUserHandler)
router.HandlerFunc(http.MethodDelete, "/v1/users/:id", app.deleteUserHandler)
router.HandlerFunc(http.MethodPut, "/v1/users/activated", app.activateUserHandler)
router.HandlerFunc(http.MethodPost, "/v1/tokens/authentication", app.createAuthenticationTokenHandler)
return app.recoverPanic(app.enableCORS(app.rateLimit(app.authenticate(router))))
}

67
cmd/api/server.go Normal file
View File

@ -0,0 +1,67 @@
package main
import (
"context"
"errors"
"fmt"
"net/http"
"time"
"log"
"os"
"os/signal"
"syscall"
)
func (app *application) serve() error {
// Declare a HTTP server using the same settings as in our main() function.
srv := &http.Server{
Addr: fmt.Sprintf(":%d", app.config.port),
Handler: app.routes(),
ErrorLog: log.New(app.logger, "", 0),
IdleTimeout: time.Minute,
ReadTimeout: 10 * time.Second,
WriteTimeout: 30 * time.Second,
}
shutdownError := make(chan error)
go func() {
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
s := <-quit
app.logger.PrintInfo("shutting down server", map[string]string{
"signal": s.String(),
})
ctx, cancel := context.WithTimeout(context.Background(), 5 * time.Second)
defer cancel()
shutdownError <- srv.Shutdown(ctx)
}()
// Likewise log a "starting server" message.
app.logger.PrintInfo("starting server", map[string]string{
"addr": srv.Addr,
"env": app.config.env,
})
err := srv.ListenAndServe()
if !errors.Is(err, http.ErrServerClosed) {
return err
}
err = <-shutdownError
if err != nil {
return err
}
app.logger.PrintInfo("stopped server", map[string]string{
"addr": srv.Addr,
})
return nil
}

77
cmd/api/tokens.go Normal file
View File

@ -0,0 +1,77 @@
package main
import (
"errors"
"net/http"
"time"
"party.at/party/internal/data"
"party.at/party/internal/validator"
)
func (app *application) createAuthenticationTokenHandler(w http.ResponseWriter, r *http.Request) {
var input struct {
Email string `json:"email"`
Password string `json:"password"`
}
err := app.readJSON(w, r, &input)
if err != nil {
app.badRequestResponse(w, r, err)
return
}
// Validate the email and password provided by the client.
v := validator.New()
data.ValidateEmail(v, input.Email)
data.ValidatePasswordPlaintext(v, input.Password)
if !v.Valid() {
app.failedValidationResponse(w, r, v.Errors)
return
}
user, err := app.models.Users.GetByEmail(input.Email)
if err != nil {
switch {
case errors.Is(err, data.ErrRecordNotFound):
app.invalidCredentialsResponse(w, r)
default:
app.serverErrorResponse(w, r, err)
}
return
}
identities, err := app.models.UserIdentities.GetByUser(user.ID)
var authenticatedIdentity *data.UserIdentity
for _, user_identity := range identities {
match, err := user_identity.Password.Matches(input.Password)
if err != nil {
app.serverErrorResponse(w, r, err)
return
}
if match {
authenticatedIdentity = user_identity
break
}
}
if authenticatedIdentity == nil {
app.invalidCredentialsResponse(w, r)
return
}
token, err := app.models.Tokens.New(user.ID, authenticatedIdentity.ID, 24 * time.Hour, data.ScopeAuthentication)
if err != nil {
app.serverErrorResponse(w, r, err)
return
}
err = app.writeJSON(w, http.StatusCreated, envelope{"authentication_token": token}, nil)
if err != nil {
app.serverErrorResponse(w, r, err)
}
}

197
cmd/api/users.go Normal file
View File

@ -0,0 +1,197 @@
package main
import (
"errors"
"net/http"
"party.at/party/internal/data"
"party.at/party/internal/validator"
"database/sql"
"time"
)
func (app *application) createUserHandler(w http.ResponseWriter, r *http.Request) {
var input struct {
ProviderId int64 `json:"provider_id"`
Username string `json:"username"`
Email string `json:"email"`
Password string `json:"password"`
Name string `json:"name"`
AltName string `json:"alt_name"`
}
err := app.readJSON(w, r, &input)
if err != nil {
app.badRequestResponse(w, r, err)
return
}
user := &data.User{
Email: input.Email,
Name: input.Name,
AltName: sql.NullString{String: input.AltName, Valid: true},
Activated: false,
}
userIdentity := &data.UserIdentity{
ProviderID: input.ProviderId,
ProviderUserID: input.Username,
}
err = userIdentity.Password.Set(input.Password)
if err != nil {
app.serverErrorResponse(w, r, err)
return
}
v := validator.New()
if data.ValidateUser(v, user); !v.Valid() {
app.failedValidationResponse(w, r, v.Errors)
return
}
if data.ValidateUserIdentity(v, userIdentity); !v.Valid() {
app.failedValidationResponse(w, r, v.Errors)
return
}
err = app.models.Users.ExecuteRegistrationTx(user, userIdentity)
if err != nil {
switch {
case errors.Is(err, data.ErrDuplicateEmail):
v.AddError("email", "a user with this email address already exists")
app.failedValidationResponse(w, r, v.Errors)
case errors.Is(err, data.ErrDuplicateUser):
v.AddError("username", "a user with this username already exists")
app.failedValidationResponse(w, r, v.Errors)
default:
app.serverErrorResponse(w, r, err)
}
return
}
err = app.models.Permissions.AddForUser(user.ID, "issues:read")
if err != nil {
app.serverErrorResponse(w, r, err)
return
}
token, err := app.models.Tokens.New(user.ID, userIdentity.ID, 3 * 24 * time.Hour, data.ScopeActivation)
if err != nil {
app.serverErrorResponse(w, r, err)
return
}
app.background(func() {
data := map[string]interface{}{
"token": token.Plaintext,
"userID": user.ID,
}
err = app.mailer.Send(user.Email, "user_welcome.tmpl", data)
if err != nil {
app.logger.PrintError(err, nil)
}
})
// Write a JSON response containing the user data along with a 201 Created status
// code.
err = app.writeJSON(w, http.StatusCreated, envelope{"user": user}, nil)
if err != nil {
app.serverErrorResponse(w, r, err)
}
}
func (app *application) deleteUserHandler(w http.ResponseWriter, r *http.Request) {
id, err := app.readIDParam(r)
if err != nil {
app.notFoundResponse(w, r)
return
}
// Delete the issue from the database, sending a 404 Not Found response to the
// client if there isn't a matching record.
err = app.models.Users.Delete(id)
if err != nil {
switch {
case errors.Is(err, data.ErrRecordNotFound):
app.notFoundResponse(w, r)
default:
app.serverErrorResponse(w, r, err)
}
return
}
// Return a 200 OK status code along with a success message.
err = app.writeJSON(w, http.StatusOK, envelope{"message": "user successfully deleted"}, nil)
if err != nil {
app.serverErrorResponse(w, r, err)
}
}
func (app *application) activateUserHandler(w http.ResponseWriter, r *http.Request) {
// Parse the plaintext activation token from the request body.
var input struct {
TokenPlaintext string `json:"token"`
}
err := app.readJSON(w, r, &input)
if err != nil {
app.badRequestResponse(w, r, err)
return
}
// Validate the plaintext token provided by the client.
v := validator.New()
if data.ValidateTokenPlaintext(v, input.TokenPlaintext); !v.Valid() {
app.failedValidationResponse(w, r, v.Errors)
return
}
// Retrieve the details of the user associated with the token using the
// GetForToken() method (which we will create in a minute). If no matching record
// is found, then we let the client know that the token they provided is not valid.
user, err := app.models.Users.GetForToken(data.ScopeActivation, input.TokenPlaintext)
if err != nil {
switch {
case errors.Is(err, data.ErrRecordNotFound):
v.AddError("token", "invalid or expired activation token")
app.failedValidationResponse(w, r, v.Errors)
default:
app.serverErrorResponse(w, r, err)
}
return
}
// Update the user's activation status.
user.Activated = true
// Save the updated user record in our database, checking for any edit conflicts in
// the same way that we did for our movie records.
err = app.models.Users.Update(user)
if err != nil {
switch {
case errors.Is(err, data.ErrEditConflict):
app.editConflictResponse(w, r)
default:
app.serverErrorResponse(w, r, err)
}
return
}
// If everything went successfully, then we delete all activation tokens for the
// user.
err = app.models.Tokens.DeleteAllForUser(data.ScopeActivation, user.ID)
if err != nil {
app.serverErrorResponse(w, r, err)
return
}
// Send the updated user details to the client in a JSON response.
err = app.writeJSON(w, http.StatusOK, envelope{"user": user}, nil)
if err != nil {
app.serverErrorResponse(w, r, err)
}
}

8
go.mod
View File

@ -1,16 +1,18 @@
module party.at/party
go 1.24.2
go 1.25.0
require (
github.com/gorilla/websocket v1.5.3
github.com/jackc/pgx/v5 v5.7.6
github.com/julienschmidt/httprouter v1.3.0
)
require (
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/lib/pq v1.12.0 // indirect
github.com/wneessen/go-mail v0.7.2 // indirect
golang.org/x/crypto v0.37.0 // indirect
golang.org/x/text v0.24.0 // indirect
golang.org/x/text v0.29.0 // indirect
golang.org/x/time v0.15.0 // indirect
)

9
go.sum
View File

@ -13,6 +13,8 @@ github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U=
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
github.com/lib/pq v1.12.0 h1:mC1zeiNamwKBecjHarAr26c/+d8V5w/u4J0I/yASbJo=
github.com/lib/pq v1.12.0/go.mod h1:/p+8NSbOcwzAEI7wiMXFlgydTwcgTr3OSKMsD2BitpA=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
@ -20,12 +22,19 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/wneessen/go-mail v0.7.2 h1:xxPnhZ6IZLSgxShebmZ6DPKh1b6OJcoHfzy7UjOkzS8=
github.com/wneessen/go-mail v0.7.2/go.mod h1:+TkW6QP3EVkgTEqHtVmnAE/1MRhmzb8Y9/W3pweuS+k=
golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610=
golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0=
golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU=
golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk=
golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4=
golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U=
golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=

View File

@ -0,0 +1,15 @@
package data
type AuthProvider struct {
ID int64 `json:"id"`
Description string `json:"description"`
Active bool `json:"active"`
}
type AuthProviderID int64
const (
ProviderUnknown AuthProviderID = 0
ProviderLocal AuthProviderID = 1
ProviderIDAustria AuthProviderID = 2
)

77
internal/data/filters.go Normal file
View File

@ -0,0 +1,77 @@
package data
import (
"math"
"strings"
"party.at/party/internal/validator"
)
type Filters struct {
Page int
PageSize int
Sort string
SortSafelist []string
}
func (f Filters) sortColumn() string {
for _, safeValue := range f.SortSafelist {
if f.Sort == safeValue {
return strings.TrimPrefix(f.Sort, "-")
}
}
panic("unsafe sort parameter: " + f.Sort)
}
// Return the sort direction ("ASC" or "DESC") depending on the prefix character of the
// Sort field.
func (f Filters) sortDirection() string {
if strings.HasPrefix(f.Sort, "-") {
return "DESC"
}
return "ASC"
}
func (f Filters) limit() int {
return f.PageSize
}
func (f Filters) offset() int {
return (f.Page - 1) * f.PageSize
}
func ValidateFilters(v *validator.Validator, f Filters) {
// Check that the page and page_size parameters contain sensible values.
v.Check(f.Page > 0, "page", "must be greater than zero")
v.Check(f.Page <= 10_000_000, "page", "must be a maximum of 10 million")
v.Check(f.PageSize > 0, "page_size", "must be greater than zero")
v.Check(f.PageSize <= 100, "page_size", "must be a maximum of 100")
// Check that the sort parameter matches a value in the safelist.
v.Check(validator.In(f.Sort, f.SortSafelist...), "sort", "invalid sort value")
}
type Metadata struct {
CurrentPage int `json:"current_page,omitempty"`
PageSize int `json:"page_size,omitempty"`
FirstPage int `json:"first_page,omitempty"`
LastPage int `json:"last_page,omitempty"`
TotalRecords int `json:"total_records,omitempty"`
}
func calculateMetadata(totalRecords, page, pageSize int) Metadata {
if totalRecords == 0 {
// Note that we return an empty Metadata struct if there are no records.
return Metadata{}
}
return Metadata{
CurrentPage: page,
PageSize: pageSize,
FirstPage: 1,
LastPage: int(math.Ceil(float64(totalRecords) / float64(pageSize))),
TotalRecords: totalRecords,
}
}

242
internal/data/issues.go Normal file
View File

@ -0,0 +1,242 @@
package data
import (
"time"
"party.at/party/internal/validator"
"database/sql"
"errors"
"context"
"fmt"
)
type Issue struct {
ID int64 `json:"id"`
Title string `json:"title"`
Description string `json:"description"`
StartTime time.Time `json:"start_time"`
EndTime time.Time `json:"end_time"`
Created time.Time `json:"created"`
Version int32 `json:"version"`
}
func ValidateIssue(v *validator.Validator, issue *Issue) {
v.Check(issue.Title != "", "title", "must be provided")
v.Check(len(issue.Title) <= 500, "title", "must not be more than 500 bytes long")
v.Check(issue.Description != "", "description", "must be provided")
v.Check(len(issue.Description) <= 500, "description", "must not be greater than 500 bytes long")
// v.Check(issue.StartTime != "", "start_time", "must be provided")
// v.Check(len(issue.StartTime) <= 500, "start_time", "must not be more than 500 bytes long")
// v.Check(issue.EndTime != "", "end_time", "must be provided")
// v.Check(len(issue.EndTime) <= 500, "end_time", "must not be more than 500 bytes long")
}
type IssueModel struct {
DB *sql.DB
}
func (m IssueModel) Insert(issue *Issue) error {
query := `
INSERT INTO issues (title, description, start_time, end_time)
VALUES ($1, $2, $3, $4)
RETURNING id, created, version`
args := []interface{}{
issue.Title,
issue.Description,
issue.StartTime,
issue.EndTime,
}
return m.DB.QueryRow(query, args...).Scan(
&issue.ID,
&issue.Created,
&issue.Version,
)
}
// Add a placeholder method for fetching a specific record from the issues table.
func (m IssueModel) Get(id int64) (*Issue, error) {
if id < 1 {
return nil, ErrRecordNotFound
}
// Define the SQL query for retrieving the issue data.
query :=`
SELECT id, title, description, start_time, end_time, created, version
FROM issues
WHERE id = $1`
// Declare a Issue struct to hold the data returned by the query.
var issue Issue
// Execute the query using the QueryRow() method, passing in the provided id value
// as a placeholder parameter, and scan the response data into the fields of the
// Issue struct. Importantly, notice that we need to convert the scan target for the
// genres column using the pq.Array() adapter function again.
err := m.DB.QueryRow(query, id).Scan(
&issue.ID,
&issue.Title,
&issue.Description,
&issue.StartTime,
&issue.EndTime,
&issue.Created,
&issue.Version,
)
// Handle any errors. If there was no matching issue found, Scan() will return
// a sql.ErrNoRows error. We check for this and return our custom ErrRecordNotFound
// error instead.
if err != nil {
switch {
case errors.Is(err, sql.ErrNoRows):
return nil, ErrRecordNotFound
default:
return nil, err
}
}
// Otherwise, return a pointer to the Issue struct.
return &issue, nil
}
func (m IssueModel) Update(issue *Issue) error {
query := `
UPDATE issues
SET title = $1, description = $2, start_time = $3, end_time = $4, version = version + 1
WHERE id = $5 AND version = $6
RETURNING version`
// Create an args slice containing the values for the placeholder parameters.
args := []interface{}{
issue.Title,
issue.Description,
issue.StartTime,
issue.EndTime,
issue.ID,
issue.Version,
}
// Use the QueryRow() method to execute the query, passing in the args slice as a
// variadic parameter and scanning the new version value into the issue struct.
err := m.DB.QueryRow(query, args...).Scan(&issue.Version)
if err != nil {
switch {
case errors.Is(err, sql.ErrNoRows):
return ErrEditConflict
default:
return err
}
}
return nil
}
// Add a placeholder method for deleting a specific record from the issues table.
func (m IssueModel) Delete(id int64) error {
if id < 1 {
return ErrRecordNotFound
}
// Construct the SQL query to delete the record.
query := `
DELETE FROM issues
WHERE id = $1`
// Execute the SQL query using the Exec() method, passing in the id variable as
// the value for the placeholder parameter. The Exec() method returns a sql.Result
// object.
result, err := m.DB.Exec(query, id)
if err != nil {
return err
}
// Call the RowsAffected() method on the sql.Result object to get the number of rows
// affected by the query.
rowsAffected, err := result.RowsAffected()
if err != nil {
return err
}
// If no rows were affected, we know that the issues table didn't contain a record
// with the provided ID at the moment we tried to delete it. In that case we
// return an ErrRecordNotFound error.
if rowsAffected == 0 {
return ErrRecordNotFound
}
return nil
}
func (m IssueModel) GetAll(title string, filters Filters) ([]*Issue, Metadata, error) {
// Construct the SQL query to retrieve all issue records.
query := fmt.Sprintf(`
SELECT
COUNT(*) OVER(), id, title, description, start_time, end_time, created, version
FROM
issues
WHERE
(to_tsvector('simple', title) @@ plainto_tsquery('simple', $1) OR $1 = '')
ORDER BY
%s %s, id ASC
LIMIT
$2
OFFSET
$3`,
filters.sortColumn(),
filters.sortDirection(),
)
// Create a context with a 3-second timeout.
ctx, cancel := context.WithTimeout(context.Background(), 3 * time.Second)
defer cancel()
args := []interface{}{title, filters.limit(), filters.offset()}
rows, err := m.DB.QueryContext(ctx, query, args...)
if err != nil {
return nil, Metadata{}, err
}
// Importantly, defer a call to rows.Close() to ensure that the resultset is closed
// before GetAll() returns.
defer rows.Close()
totalRecords := 0
issues := []*Issue{}
// Use rows.Next to iterate through the rows in the resultset.
for rows.Next() {
// Initialize an empty Issue struct to hold the data for an individual issue.
var issue Issue
// Scan the values from the row into the Issue struct. Again, note that we're
// using the pq.Array() adapter on the genres field here.
err := rows.Scan(
&totalRecords,
&issue.ID,
&issue.Title,
&issue.Description,
&issue.StartTime,
&issue.EndTime,
&issue.Created,
&issue.Version,
)
if err != nil {
return nil, Metadata{}, err
}
// Add the Issue struct to the slice.
issues = append(issues, &issue)
}
// When the rows.Next() loop has finished, call rows.Err() to retrieve any error
// that was encountered during the iteration.
if err = rows.Err(); err != nil {
return nil, Metadata{}, err
}
metadata := calculateMetadata(totalRecords, filters.Page, filters.PageSize)
// If everything went OK, then return the slice of issues.
return issues, metadata, nil
}

29
internal/data/models.go Normal file
View File

@ -0,0 +1,29 @@
package data
import (
"database/sql"
"errors"
)
var (
ErrRecordNotFound = errors.New("record not found")
ErrEditConflict = errors.New("edit conflict")
)
type Models struct {
Users UserModel
UserIdentities UserIdentityModel
Issues IssueModel
Tokens TokenModel
Permissions PermissionModel
}
func NewModels(db *sql.DB) Models {
return Models{
Users: UserModel{DB: db},
UserIdentities: UserIdentityModel{DB: db},
Issues: IssueModel{DB: db},
Tokens: TokenModel{DB: db},
Permissions: PermissionModel{DB: db},
}
}

View File

@ -0,0 +1,69 @@
package data
import (
"context"
"database/sql"
"time"
"github.com/lib/pq"
)
type Permissions []string
func (p Permissions) Include(code string) bool {
for i := range p {
if code == p[i] {
return true
}
}
return false
}
type PermissionModel struct {
DB *sql.DB
}
func (m PermissionModel) GetAllForUser(userID int64) (Permissions, error) {
query :=`
SELECT permissions.code
FROM permissions
INNER JOIN users_permissions ON users_permissions.permission_id = permissions.id
INNER JOIN users ON users_permissions.user_id = users.id
WHERE users.id = $1`
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
rows, err := m.DB.QueryContext(ctx, query, userID)
if err != nil {
return nil, err
}
defer rows.Close()
var permissions Permissions
for rows.Next() {
var permission string
err := rows.Scan(&permission)
if err != nil {
return nil, err
}
permissions = append(permissions, permission)
}
if err = rows.Err(); err != nil {
return nil, err
}
return permissions, nil
}
func (m PermissionModel) AddForUser(userID int64, codes ...string) error {
query :=`
INSERT INTO users_permissions
SELECT $1, permissions.id FROM permissions WHERE permissions.code = ANY($2)`
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
_, err := m.DB.ExecContext(ctx, query, userID, pq.Array(codes))
return err
}

95
internal/data/tokens.go Normal file
View File

@ -0,0 +1,95 @@
package data
import (
"crypto/rand"
"crypto/sha256"
"encoding/base32"
"time"
"context"
"database/sql"
"party.at/party/internal/validator"
)
const (
ScopeActivation = "activation"
ScopeAuthentication = "authentication"
)
type Token struct {
Plaintext string `json:"token"`
Hash []byte `json:"-"`
UserID int64 `json:"-"`
UserIdentityID int64 `json:"-"`
Expiry time.Time `json:"expiry"`
Scope string `json:"-"`
}
func generateToken(userID int64, userIdentityID int64, ttl time.Duration, scope string) (*Token, error) {
token := &Token{
UserID: userID,
UserIdentityID: userIdentityID,
Expiry: time.Now().Add(ttl),
Scope: scope,
}
// Initialize a zero-valued byte slice with a length of 16 bytes.
randomBytes := make([]byte, 16)
_, err := rand.Read(randomBytes)
if err != nil {
return nil, err
}
token.Plaintext = base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(randomBytes)
hash := sha256.Sum256([]byte(token.Plaintext))
token.Hash = hash[:]
return token, nil
}
func ValidateTokenPlaintext(v *validator.Validator, tokenPlaintext string) {
v.Check(tokenPlaintext != "", "token", "must be provided")
v.Check(len(tokenPlaintext) == 26, "token", "must be 26 bytes long")
}
type TokenModel struct {
DB *sql.DB
}
func (m TokenModel) New(userID int64, userIdentityID int64, ttl time.Duration, scope string) (*Token, error) {
token, err := generateToken(userID, userIdentityID, ttl, scope)
if err != nil {
return nil, err
}
err = m.Insert(token)
return token, err
}
func (m TokenModel) Insert(token *Token) error {
query :=`
INSERT INTO tokens (hash, user_id, user_identity_id, expiry, scope)
VALUES ($1, $2, $3, $4, $5)`
args := []interface{}{token.Hash, token.UserID, token.UserIdentityID, token.Expiry, token.Scope}
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
_, err := m.DB.ExecContext(ctx, query, args...)
return err
}
func (m TokenModel) DeleteAllForUser(scope string, userID int64) error {
query :=`
DELETE FROM tokens
WHERE scope = $1 AND user_id = $2`
ctx, cancel := context.WithTimeout(context.Background(), 3 * time.Second)
defer cancel()
_, err := m.DB.ExecContext(ctx, query, scope, userID)
return err
}

View File

@ -0,0 +1,306 @@
package data
import (
"golang.org/x/crypto/bcrypt"
"errors"
"party.at/party/internal/validator"
"database/sql"
"context"
"time"
"crypto/sha256"
)
type UserIdentity struct {
ID int64 `json:"id"`
ProviderID int64 `json:"provider_id"`
UserID int64 `json:"user_id"`
ProviderUserID string `json:"provider_user_id"`
Password password `json:"-"`
Version int32 `json:"version"`
}
type password struct {
plaintext *string
hash []byte
}
func (p *password) Set(plaintextPassword string) error {
hash, err := bcrypt.GenerateFromPassword([]byte(plaintextPassword), 12)
if err != nil {
return err
}
p.plaintext = &plaintextPassword
p.hash = hash
return nil
}
func (p *password) Matches(plaintextPassword string) (bool, error) {
err := bcrypt.CompareHashAndPassword(p.hash, []byte(plaintextPassword))
if err != nil {
switch {
case errors.Is(err, bcrypt.ErrMismatchedHashAndPassword):
return false, nil
default:
return false, err
}
}
return true, nil
}
func ValidatePasswordPlaintext(v *validator.Validator, password string) {
v.Check(password != "", "password", "must be provided")
v.Check(len(password) >= 8, "password", "must be at least 8 bytes long")
v.Check(len(password) <= 72, "password", "must not be more than 72 bytes long")
}
func ValidateUserIdentity(v *validator.Validator, userIdentity *UserIdentity) {
v.Check(userIdentity.ProviderID == int64(ProviderLocal), "provider_id", "must be 1");
if userIdentity.ProviderID == int64(ProviderLocal) {
v.Check(userIdentity.ProviderUserID != "", "username", "must be provided")
v.Check(len(userIdentity.ProviderUserID) <= 500, "username", "must not be more than 500 bytes long")
}
if userIdentity.Password.plaintext != nil {
ValidatePasswordPlaintext(v,
*userIdentity.Password.plaintext)
}
if userIdentity.Password.hash == nil {
panic("missing password hash for user")
}
}
type UserIdentityModel struct {
DB *sql.DB
}
// func (m UserIdentityModel) Insert(userIdentity *UserIdentity) error {
// query := `
// INSERT INTO user_identities (provider_id, user_id, provider_user, password)
// VALUES ($1, $2, $3, $4)
// RETURNING id, version`
// args := []interface{}{
// userIdentity.ProviderID,
// userIdentity.UserID,
// userIdentity.ProviderUserID,
// userIdentity.Password.hash,
// }
// ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
// defer cancel()
// err := m.DB.QueryRowContext(ctx, query, args...).Scan(&userIdentity.ID, &userIdentity.Version)
// if err != nil {
// switch {
// case err.Error() ==
// `pq: duplicate key value violates unique constraint "users_email_key"`:
// return ErrDuplicateEmail
// default:
// return err
// }
// }
// return nil
// }
func (m UserIdentityModel) Get(id int64) (*UserIdentity, error) {
if id < 1 {
return nil, ErrRecordNotFound
}
query :=`
SELECT id, provider_id, user_id, provider_user, password, version
FROM user_identities
WHERE id = $1`
// Declare a User struct to hold the data returned by the query.
var userIdentity UserIdentity
// Execute the query using the QueryRow() method, passing in the provided id value
// as a placeholder parameter, and scan the response data into the fields of the
// User struct. Importantly, notice that we need to convert the scan target for the
// genres column using the pq.Array() adapter function again.
err := m.DB.QueryRow(query, id).Scan(
&userIdentity.ID,
&userIdentity.ProviderID,
&userIdentity.UserID,
&userIdentity.ProviderUserID,
&userIdentity.Password,
&userIdentity.Version,
)
if err != nil {
switch {
case errors.Is(err, sql.ErrNoRows):
return nil, ErrRecordNotFound
default:
return nil, err
}
}
// Otherwise, return a pointer to the User struct.
return &userIdentity, nil
}
func (m UserIdentityModel) GetByUser(user_id int64) ([]*UserIdentity, error) {
if user_id < 1 {
return nil, ErrRecordNotFound
}
query :=`
SELECT identity.id, identity.provider_id, identity.user_id, identity.provider_user_id, identity.password, identity.version
FROM user_identities identity
JOIN users u on identity.user_id = u.id
WHERE u.id = $1`
// Execute the query using the QueryRow() method, passing in the provided id value
// as a placeholder parameter, and scan the response data into the fields of the
// User struct. Importantly, notice that we need to convert the scan target for the
// genres column using the pq.Array() adapter function again.
rows, err := m.DB.Query(query, user_id)
if err != nil {
switch {
case errors.Is(err, sql.ErrNoRows):
return nil, ErrRecordNotFound
default:
return nil, err
}
}
defer rows.Close()
var identities []*UserIdentity
for rows.Next() {
var userIdentity UserIdentity
err := rows.Scan(
&userIdentity.ID,
&userIdentity.ProviderID,
&userIdentity.UserID,
&userIdentity.ProviderUserID,
&userIdentity.Password.hash,
&userIdentity.Version,
)
if err != nil {
return nil, err
}
identities = append(identities, &userIdentity)
}
if len(identities) == 0 {
return nil, ErrRecordNotFound
}
// Otherwise, return a pointer to the User struct.
return identities, nil
}
func (m UserIdentityModel) GetForToken(tokenScope, tokenPlaintext string) (*UserIdentity, error) {
tokenHash := sha256.Sum256([]byte(tokenPlaintext))
query := `
SELECT ui.id, ui.provider_id, ui.user_id, ui.provider_user_id, ui.password, ui.version
FROM user_identities ui
INNER JOIN tokens ON ui.id = tokens.user_identity_id
WHERE tokens.hash = $1
AND tokens.scope = $2
AND tokens.expiry > $3`
args := []interface{}{tokenHash[:], tokenScope, time.Now()}
var userIdentity UserIdentity
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
err := m.DB.QueryRowContext(ctx, query, args...).Scan(
&userIdentity.ID,
&userIdentity.ProviderID,
&userIdentity.UserID,
&userIdentity.ProviderUserID,
&userIdentity.Password.hash,
&userIdentity.Version,
)
if err != nil {
switch {
case errors.Is(err, sql.ErrNoRows):
return nil, ErrRecordNotFound
default:
return nil, err
}
}
return &userIdentity, nil
}
func (m UserIdentityModel) Update(user *UserIdentity) error {
query := `
UPDATE user_identities
SET provider_user_id = $1, password = $2, version = version + 1
WHERE id = $3 AND version = $4
RETURNING version`
// Create an args slice containing the values for the placeholder parameters.
args := []interface{}{
user.ProviderUserID,
user.Password.hash,
user.ID,
user.Version,
}
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
err := m.DB.QueryRowContext(ctx, query, args...).Scan(&user.Version)
if err != nil {
switch {
case err.Error() == `pq: duplicate key value violates unique constraint "users_email_key"`:
return ErrDuplicateEmail
case errors.Is(err, sql.ErrNoRows):
return ErrEditConflict
default:
return err
}
}
return nil
}
func (m UserIdentityModel) Delete(id int64) error {
if id < 1 {
return ErrRecordNotFound
}
// Construct the SQL query to delete the record.
query := `
DELETE FROM user_identities
WHERE id = $1`
// Execute the SQL query using the Exec() method, passing in the id variable as
// the value for the placeholder parameter. The Exec() method returns a sql.Result
// object.
result, err := m.DB.Exec(query, id)
if err != nil {
return err
}
// Call the RowsAffected() method on the sql.Result object to get the number of rows
// affected by the query.
rowsAffected, err := result.RowsAffected()
if err != nil {
return err
}
// If no rows were affected, we know that the issues table didn't contain a record
// with the provided ID at the moment we tried to delete it. In that case we
// return an ErrRecordNotFound error.
if rowsAffected == 0 {
return ErrRecordNotFound
}
return nil
}

329
internal/data/users.go Normal file
View File

@ -0,0 +1,329 @@
package data
import (
"context"
"time"
"party.at/party/internal/validator"
"database/sql"
"github.com/lib/pq"
"errors"
"crypto/sha256"
)
var (
ErrDuplicateEmail = errors.New("duplicate email")
ErrDuplicateUser = errors.New("duplicate username")
)
var AnonymousUser = &User{}
type User struct {
ID int64 `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
AltName sql.NullString `json:"alt_name"`
Created time.Time `json:"created"`
LastLogin time.Time `json:"last_login"`
Activated bool `json:"activated"`
Version int32 `json:"-"`
}
func (u *User) IsAnonymous() bool {
return u == AnonymousUser
}
func ValidateEmail(v *validator.Validator, email string) {
v.Check(email != "", "email", "must be provided")
v.Check(validator.Matches(email, validator.EmailRX), "email" , "must be a valid email address")
}
func ValidateUser(v *validator.Validator, user *User) {
ValidateEmail(v, user.Email)
}
type UserModel struct {
DB *sql.DB
}
func (m UserModel) ExecuteRegistrationTx(user *User, userIdentity *UserIdentity) error {
ctx, cancel := context.WithTimeout(context.Background(), 5 * time.Second)
defer cancel()
tx, err := m.DB.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()
query := `
INSERT INTO users (email, name, alt_name)
VALUES ($1, $2, $3)
RETURNING id, created, last_login, version`
args := []interface{}{
user.Email,
user.Name,
user.AltName,
}
err = tx.QueryRowContext(ctx, query, args...).Scan(&user.ID, &user.Created, &user.LastLogin, &user.Version)
if pgErr, ok := err.(*pq.Error); ok {
if pgErr.Code == "23505" {
if pgErr.Constraint == "users_email_key" {
return ErrDuplicateEmail
}
}
}
if err != nil {
return err
}
userIdentity.UserID = user.ID
// Insert Identity
query = `
INSERT INTO user_identities (provider_id, user_id, provider_user_id, password)
VALUES ($1, $2, $3, $4)
RETURNING id, version`
args = []interface{}{
userIdentity.ProviderID,
userIdentity.UserID,
userIdentity.ProviderUserID,
userIdentity.Password.hash,
}
err = tx.QueryRowContext(ctx, query, args...).Scan(&userIdentity.ID, &userIdentity.Version)
if pgErr, ok := err.(*pq.Error); ok {
if pgErr.Code == "23505" {
if pgErr.Constraint == "user_identities_provider_id_provider_user_id_key" {
return ErrDuplicateUser
}
}
}
if err != nil {
return err
}
err = tx.Commit()
if err != nil {
return err
}
return nil
}
// func (m UserModel) Insert(user *User) error {
// query := `
// INSERT INTO users (email, name, alt_name, activated)
// VALUES ($1, $2, $3, $4)
// RETURNING id, created, last_login, version`
// args := []interface{}{
// user.Email,
// user.Name,
// user.AltName,
// user.Activated,
// }
// ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
// defer cancel()
// err := m.DB.QueryRowContext(ctx, query, args...).Scan(&user.ID, &user.Created, &user.LastLogin, &user.Version)
// if err != nil {
// switch {
// case err.Error() ==
// `pq: duplicate key value violates unique constraint "users_email_key"`:
// return ErrDuplicateEmail
// default:
// return err
// }
// }
// return nil
// }
func (m UserModel) Get(id int64) (*User, error) {
if id < 1 {
return nil, ErrRecordNotFound
}
// Define the SQL query for retrieving the issue data.
query :=`
SELECT id, email, name, alt_name, created, last_login, activated, version
FROM users
WHERE id = $1`
// Declare a User struct to hold the data returned by the query.
var user User
err := m.DB.QueryRow(query, id).Scan(
&user.ID,
&user.Email,
&user.Name,
&user.AltName,
&user.Created,
&user.LastLogin,
&user.Activated,
&user.Version,
)
if err != nil {
switch {
case errors.Is(err, sql.ErrNoRows):
return nil, ErrRecordNotFound
default:
return nil, err
}
}
return &user, nil
}
func (m UserModel) GetByEmail(email string) (*User, error) {
query :=`
SELECT id, email, name, alt_name, created, last_login, activated, version
FROM users
WHERE email = $1`
var user User
ctx, cancel := context.WithTimeout(context.Background(), 3 * time.Second)
defer cancel()
err := m.DB.QueryRowContext(ctx, query, email).Scan(
&user.ID,
&user.Email,
&user.Name,
&user.AltName,
&user.Created,
&user.LastLogin,
&user.Activated,
&user.Version,
)
if err != nil {
switch {
case errors.Is(err, sql.ErrNoRows):
return nil, ErrRecordNotFound
default:
return nil, err
}
}
return &user, nil
}
func (m UserModel) GetForToken(tokenScope, tokenPlaintext string) (*User, error) {
// Calculate the SHA-256 hash of the plaintext token provided by the client.
// Remember that this returns a byte *array* with length 32, not a slice.
tokenHash := sha256.Sum256([]byte(tokenPlaintext))
// Set up the SQL query.
query :=`
SELECT users.id, users.created, users.name, users.email, users.activated, users.version
FROM users
INNER JOIN tokens ON users.id = tokens.user_id
WHERE tokens.hash = $1
AND tokens.scope = $2
AND tokens.expiry > $3`
// Create a slice containing the query arguments. Notice how we use the [:] operator
// to get a slice containing the token hash, rather than passing in the array (which
// is not supported by the pq driver), and that we pass the current time as the
// value to check against the token expiry.
args := []interface{}{tokenHash[:], tokenScope, time.Now()}
var user User
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
// Execute the query, scanning the return values into a User struct. If no matching
// record is found we return an ErrRecordNotFound error.
err := m.DB.QueryRowContext(ctx, query, args...).Scan(
&user.ID,
&user.Created,
&user.Name,
&user.Email,
&user.Activated,
&user.Version,
)
if err != nil {
switch {
case errors.Is(err, sql.ErrNoRows):
return nil, ErrRecordNotFound
default:
return nil, err
}
}
// Return the matching user.
return &user, nil
}
func (m UserModel) Update(user *User) error {
query := `
UPDATE users
SET email = $1, name = $2, alt_name = $3, activated = $4, version = version + 1
WHERE id = $5 AND version = $6
RETURNING version`
// Create an args slice containing the values for the placeholder parameters.
args := []interface{}{
user.Email,
user.Name,
user.AltName,
user.Activated,
user.ID,
user.Version,
}
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
err := m.DB.QueryRowContext(ctx, query, args...).Scan(&user.Version)
if err != nil {
switch {
case err.Error() == `pq: duplicate key value violates unique constraint "users_email_key"`:
return ErrDuplicateEmail
case errors.Is(err, sql.ErrNoRows):
return ErrEditConflict
default:
return err
}
}
return nil
}
func (m UserModel) Delete(id int64) error {
if id < 1 {
return ErrRecordNotFound
}
// Construct the SQL query to delete the record.
query := `
DELETE FROM users
WHERE id = $1`
// Execute the SQL query using the Exec() method, passing in the id variable as
// the value for the placeholder parameter. The Exec() method returns a sql.Result
// object.
result, err := m.DB.Exec(query, id)
if err != nil {
return err
}
// Call the RowsAffected() method on the sql.Result object to get the number of rows
// affected by the query.
rowsAffected, err := result.RowsAffected()
if err != nil {
return err
}
// If no rows were affected, we know that the issues table didn't contain a record
// with the provided ID at the moment we tried to delete it. In that case we
// return an ErrRecordNotFound error.
if rowsAffected == 0 {
return ErrRecordNotFound
}
return nil
}

124
internal/jsonlog/jsonlog.go Normal file
View File

@ -0,0 +1,124 @@
package jsonlog
import (
"encoding/json"
"io"
"os"
"runtime/debug"
"sync"
"time"
)
// Define a Level type to represent the severity level for a log entry.
type Level int8
// Initialize constants which represent a specific severity level. We use the iota
// keyword as a shortcut to assign successive integer values to the constants.
const (
LevelInfo Level = iota // Has the value 0.
LevelError // Has the value 1.
LevelFatal // Has the value 2.
LevelOff // Has the value 3.
)
// Return a human-friendly string for the severity level.
func (l Level) String() string {
switch l {
case LevelInfo:
return "INFO"
case LevelError:
return "ERROR"
case LevelFatal:
return "FATAL"
default:
return ""
}
}
// Define a custom Logger type. This holds the output destination that the log entries
// will be written to, the minimum severity level that log entries will be written for,
// plus a mutex for coordinating the writes.
type Logger struct {
out io.Writer
minLevel Level
mu sync.Mutex
}
// Return a new Logger instance which writes log entries at or above a minimum severity
// level to a specific output destination.
func New(out io.Writer, minLevel Level) *Logger {
return &Logger{
out: out,
minLevel: minLevel,
}
}
// Declare some helper methods for writing log entries at the different levels. Notice
// that these all accept a map as the second parameter which can contain any arbitrary
// 'properties' that you want to appear in the log entry.
func (l *Logger) PrintInfo(message string, properties map[string]string) {
l.print(LevelInfo, message, properties)
}
func (l *Logger) PrintError(err error, properties map[string]string) {
l.print(LevelError, err.Error(), properties)
}
func (l *Logger) PrintFatal(err error, properties map[string]string) {
l.print(LevelFatal, err.Error(), properties)
os.Exit(1) // For entries at the FATAL level, we also terminate the application.
}
// Print is an internal method for writing the log entry.
func (l *Logger) print(level Level, message string, properties map[string]string) (int, error) {
// If the severity level of the log entry is below the minimum severity for the
// logger, then return with no further action.
if level < l.minLevel {
return 0, nil
}
// Declare an anonymous struct holding the data for the log entry.
aux := struct {
Level string `json:"level"`
Time string `json:"time"`
Message string `json:"message"`
Properties map[string]string `json:"properties,omitempty"`
Trace string `json:"trace,omitempty"`
}{
Level: level.String(),
Time: time.Now().UTC().Format(time.RFC3339),
Message: message,
Properties: properties,
}
// Include a stack trace for entries at the ERROR and FATAL levels.
if level >= LevelError {
aux.Trace = string(debug.Stack())
}
// Declare a line variable for holding the actual log entry text.
var line []byte
// Marshal the anonymous struct to JSON and store it in the line variable. If there
// was a problem creating the JSON, set the contents of the log entry to be that
// plain-text error message instead.
line, err := json.Marshal(aux)
if err != nil {
line = []byte(LevelError.String() + ": unable to marshal log message:" + err.Error())
}
// Lock the mutex so that no two writes to the output destination cannot happen
// concurrently. If we don't do this, it's possible that the text for two or more
// log entries will be intermingled in the output.
l.mu.Lock()
defer l.mu.Unlock()
// Write the log entry followed by a newline.
return l.out.Write(append(line, '\n'))
}
// We also implement a Write() method on our Logger type so that it satisfies the
// io.Writer interface. This writes a log entry at the ERROR level with no additional
// properties.
func (l *Logger) Write(message []byte) (n int, err error) {
return l.print(LevelError, string(message), nil)
}

96
internal/mailer/mailer.go Normal file
View File

@ -0,0 +1,96 @@
package mailer
import (
"bytes"
"embed"
"html/template"
"log"
// "time"
"github.com/wneessen/go-mail"
)
//go:embed "templates"
var templateFS embed.FS
type Mailer struct {
dialer *mail.Client
sender string
}
func New(host string, port int, username, password, sender string) Mailer {
// Initialize a new mail.Dialer instance with the given SMTP server settings. We
// also configure this to use a 5-second timeout whenever we send an email.
dialer, err := mail.NewClient(host,
mail.WithSMTPAuth(mail.SMTPAuthAutoDiscover), mail.WithTLSPortPolicy(mail.TLSMandatory),
mail.WithUsername(username), mail.WithPassword(password),
)
if err != nil {
log.Fatalf("failed to deliver mail: %s", err)
}
// dialer. = 5 * time.Second
// Return a Mailer instance containing the dialer and sender information.
return Mailer{
dialer: dialer,
sender: sender,
}
}
// Define a Send() method on the Mailer type. This takes the recipient email address
// as the first parameter, the name of the file containing the templates, and any
// dynamic data for the templates as an interface{} parameter.
func (m Mailer) Send(recipient, templateFile string, data interface{}) error {
// Use the ParseFS() method to parse the required template file from the embedded
// file system.
tmpl, err := template.New("email").ParseFS(templateFS, "templates/"+templateFile)
if err != nil {
return err
}
// Execute the named template "subject", passing in the dynamic data and storing the
// result in a bytes.Buffer variable.
subject := new(bytes.Buffer)
err = tmpl.ExecuteTemplate(subject, "subject", data)
if err != nil {
return err
}
// Follow the same pattern to execute the "plainBody" template and store the result
// in the plainBody variable.
plainBody := new(bytes.Buffer)
err = tmpl.ExecuteTemplate(plainBody, "plainBody", data)
if err != nil {
return err
}
// And likewise with the "htmlBody" template.
htmlBody := new(bytes.Buffer)
err = tmpl.ExecuteTemplate(htmlBody, "htmlBody", data)
if err != nil {
return err
}
// Use the mail.NewMessage() function to initialize a new mail.Message instance.
// Then we use the SetHeader() method to set the email recipient, sender and subject
// headers, the SetBody() method to set the plain-text body, and the AddAlternative()
// method to set the HTML body. It's important to note that AddAlternative() should
// always be called *after* SetBody().
msg := mail.NewMsg()
msg.SetAddrHeader("To", recipient)
msg.SetAddrHeader("From", m.sender)
msg.SetGenHeader("Subject", subject.String())
msg.SetBodyString("text/plain", plainBody.String())
msg.AddAlternativeString("text/html", htmlBody.String())
// Call the DialAndSend() method on the dialer, passing in the message to send. This
// opens a connection to the SMTP server, sends the message, then closes the
// connection. If there is a timeout, it will return a "dial tcp: i/o timeout"
// error.
err = m.dialer.DialAndSend(msg)
if err != nil {
return err
}
return nil
}

View File

@ -0,0 +1,35 @@
{{define "subject"}}Welcome to Digitale Partei Österreich!{{end}}
{{define "plainBody"}}
Hi,
Thanks for signing up for a Digitale Partei Österreich account. We're excited to have you on board!
For future reference, your user ID number is {{.ID}}.
Your activation token is {{.token}}
Thanks,
The DigitalePartei Team
{{end}}
{{define "htmlBody"}}
<!doctype html>
<html>
<head>
<meta name="viewport" content="width=device-width" />
<meta http-equiv="Content-Type" content="text/html; charset=UTF-8" />
</head>
<body>
<p>Hi,</p>
<p>Thanks for signing up for a Digitale Partei Österreich account. We're excited to have you on board!</p>
<p>For future reference, your user ID number is {{.ID}}.</p>
<p>Your activation token is {{.token}}</p>
<p>Thanks,</p>
<p>The DigitalePartei Team</p>
</body>
</html>
{{end}}

View File

@ -0,0 +1,61 @@
package validator
import (
"regexp"
)
// Declare a regular expression for sanity checking the format of email addresses (we'll
// use this later in the book). If you're interested, this regular expression pattern is
// taken from https://html.spec.whatwg.org/#valid-e-mail-address. Note: if you're
// reading this in PDF or EPUB format and cannot see the full pattern, please see the
// note further down the page.
var (
EmailRX = regexp.MustCompile("^[a-zA-Z0-9.!#$%&'*+\\/=?^_`{|}~-]+@[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(?:\\.[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*$")
)
type Validator struct {
Errors map[string]string
}
func New() *Validator {
return &Validator{Errors: make(map[string]string)}
}
func (v *Validator) Valid() bool {
return len(v.Errors) == 0
}
func (v *Validator) AddError(key, message string) {
if _, exists := v.Errors[key]; !exists {
v.Errors[key] = message
}
}
func (v *Validator) Check(ok bool, key, message string) {
if !ok {
v.AddError(key, message)
}
}
func In(value string, list ...string) bool {
for i := range list {
if value == list[i] {
return true
}
}
return false
}
func Matches(value string, rx *regexp.Regexp) bool {
return rx.MatchString(value)
}
func Unique(values []string) bool {
uniqueValues := make(map[string]bool)
for _, value := range values {
uniqueValues[value] = true
}
return len(values) == len(uniqueValues)
}

View File

@ -0,0 +1,5 @@
DROP TABLE IF EXISTS user_identities;
DROP TABLE IF EXISTS auth_provider;
DROP TABLE IF EXISTS users;

View File

@ -0,0 +1,65 @@
CREATE TABLE IF NOT EXISTS users (
id BIGSERIAL PRIMARY KEY,
email citext UNIQUE NOT NULL,
name TEXT NOT NULL,
alt_name TEXT,
created TIMESTAMPTZ NOT NULL DEFAULT now(),
last_login TIMESTAMPTZ NOT NULL DEFAULT '1970-01-01 00:00:00+00',
activated BOOL NOT NULL DEFAULT false,
version INT NOT NULL DEFAULT 1
);
CREATE TABLE IF NOT EXISTS auth_provider (
id BIGSERIAL PRIMARY KEY, -- e.g., 'local', 'id_austria',
description TEXT NOT NULL,
active BOOLEAN DEFAULT false
);
INSERT INTO auth_provider (description, active) VALUES ('local', true);
INSERT INTO auth_provider (description, active) VALUES ('id_austria', false);
CREATE TABLE IF NOT EXISTS user_identities (
id BIGSERIAL PRIMARY KEY,
provider_id BIGINT NOT NULL REFERENCES auth_provider(id),
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
-- For local: the username. For OIDC: the 'sub' (Subject ID)
provider_user_id TEXT NOT NULL,
-- Nullable because OIDC users won't have a password in your DB
password bytea,
version INT NOT NULL DEFAULT 1,
UNIQUE(provider_id, provider_user_id)
);
-- INSERT INTO users(
-- username,
-- email,
-- password,
-- name,
-- alt_name
-- )
-- VALUES(
-- 'vik',
-- 'vikhenzo@gmail.com',
-- '$argon2id$v=19$m=65536,t=1,p=32$+dQ9uB7kKL7t7G3bI+TOMw$Wvic27W6SYH6Fx2Pp84irhVJ/blVh5qINlkv58bpgEc',
-- 'Vicente Ferrari Smith',
-- 'Vicente'
-- );
-- INSERT INTO users(
-- username,
-- email,
-- password,
-- name,
-- alt_name
-- )
-- VALUES(
-- 'mkBflwkpe',
-- 'bellrebekaou@gmail.com',
-- '$argon2id$v=19$m=65536,t=1,p=1$BWcWdp8bhgS84LWqUCb2IA$DGF/FzQbSNHnfZraE9F2qvfdBGf5XB81+w00QgY/jG0',
-- 'zWkxKTNolTgJwO',
-- 'OahedOBLSo'
-- );

View File

@ -0,0 +1,8 @@
DROP INDEX IF EXISTS idx_votes_option_id;
DROP INDEX IF EXISTS idx_vote_tokens_issue_id;
DROP INDEX IF EXISTS idx_options_issue_id;
DROP TABLE IF EXISTS votes;
DROP TABLE IF EXISTS vote_tokens;
DROP TABLE IF EXISTS options;
DROP TABLE IF EXISTS issues;

View File

@ -0,0 +1,38 @@
CREATE TABLE IF NOT EXISTS issues (
id BIGSERIAL PRIMARY KEY,
title VARCHAR(255) NOT NULL,
description TEXT,
start_time TIMESTAMPTZ NOT NULL,
end_time TIMESTAMPTZ NOT NULL,
created TIMESTAMPTZ NOT NULL DEFAULT now(),
version INT NOT NULL DEFAULT 1
);
CREATE TABLE IF NOT EXISTS options (
id BIGSERIAL PRIMARY KEY,
issue_id BIGINT NOT NULL REFERENCES issues(id) ON DELETE CASCADE,
label VARCHAR(255) NOT NULL,
created TIMESTAMPTZ NOT NULL DEFAULT now(),
version INT NOT NULL DEFAULT 1
);
CREATE TABLE IF NOT EXISTS vote_tokens (
id BIGSERIAL PRIMARY KEY,
issue_id BIGINT NOT NULL REFERENCES issues(id) ON DELETE CASCADE,
token UUID NOT NULL UNIQUE DEFAULT gen_random_uuid(),
used BOOLEAN NOT NULL DEFAULT FALSE,
created TIMESTAMPTZ NOT NULL DEFAULT now(),
version INT NOT NULL DEFAULT 1
);
CREATE TABLE IF NOT EXISTS votes (
id BIGINT GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
token UUID NOT NULL UNIQUE REFERENCES vote_tokens(token) ON DELETE CASCADE,
option_id BIGINT NOT NULL REFERENCES options(id) ON DELETE CASCADE,
created TIMESTAMPTZ NOT NULL DEFAULT now(),
version INT NOT NULL DEFAULT 1
);
CREATE INDEX idx_votes_option_id ON votes(option_id);
CREATE INDEX idx_vote_tokens_issue_id ON vote_tokens(issue_id);
CREATE INDEX idx_options_issue_id ON options(issue_id);

View File

@ -0,0 +1,3 @@
DROP INDEX IF EXISTS issue_title_idx;
DROP INDEX IF EXISTS issue_description_idx;

View File

@ -0,0 +1,3 @@
CREATE INDEX IF NOT EXISTS issue_title_idx ON issues USING GIN (to_tsvector('simple', title));
CREATE INDEX IF NOT EXISTS issue_description_idx ON issues USING GIN (to_tsvector('simple', description));

View File

@ -0,0 +1 @@
DROP TABLE IF EXISTS tokens;

View File

@ -0,0 +1,7 @@
CREATE TABLE IF NOT EXISTS tokens (
hash bytea PRIMARY KEY,
user_id bigint NOT NULL REFERENCES users ON DELETE CASCADE,
user_identity_id bigint REFERENCES user_identities ON DELETE CASCADE,
expiry timestamp(0) with time zone NOT NULL,
scope text NOT NULL
);

View File

@ -0,0 +1,2 @@
DROP TABLE IF EXISTS users_permissions;
DROP TABLE IF EXISTS permissions;

View File

@ -0,0 +1,14 @@
CREATE TABLE IF NOT EXISTS permissions (
id bigserial PRIMARY KEY,
code text NOT NULL
);
CREATE TABLE IF NOT EXISTS users_permissions (
user_id bigint NOT NULL REFERENCES users ON DELETE CASCADE,
permission_id bigint NOT NULL REFERENCES permissions ON DELETE CASCADE,
PRIMARY KEY (user_id, permission_id)
);
-- Add the two permissions to the table.
INSERT INTO permissions (code)
VALUES ('issues:read'), ('issues:write');