diff --git a/.vscode/launch.json b/.vscode/launch.json index bc86fbf..689a0d4 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -10,7 +10,7 @@ "type": "go", "request": "launch", "mode": "auto", - "program": "${workspaceFolder}/cmd/web" + "program": "${workspaceFolder}/cmd/api" } ] } \ No newline at end of file diff --git a/cmd/api/context.go b/cmd/api/context.go new file mode 100644 index 0000000..d2572c8 --- /dev/null +++ b/cmd/api/context.go @@ -0,0 +1,26 @@ +package main + +import ( + "context" + "net/http" + + "party.at/party/internal/data" +) + +type contextKey string + +const userContextKey = "user" + +func (app *application) contextSetUser(r *http.Request, user *data.User) *http.Request { + ctx := context.WithValue(r.Context(), userContextKey, user) + return r.WithContext(ctx) +} + +func (app *application) contextGetUser(r *http.Request) *data.User { + user, ok := r.Context().Value(userContextKey).(*data.User) + if !ok { + panic("missing user value in request context") + } + + return user +} diff --git a/cmd/api/db.go b/cmd/api/db.go index c63ac5a..3439f1f 100644 --- a/cmd/api/db.go +++ b/cmd/api/db.go @@ -1,19 +1,19 @@ package main import ( - "context" - "github.com/jackc/pgx/v5" + // "context" + "database/sql" "log" ) -func database_init(conn *pgx.Conn) { +func database_init_dont_use(conn *sql.DB) { sql := `DROP TABLE IF EXISTS vote; DROP TABLE IF EXISTS vote_token; DROP TABLE IF EXISTS option; -DROP TABLE IF EXISTS issue; +DROP TABLE IF EXISTS issues; DROP TABLE IF EXISTS account;` - _, err := conn.Exec(context.Background(), sql) + _, err := conn.Exec(sql) if err != nil { log.Fatal(err) } @@ -29,7 +29,7 @@ created TIMESTAMPTZ DEFAULT now() NOT NULL, last_login TIMESTAMPTZ )` - _, err = conn.Exec(context.Background(), sql) + _, err = conn.Exec(sql) if err != nil { log.Fatal(err) } @@ -74,12 +74,12 @@ INSERT INTO account (username, password_hash, first_name, last_name, email, created) VALUES('yyZGFAsmPEBoSf', '$argon2id$v=19$m=65536,t=1,p=1$vLVm5ol6GQkXPU0BPut92g$Kyvp7/dl3lGUszXRiXfgsgB/IY0EKulZVpVttXQaDDU', 'svVjyPUaAkN', 'mdngeHlf', 'giyahyd4141@gmail.com', '2024-11-28 19:02:32.806');` - _, err = conn.Exec(context.Background(), sql) + _, err = conn.Exec(sql) if err != nil { log.Fatal(err) } - sql = `CREATE TABLE issue ( + sql = `CREATE TABLE issues ( id BIGINT GENERATED ALWAYS AS IDENTITY PRIMARY KEY, title VARCHAR(255) NOT NULL, description TEXT, @@ -88,32 +88,32 @@ end_time TIMESTAMPTZ NOT NULL, created_at TIMESTAMPTZ NOT NULL DEFAULT now() )` - _, err = conn.Exec(context.Background(), sql) + _, err = conn.Exec(sql) if err != nil { log.Fatal(err) } sql = `CREATE TABLE option ( id BIGINT GENERATED ALWAYS AS IDENTITY PRIMARY KEY, -issue_id BIGINT NOT NULL REFERENCES issue(id) ON DELETE CASCADE, +issue_id BIGINT NOT NULL REFERENCES issues(id) ON DELETE CASCADE, label VARCHAR(255) NOT NULL, created_at TIMESTAMPTZ NOT NULL DEFAULT now() )` - _, err = conn.Exec(context.Background(), sql) + _, err = conn.Exec(sql) if err != nil { log.Fatal(err) } sql = `CREATE TABLE vote_token ( id BIGINT GENERATED ALWAYS AS IDENTITY PRIMARY KEY, -issue_id BIGINT NOT NULL REFERENCES issue(id) ON DELETE CASCADE, +issue_id BIGINT NOT NULL REFERENCES issues(id) ON DELETE CASCADE, token UUID NOT NULL UNIQUE DEFAULT gen_random_uuid(), used BOOLEAN NOT NULL DEFAULT FALSE, created_at TIMESTAMPTZ NOT NULL DEFAULT now() )` - _, err = conn.Exec(context.Background(), sql) + _, err = conn.Exec(sql) if err != nil { log.Fatal(err) } @@ -125,7 +125,7 @@ option_id BIGINT NOT NULL REFERENCES option(id) ON DELETE CASCADE, voted_at TIMESTAMPTZ NOT NULL DEFAULT now() )` - _, err = conn.Exec(context.Background(), sql) + _, err = conn.Exec(sql) if err != nil { log.Fatal(err) } @@ -134,7 +134,7 @@ voted_at TIMESTAMPTZ NOT NULL DEFAULT now() CREATE INDEX idx_vote_tokens_issue_id ON vote_token(issue_id); CREATE INDEX idx_options_issue_id ON option(issue_id);` - _, err = conn.Exec(context.Background(), sql) + _, err = conn.Exec(sql) if err != nil { log.Fatal(err) } diff --git a/cmd/api/errors.go b/cmd/api/errors.go new file mode 100644 index 0000000..b470581 --- /dev/null +++ b/cmd/api/errors.go @@ -0,0 +1,104 @@ +package main + +import ( + "fmt" + "net/http" +) + +// The logError() method is a generic helper for logging an error message. Later in the +// book we'll upgrade this to use structured logging, and record additional information +// about the request including the HTTP method and URL. +func (app *application) logError(r *http.Request, err error) { + app.logger.PrintError(err, map[string]string{ + "request_method": r.Method, + "request_url": r.URL.String(), + }) +} + +// The errorResponse() method is a generic helper for sending JSON-formatted error +// messages to the client with a given status code. Note that we're using an interface{} +// type for the message parameter, rather than just a string type, as this gives us +// more flexibility over the values that we can include in the response. +func (app *application) errorResponse(w http.ResponseWriter, r *http.Request, status int, message interface{}) { + env := envelope{"error": message} + // Write the response using the writeJSON() helper. If this happens to return an + // error then log it, and fall back to sending the client an empty response with a + // 500 Internal Server Error status code. + err := app.writeJSON(w, status, env, nil) + if err != nil { + app.logError(r, err) + w.WriteHeader(500) + } +} + +// The serverErrorResponse() method will be used when our application encounters an +// unexpected problem at runtime. It logs the detailed error message, then uses the +// errorResponse() helper to send a 500 Internal Server Error status code and JSON +// response (containing a generic error message) to the client. +func (app *application) serverErrorResponse(w http.ResponseWriter, r *http.Request, err error) { + app.logError(r, err) + message := + "the server encountered a problem and could not process your request" + app.errorResponse(w, r, http.StatusInternalServerError, message) + +} + + +// The notFoundResponse() method will be used to send a 404 Not Found status code and +// JSON response to the client. +func (app *application) notFoundResponse(w http.ResponseWriter, r *http.Request) { + message := "the requested resource could not be found" + app.errorResponse(w, r, http.StatusNotFound, message) +} + +// The methodNotAllowedResponse() method will be used to send a 405 Method Not Allowed +// status code and JSON response to the client. +func (app *application) methodNotAllowedResponse(w http.ResponseWriter, r *http.Request) { + message := fmt.Sprintf("the %s method is not supported for this resource", r.Method) + app.errorResponse(w, r, http.StatusMethodNotAllowed, message) +} + +func (app *application) badRequestResponse(w http.ResponseWriter, r *http.Request, err error) { + app.errorResponse(w, r, http.StatusBadRequest, err.Error()) +} + +func (app *application) failedValidationResponse(w http.ResponseWriter, r *http.Request, errors map[string]string) { + app.errorResponse(w, r, http.StatusUnprocessableEntity, errors) +} + +func (app *application) editConflictResponse(w http.ResponseWriter, r *http.Request) { + message := "unable to update the record due to an edit conflict, please try again" + app.errorResponse(w, r, http.StatusConflict, message) +} + +func (app *application) rateLimitExceededResponse(w http.ResponseWriter, r *http.Request) { + message := "rate limit exceeded" + app.errorResponse(w, r, http.StatusTooManyRequests, message) +} + +func (app *application) invalidCredentialsResponse(w http.ResponseWriter, r *http.Request) { + message := "invalid authentication credentials" + app.errorResponse(w, r, http.StatusUnauthorized, message) +} + +func (app *application) invalidAuthenticationTokenResponse(w http.ResponseWriter, r *http.Request) { + w.Header().Set("WWW-Authenticate", "Bearer") + + message := "invalid or missing authentication token" + app.errorResponse(w, r, http.StatusUnauthorized, message) +} + +func (app *application) authenticationRequiredResponse(w http.ResponseWriter, r *http.Request) { + message :="you must be authenticated to access this resource" + app.errorResponse(w, r, http.StatusUnauthorized, message) +} + +func (app *application) inactiveAccountResponse(w http.ResponseWriter, r *http.Request) { + message := "your user account must be activated to access this resource" + app.errorResponse(w, r, http.StatusForbidden, message) +} + +func (app *application) notPermittedResponse(w http.ResponseWriter, r *http.Request) { + message := "your user account doesn't have the necessary permissions to access this resource" + app.errorResponse(w, r, http.StatusForbidden, message) +} diff --git a/cmd/api/handlers.go b/cmd/api/handlers.go index 946978d..9bd127c 100644 --- a/cmd/api/handlers.go +++ b/cmd/api/handlers.go @@ -1,11 +1,17 @@ package main import ( + // "encoding/json" "fmt" + "html/template" + "log" "net/http" "time" - "log" - "html/template" + // "github.com/julienschmidt/httprouter" + // "strconv" + // "errors" + // "party.at/party/internal/data" + // "party.at/party/internal/validator" ) func home(w http.ResponseWriter, r *http.Request) { @@ -48,7 +54,7 @@ func ws(w http.ResponseWriter, r *http.Request) { done := make(chan struct{}) go func() { - ticker := time.NewTicker(1 * time.Second); + ticker := time.NewTicker(1 * time.Second) for { select { @@ -60,7 +66,6 @@ func ws(w http.ResponseWriter, r *http.Request) { "timestamp": t.Format(time.RFC3339), } - if err := conn.WriteJSON(msg); err != nil { fmt.Println("Write error:", err) return @@ -87,6 +92,33 @@ func ws(w http.ResponseWriter, r *http.Request) { } } +// func handleMobileLogin(w http.ResponseWriter, r *http.Request) { +// // 1. Get the token from the request header +// rawIDToken := r.Header.Get("Authorization") + +// // 2. Initialize the verifier (pointing to ID Austria's keys) +// verifier := provider.Verifier(&oidc.Config{ClientID: "YOUR_APP_ID"}) + +// // 3. Verify the signature and expiration +// idToken, err := verifier.Verify(ctx, rawIDToken) +// if err != nil { +// http.Error(w, "Invalid Token", http.StatusUnauthorized) +// return +// } + +// // 4. Extract User Data +// var claims struct { +// Subject string `json:"sub"` // This is the unique ID +// Name string `json:"name"` +// } +// if err := idToken.Claims(&claims); err != nil { +// // Handle error +// } + +// // 5. Create your own application session for the mobile app +// issueLocalSession(w, claims.Subject) +// } + func redirectToHTTPS(w http.ResponseWriter, r *http.Request) { - http.Redirect(w, r, "https://localhost:8443" + r.URL.RequestURI(), http.StatusMovedPermanently) + http.Redirect(w, r, "https://localhost:8443"+r.URL.RequestURI(), http.StatusMovedPermanently) } diff --git a/cmd/api/healthcheck.go b/cmd/api/healthcheck.go index 9da97d3..e8de93b 100644 --- a/cmd/api/healthcheck.go +++ b/cmd/api/healthcheck.go @@ -1,12 +1,22 @@ package main import ( - "fmt" + // "encoding/json" + //"fmt" "net/http" ) func (app *application) healthcheckHandler(w http.ResponseWriter, r *http.Request) { - fmt.Fprintln(w, "status: available") - fmt.Fprintf(w, "environment: %s\n", app.config.env) - fmt.Fprintf(w, "version: %s\n", version) + env := envelope{ + "status": "available", + "system_info": map[string]string{ + "environment": app.config.env, + "version": version, + }, + } + + err := app.writeJSON(w, http.StatusOK, envelope{"health_check": env}, nil) + if err != nil { + app.serverErrorResponse(w,r, err) + } } \ No newline at end of file diff --git a/cmd/api/helpers.go b/cmd/api/helpers.go new file mode 100644 index 0000000..076d39d --- /dev/null +++ b/cmd/api/helpers.go @@ -0,0 +1,193 @@ +package main + +import ( + "net/http" + "encoding/json" + "strings" + "errors" + "fmt" + "io" + "net/url" + "strconv" + "party.at/party/internal/validator" + "github.com/julienschmidt/httprouter" + +) + +type envelope map[string]interface{} + +func (app *application) readString(qs url.Values, key string, defaultValue string) string { + // Extract the value for a given key from the query string. If no key exists this + // will return the empty string "". + s := qs.Get(key) + // If no key exists (or the value is empty) then return the default value. + if s == "" { + return defaultValue + } + + // Otherwise return the string. + return s +} + +func (app *application) readCSV(qs url.Values, key string, defaultValue []string) []string { + // Extract the value from the query string. + csv := qs.Get(key) + + // If no key exists (or the value is empty) then return the default value. + if csv == "" { + return defaultValue + } + + // Otherwise parse the value into a []string slice and return it. + return strings.Split(csv, ",") +} + +func (app *application) readInt(qs url.Values, key string, defaultValue int, v *validator.Validator) int { + // Extract the value from the query string. + s := qs.Get(key) + + // If no key exists (or the value is empty) then return the default value. + if s == "" { + return defaultValue + } + + // Try to convert the value to an int. If this fails, add an error message to the + // validator instance and return the default value. + i, err := strconv.Atoi(s) + if err != nil { + v.AddError(key, "must be an integer value") + return defaultValue + } + + // Otherwise, return the converted integer value. + return i +} + +func (app *application) readIDParam(r *http.Request) (int64, error) { + params := httprouter.ParamsFromContext(r.Context()) + + id, err := strconv.ParseInt(params.ByName("id"), 10, 64) + if err != nil || id < 1 { + return 0, errors.New("invalid id parameter") + } + + return id, nil +} + +func (app *application) writeJSON(w http.ResponseWriter, status int, data envelope, headers http.Header) error { + // Encode the data to JSON, returning the error if there was one. + js, err := json.MarshalIndent(data, "", "\t") + if err != nil { + return err + } + + + // Append a newline to make it easier to view in terminal applications. + js = append(js, '\n') + + + // At this point, we know that we won't encounter any more errors before writing the + // response, so it's safe to add any headers that we want to include. We loop + // through the header map and add each header to the http.ResponseWriter header map. + // Note that it's OK if the provided header map is nil. Go doesn't throw an error + // if you try to range over (or generally, read from) a nil map. + for key, value := range headers { + w.Header()[key] = value + } + + + // Add the "Content-Type: application/json" header, then write the status code and + // JSON response. + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + w.Write(js) + + return nil +} + +func (app *application) readJSON(w http.ResponseWriter, r *http.Request, dst interface{}) error { + maxBytes := 1048576 + + r.Body = http.MaxBytesReader(w, r.Body, int64(maxBytes)) + + // Initialize the json.Decoder, and call the DisallowUnknownFields() method on it + // before decoding. This means that if the JSON from the client now includes any + // field which cannot be mapped to the target destination, the decoder will return + // an error instead of just ignoring the field. + dec := json.NewDecoder(r.Body) + dec.DisallowUnknownFields() + // Decode the request body to the destination. + err := dec.Decode(dst) + if err != nil { + var syntaxError *json.SyntaxError + var unmarshalTypeError *json.UnmarshalTypeError + var invalidUnmarshalError *json.InvalidUnmarshalError + + switch { + case errors.As(err, &syntaxError): + return fmt.Errorf("body contains badly-formed JSON (at character %d)", syntaxError.Offset) + + case errors.Is(err, io.ErrUnexpectedEOF): + return errors.New("body contains badly-formed JSON") + + case errors.As(err, &unmarshalTypeError): + if unmarshalTypeError.Field != "" { + return fmt.Errorf("body contains incorrect JSON type for field %q", unmarshalTypeError.Field) + } + + return fmt.Errorf("body contains incorrect JSON type (at character %d)", unmarshalTypeError.Offset) + + case errors.Is(err, io.EOF): + return errors.New("body must not be empty") + + // If the JSON contains a field which cannot be mapped to the target destination + // then Decode() will now return an error message in the format "json: unknown + // field """. We check for this, extract the field name from the error, + // and interpolate it into our custom error message. Note that there's an open + // issue at https://github.com/golang/go/issues/29035 regarding turning this + // into a distinct error type in the future. + case strings.HasPrefix(err.Error(), "json: unknown field "): + fieldName := strings.TrimPrefix(err.Error(), "json: unknown field ") + return fmt.Errorf("body contains unknown key %s", fieldName) + + + // If the request body exceeds 1MB in size the decode will now fail with the + // error "http: request body too large". There is an open issue about turning + // this into a distinct error type at https://github.com/golang/go/issues/30715. + case err.Error() == "http: request body too large": + return fmt.Errorf("body must not be larger than %d bytes", maxBytes) + + case errors.As(err, &invalidUnmarshalError): + panic(err) + + default: + return err + } + } + + // Call Decode() again, using a pointer to an empty anonymous struct as the + // destination. If the request body only contained a single JSON value this will + // return an io.EOF error. So if we get anything else, we know that there is + // additional data in the request body and we return our own custom error message. + err = dec.Decode(&struct{}{}) + if err != io.EOF { + return errors.New("body must only contain a single JSON value") + } + + return nil +} + +func (app *application) background(fn func()) { + // Launch a background goroutine. + go func() { + // Recover any panic. + defer func() { + if err := recover(); err != nil { + app.logger.PrintError(fmt.Errorf("%s", err), nil) + } + }() + + // Execute the arbitrary function that we passed as the parameter. + fn() + }() +} diff --git a/cmd/api/issues.go b/cmd/api/issues.go new file mode 100644 index 0000000..fe8772c --- /dev/null +++ b/cmd/api/issues.go @@ -0,0 +1,221 @@ +package main + +import ( + // "encoding/json" + "fmt" + // "html/template" + // "log" + "net/http" + "time" + // "github.com/julienschmidt/httprouter" + // "strconv" + "errors" + "party.at/party/internal/data" + "party.at/party/internal/validator" +) + +func (app *application) listIssuesHandler(w http.ResponseWriter, r *http.Request) { + // To keep things consistent with our other handlers, we'll define an input struct + // to hold the expected values from the request query string. + var input struct { + Title string + data.Filters + } + + // Initialize a new Validator instance. + v := validator.New() + // Call r.URL.Query() to get the url.Values map containing the query string data. + qs := r.URL.Query() + + // Use our helpers to extract the title and genres query string values, falling back + // to defaults of an empty string and an empty slice respectively if they are not + // provided by the client. + input.Title = app.readString(qs, "title", "") + // input.Genres = app.readCSV(qs, "genres", []string{}) + + // Get the page and page_size query string values as integers. Notice that we set + // the default page value to 1 and default page_size to 20, and that we pass the + // validator instance as the final argument here. + input.Filters.Page = app.readInt(qs, "page", 1, v) + input.Filters.PageSize = app.readInt(qs, "page_size", 20, v) + + // Extract the sort query string value, falling back to "id" if it is not provided + // by the client (which will imply a ascending sort on issue ID). + input.Filters.Sort = app.readString(qs, "sort", "id") + input.Filters.SortSafelist = []string{"id", "-id", "title", "-title", "description", "-description"} + + if data.ValidateFilters(v, input.Filters); !v.Valid() { + app.failedValidationResponse(w, r, v.Errors) + return + } + + issues, metadata, err := app.models.Issues.GetAll(input.Title, input.Filters) + if err != nil { + app.serverErrorResponse(w, r, err) + return + } + + err = app.writeJSON(w, http.StatusOK, envelope{"issues": issues, "metadata": metadata}, nil) + if err != nil { + app.serverErrorResponse(w, r, err) + } +} + +func (app *application) createIssueHandler(w http.ResponseWriter, r *http.Request) { + + var input struct { + Title string `json:"title"` + Description string `json:"description"` + StartTime time.Time `json:"start_time"` + EndTime time.Time `json:"end_time"` + } + + err := app.readJSON(w, r, &input) + if err != nil { + // Use the new badRequestResponse() helper. + app.badRequestResponse(w, r, err) + return + } + + issue := &data.Issue{ + Title: input.Title, + Description: input.Description, + StartTime: input.StartTime, + EndTime: input.EndTime, + } + + v := validator.New() + + if data.ValidateIssue(v, issue); !v.Valid() { + app.failedValidationResponse(w, r, v.Errors) + return + } + + err = app.models.Issues.Insert(issue) + if err != nil { + app.serverErrorResponse(w, r, err) + return + } + + headers := make(http.Header) + headers.Set("Location", fmt.Sprintf("/v1/issues/%d", issue.ID)) + + err = app.writeJSON(w, http.StatusCreated, envelope{"issue": issue}, headers) + if err != nil { + app.serverErrorResponse(w, r, err) + } +} + +func (app *application) readIssueHandler(w http.ResponseWriter, r *http.Request) { + id, err := app.readIDParam(r) + if err != nil { + app.notFoundResponse(w, r) + return + } + + issue, err := app.models.Issues.Get(id) + if err != nil { + switch { + case errors.Is(err, data.ErrRecordNotFound): + app.notFoundResponse(w, r) + default: + app.serverErrorResponse(w, r, err) + } + return + } + + // Encode the struct to JSON and send it as the HTTP response. + err = app.writeJSON(w, http.StatusOK, envelope{"issue": issue}, nil) + if err != nil { + app.serverErrorResponse(w, r, err) + } +} + +func (app *application) updateIssueHandler(w http.ResponseWriter, r *http.Request) { + id, err := app.readIDParam(r) + if err != nil { + app.notFoundResponse(w, r) + return + } + + issue, err := app.models.Issues.Get(id) + if err != nil { + switch { + case errors.Is(err, data.ErrRecordNotFound): + app.notFoundResponse(w, r) + default: + app.serverErrorResponse(w, r, err) + } + return + } + + var input struct { + Title *string `json:"title"` + Description *string `json:"description"` + StartTime *time.Time `json:"start_time"` + EndTime *time.Time `json:"end_time"` + } + + err = app.readJSON(w, r, &input) + if err != nil { + app.badRequestResponse(w, r, err) + return + } + + if input.Title != nil { issue.Title = *input.Title } + if input.Description != nil { issue.Description = *input.Description } + if input.StartTime != nil { issue.StartTime = *input.StartTime } + if input.StartTime != nil { issue.EndTime = *input.EndTime } + + v := validator.New() + if data.ValidateIssue(v, issue); !v.Valid() { + app.failedValidationResponse(w, r, v.Errors) + return + } + + err = app.models.Issues.Update(issue) + if err != nil { + switch { + case errors.Is(err, data.ErrEditConflict): + app.editConflictResponse(w, r) + default: + app.serverErrorResponse(w, r, err) + } + + return + } + + // Write the updated issue record in a JSON response. + err = app.writeJSON(w, http.StatusOK, envelope{"issue": issue}, nil) + if err != nil { + app.serverErrorResponse(w, r, err) + } +} + +func (app *application) deleteIssueHandler(w http.ResponseWriter, r *http.Request) { + id, err := app.readIDParam(r) + if err != nil { + app.notFoundResponse(w, r) + return + } + + // Delete the issue from the database, sending a 404 Not Found response to the + // client if there isn't a matching record. + err = app.models.Issues.Delete(id) + if err != nil { + switch { + case errors.Is(err, data.ErrRecordNotFound): + app.notFoundResponse(w, r) + default: + app.serverErrorResponse(w, r, err) + } + + return + } + + // Return a 200 OK status code along with a success message. + err = app.writeJSON(w, http.StatusOK, envelope{"message": "issue successfully deleted"}, nil) + if err != nil { + app.serverErrorResponse(w, r, err) + } +} diff --git a/cmd/api/main.go b/cmd/api/main.go index 6a8e492..7f45805 100644 --- a/cmd/api/main.go +++ b/cmd/api/main.go @@ -3,31 +3,63 @@ package main import ( "context" "log" + "strings" + //"html/template" - "net/http" "flag" + "net/http" "github.com/gorilla/websocket" - "github.com/jackc/pgx/v5" -) - -import ( - "fmt" "os" "time" + + "database/sql" + + _ "github.com/lib/pq" + + "party.at/party/internal/data" + "party.at/party/internal/mailer" + "party.at/party/internal/jsonlog" ) const version = "1.0.0" type config struct { port int - env string + env string + + db struct { + dsn string + maxOpenConns int + maxIdleConns int + maxIdleTime string + } + + limiter struct { + rps float64 + burst int + enabled bool + } + + smtp struct { + host string + port int + username string + password string + sender string + } + + cors struct { + trustedOrigins []string + } } type application struct { config config - logger *log.Logger + logger *jsonlog.Logger + models data.Models + mailer mailer.Mailer } type Message struct { @@ -47,39 +79,84 @@ func main() { var cfg config flag.IntVar(&cfg.port, "port", 4000, "API server port") - flag.StringVar(&cfg.env, "env", "development", "Environment (development|staging|production)") + flag.StringVar(&cfg.env, "env", "development", "Environment (development|staging|production)") + flag.StringVar(&cfg.db.dsn, "db-dsn", os.Getenv("PARTY_DB_DSN"), "PostgreSQL DSN") //addr := flag.String("addr", ":8443", "HTTP network address") + + flag.IntVar(&cfg.db.maxOpenConns, "db-max-open-conns", 25, "PostgreSQL max open connections") + flag.IntVar(&cfg.db.maxIdleConns, "db-max-idle-conns", 25, "PostgreSQL max idle connections") + flag.StringVar(&cfg.db.maxIdleTime, "db-max-idle-time", "15m", "PostgreSQL max connection idle time") + + flag.StringVar(&cfg.smtp.host, "smtp-host", "smtp.mailtrap.io", "SMTP host") + flag.IntVar(&cfg.smtp.port, "smtp-port", 25, "SMTP port") + flag.StringVar(&cfg.smtp.username, "smtp-username", "98cf60028d7fcb", "SMTP username") + flag.StringVar(&cfg.smtp.password, "smtp-password", "b9d4a35372e971", "SMTP password") + flag.StringVar(&cfg.smtp.sender, "smtp-sender", "Greenlight ", "SMTP sender") + + flag.Float64Var(&cfg.limiter.rps, "limiter-rps", 2, "Rate limiter maximum requests per second") + flag.IntVar(&cfg.limiter.burst, "limiter-burst", 4, "Rate limiter maximum burst") + flag.BoolVar(&cfg.limiter.enabled, "limiter-enabled", true, "Enable rate limiter") + + flag.Func("cors-trusted-origins", "Trusted CORS origins (space separated)", func(val string) error { + cfg.cors.trustedOrigins = strings.Fields(val) + return nil + }) + flag.Parse() - logger := log.New(os.Stdout, "", log.Ldate | log.Ltime) + logger := jsonlog.New(os.Stdout, jsonlog.LevelInfo) + log.Printf("%s\n", cfg.db.dsn) + + db, err := openDB(cfg) + if err != nil { + logger.PrintFatal(err, nil) + } + + defer db.Close() + app := &application{ config: cfg, logger: logger, + models: data.NewModels(db), + mailer: mailer.New(cfg.smtp.host, cfg.smtp.port, cfg.smtp.username, cfg.smtp.password, cfg.smtp.sender), } + + log.Println("Hello, Sailor!") - conn, err := pgx.Connect(context.Background(), "postgres://party:iK2SoVbDhdCki5n3LxGyP6zKpLspt4@losandesgames.com:5432/party") + + err = app.serve() if err != nil { - log.Fatal(err) - } - - defer conn.Close(context.Background()) - - database_init(conn) - - srv := &http.Server{ - Addr: fmt.Sprintf(":%d", cfg.port), - Handler: app.routes(), - IdleTimeout: time.Minute, - ReadTimeout: 10 * time.Second, - WriteTimeout: 30 * time.Second, - } - - logger.Printf("starting %s server on %s", cfg.env, srv.Addr); - - // Start HTTPS server (requires cert.pem and key.pem in current dir) - err = srv.ListenAndServe() - if err != nil { - panic(err) + logger.PrintFatal(err, nil) } } + +func openDB(cfg config) (*sql.DB, error) { + db, err := sql.Open("postgres", cfg.db.dsn) + if err != nil { + return nil, err + } + + + db.SetMaxOpenConns(cfg.db.maxOpenConns) + db.SetMaxIdleConns(cfg.db.maxIdleConns) + duration, err := time.ParseDuration(cfg.db.maxIdleTime) + if err != nil { + return nil, err + } + db.SetConnMaxIdleTime(duration) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Use PingContext() to establish a new connection to the database, passing in the + // context we created above as a parameter. If the connection couldn't be + // established successfully within the 5 second deadline, then this will return an + // error. + err = db.PingContext(ctx) + if err != nil { + return nil, err + } + + return db, nil +} diff --git a/cmd/api/middleware.go b/cmd/api/middleware.go new file mode 100644 index 0000000..81642fc --- /dev/null +++ b/cmd/api/middleware.go @@ -0,0 +1,215 @@ +package main + +import ( + "fmt" + "net" + "net/http" + "sync" + "time" + "strings" + "errors" + + "golang.org/x/time/rate" + "party.at/party/internal/data" + "party.at/party/internal/validator" +) + +func (app *application) recoverPanic(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Create a deferred function (which will always be run in the event of a panic + // as Go unwinds the stack). + defer func() { + // Use the builtin recover function to check if there has been a panic or + // not. + if err := recover(); err != nil { + // If there was a panic, set a "Connection: close" header on the + // response. This acts as a trigger to make Go's HTTP server + // automatically close the current connection after a response has been + // sent. + w.Header().Set("Connection", "close") + // The value returned by recover() has the type interface{}, so we use + // fmt.Errorf() to normalize it into an error and call our + // serverErrorResponse() helper. In turn, this will log the error using + // our custom Logger type at the ERROR level and send the client a 500 + // Internal Server Error response. + app.serverErrorResponse(w, r, fmt.Errorf("%s", err)) + } + }() + next.ServeHTTP(w, r) + }) +} + +func (app *application) rateLimit(next http.Handler) http.Handler { + + type client struct { + limiter *rate.Limiter + lastSeen time.Time + } + + var mu sync.Mutex + var clients = make(map[string]*client) + + go func() { + for { + time.Sleep(time.Minute) + + // Lock the mutex to prevent any rate limiter checks from happening while + // the cleanup is taking place. + mu.Lock() + + // Loop through all clients. If they haven't been seen within the last three + // minutes, delete the corresponding entry from the map. + for ip, client := range clients { + if time.Since(client.lastSeen) > 3*time.Minute { + delete(clients, ip) + } + } + + // Importantly, unlock the mutex when the cleanup is complete. + mu.Unlock() + } + }() + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + + if app.config.limiter.enabled { + ip, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + app.serverErrorResponse(w, r, err) + return + } + + mu.Lock() + + if _, found := clients[ip]; !found { + clients[ip] = &client{limiter: rate.NewLimiter(rate.Limit(app.config.limiter.rps), app.config.limiter.burst)} + } + + clients[ip].lastSeen = time.Now() + + if !clients[ip].limiter.Allow() { + mu.Unlock() + app.rateLimitExceededResponse(w, r) + return + } + + mu.Unlock() + } + + next.ServeHTTP(w, r) + }) +} + +func (app *application) authenticate(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Add("Vary", "Authorization") + + authorizationHeader := r.Header.Get("Authorization") + + if authorizationHeader == "" { + r = app.contextSetUser(r, data.AnonymousUser) + next.ServeHTTP(w, r) + return + } + + headerParts := strings.Split(authorizationHeader, " ") + if len(headerParts) != 2 || headerParts[0] != "Bearer" { + app.invalidAuthenticationTokenResponse(w, r) + return + } + + token := headerParts[1] + + v := validator.New() + + if data.ValidateTokenPlaintext(v, token); !v.Valid() { + app.invalidAuthenticationTokenResponse(w, r) + return + } + + userIdentity, err := app.models.UserIdentities.GetForToken(data.ScopeAuthentication, token) + if err != nil { + switch { + case errors.Is(err, data.ErrRecordNotFound): + app.invalidAuthenticationTokenResponse(w, r) + default: + app.serverErrorResponse(w, r, err) + } + return + } + + user, err := app.models.Users.Get(userIdentity.UserID) + if err != nil { + switch { + case errors.Is(err, data.ErrRecordNotFound): + app.invalidCredentialsResponse(w, r) + default: + app.serverErrorResponse(w, r, err) + } + return + } + + // Call the contextSetUser() helper to add the user information to the request + // context. + r = app.contextSetUser(r, user) + // Call the next handler in the chain. + next.ServeHTTP(w, r) + }) +} + +func (app *application) requireAuthenticatedUser(next http.HandlerFunc) http.HandlerFunc { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + user := app.contextGetUser(r) + if user.IsAnonymous() { + app.authenticationRequiredResponse(w, r) + return + } + + next.ServeHTTP(w, r) + }) +} + +func (app *application) requireActivatedUser(next http.HandlerFunc) http.HandlerFunc { + // Rather than returning this http.HandlerFunc we assign it to the variable fn. + fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + user := app.contextGetUser(r) + + // Check that a user is activated. + if !user.Activated { + app.inactiveAccountResponse(w, r) + return + } + next.ServeHTTP(w, r) + }) + + // Wrap fn with the requireAuthenticatedUser() middleware before returning it. + return app.requireAuthenticatedUser(fn) +} + +func (app *application) requirePermission(code string, next http.HandlerFunc) http.HandlerFunc { + fn := func(w http.ResponseWriter, r *http.Request) { + // Retrieve the user from the request context. + user := app.contextGetUser(r) + + // Get the slice of permissions for the user. + permissions, err := app.models.Permissions.GetAllForUser(user.ID) + if err != nil { + app.serverErrorResponse(w, r, err) + return + } + + // Check if the slice includes the required permission. If it doesn't, then + // return a 403 Forbidden response. + if !permissions.Include(code) { + app.notPermittedResponse(w, r) + return + } + + // Otherwise they have the required permission so we call the next handler in + // the chain. + next.ServeHTTP(w, r) + } + + // Wrap this with the requireActivatedUser() middleware before returning it. + return app.requireActivatedUser(fn) +} diff --git a/cmd/api/routes.go b/cmd/api/routes.go index fb7457a..e798ad4 100644 --- a/cmd/api/routes.go +++ b/cmd/api/routes.go @@ -5,18 +5,34 @@ import ( "github.com/julienschmidt/httprouter" ) -func (app *application) routes() *httprouter.Router { +func (app *application) routes() http.Handler { router := httprouter.New() + router.NotFound = http.HandlerFunc(app.notFoundResponse) + router.MethodNotAllowed = http.HandlerFunc(app.methodNotAllowedResponse) + fileServer := http.FileServer(http.Dir("ui/static")) - router.HandlerFunc(http.MethodGet, "/", home) - router.HandlerFunc(http.MethodGet, "/ws", ws) - router.Handler(http.MethodGet, "/static/", http.StripPrefix("/static", fileServer)) - router.HandlerFunc(http.MethodGet, "/v1/healthcheck", app.healthcheckHandler) - // router.HandlerFunc(http.MethodPost, "/v1/movies", app.createMovieHandler) - // router.HandlerFunc(http.MethodGet, "/v1/movies/:id", app.showMovieHandler) + router.HandlerFunc(http.MethodGet, "/", home) + router.HandlerFunc(http.MethodGet, "/ws", ws) + router.Handler (http.MethodGet, "/static/", http.StripPrefix("/static", fileServer)) + router.HandlerFunc(http.MethodGet, "/v1/healthcheck", app.healthcheckHandler) + + router.HandlerFunc(http.MethodGet, "/v1/issues", app.requirePermission("issues:read", app.listIssuesHandler)) - return router + router.HandlerFunc(http.MethodPost, "/v1/issues", app.requirePermission("issues:write", app.createIssueHandler)) + router.HandlerFunc(http.MethodGet, "/v1/issues/:id", app.requirePermission("issues:read", app.readIssueHandler)) + router.HandlerFunc(http.MethodPatch, "/v1/issues/:id", app.requirePermission("issues:write", app.updateIssueHandler)) + router.HandlerFunc(http.MethodDelete, "/v1/issues/:id", app.requirePermission("issues:write", app.deleteIssueHandler)) + + router.HandlerFunc(http.MethodPost, "/v1/users", app.createUserHandler) + // router.HandlerFunc(http.MethodGet, "/v1/users/:id", app.readUserHandler) + // router.HandlerFunc(http.MethodPatch, "/v1/users/:id", app.updateUserHandler) + router.HandlerFunc(http.MethodDelete, "/v1/users/:id", app.deleteUserHandler) + router.HandlerFunc(http.MethodPut, "/v1/users/activated", app.activateUserHandler) + + router.HandlerFunc(http.MethodPost, "/v1/tokens/authentication", app.createAuthenticationTokenHandler) + + return app.recoverPanic(app.enableCORS(app.rateLimit(app.authenticate(router)))) } diff --git a/cmd/api/server.go b/cmd/api/server.go new file mode 100644 index 0000000..5464815 --- /dev/null +++ b/cmd/api/server.go @@ -0,0 +1,67 @@ +package main + +import ( + "context" + "errors" + "fmt" + "net/http" + "time" + "log" + "os" + "os/signal" + "syscall" +) + +func (app *application) serve() error { + // Declare a HTTP server using the same settings as in our main() function. + srv := &http.Server{ + Addr: fmt.Sprintf(":%d", app.config.port), + Handler: app.routes(), + ErrorLog: log.New(app.logger, "", 0), + IdleTimeout: time.Minute, + ReadTimeout: 10 * time.Second, + WriteTimeout: 30 * time.Second, + } + + shutdownError := make(chan error) + + go func() { + + quit := make(chan os.Signal, 1) + + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + + s := <-quit + + app.logger.PrintInfo("shutting down server", map[string]string{ + "signal": s.String(), + }) + + ctx, cancel := context.WithTimeout(context.Background(), 5 * time.Second) + defer cancel() + + shutdownError <- srv.Shutdown(ctx) + }() + + // Likewise log a "starting server" message. + app.logger.PrintInfo("starting server", map[string]string{ + "addr": srv.Addr, + "env": app.config.env, + }) + + err := srv.ListenAndServe() + if !errors.Is(err, http.ErrServerClosed) { + return err + } + + err = <-shutdownError + if err != nil { + return err + } + + app.logger.PrintInfo("stopped server", map[string]string{ + "addr": srv.Addr, + }) + + return nil +} diff --git a/cmd/api/tokens.go b/cmd/api/tokens.go new file mode 100644 index 0000000..e3eb913 --- /dev/null +++ b/cmd/api/tokens.go @@ -0,0 +1,77 @@ +package main + +import ( + "errors" + "net/http" + "time" + "party.at/party/internal/data" + "party.at/party/internal/validator" +) + +func (app *application) createAuthenticationTokenHandler(w http.ResponseWriter, r *http.Request) { + var input struct { + Email string `json:"email"` + Password string `json:"password"` + } + + err := app.readJSON(w, r, &input) + if err != nil { + app.badRequestResponse(w, r, err) + return + } + + // Validate the email and password provided by the client. + v := validator.New() + + data.ValidateEmail(v, input.Email) + data.ValidatePasswordPlaintext(v, input.Password) + + if !v.Valid() { + app.failedValidationResponse(w, r, v.Errors) + return + } + + user, err := app.models.Users.GetByEmail(input.Email) + if err != nil { + switch { + case errors.Is(err, data.ErrRecordNotFound): + app.invalidCredentialsResponse(w, r) + default: + app.serverErrorResponse(w, r, err) + } + return + } + + identities, err := app.models.UserIdentities.GetByUser(user.ID) + + var authenticatedIdentity *data.UserIdentity + + for _, user_identity := range identities { + match, err := user_identity.Password.Matches(input.Password) + if err != nil { + app.serverErrorResponse(w, r, err) + return + } + + if match { + authenticatedIdentity = user_identity + break + } + } + + if authenticatedIdentity == nil { + app.invalidCredentialsResponse(w, r) + return + } + + token, err := app.models.Tokens.New(user.ID, authenticatedIdentity.ID, 24 * time.Hour, data.ScopeAuthentication) + if err != nil { + app.serverErrorResponse(w, r, err) + return + } + + err = app.writeJSON(w, http.StatusCreated, envelope{"authentication_token": token}, nil) + if err != nil { + app.serverErrorResponse(w, r, err) + } +} diff --git a/cmd/api/users.go b/cmd/api/users.go new file mode 100644 index 0000000..dbb473d --- /dev/null +++ b/cmd/api/users.go @@ -0,0 +1,197 @@ +package main + +import ( + "errors" + "net/http" + "party.at/party/internal/data" + "party.at/party/internal/validator" + "database/sql" + "time" +) + +func (app *application) createUserHandler(w http.ResponseWriter, r *http.Request) { + + var input struct { + ProviderId int64 `json:"provider_id"` + Username string `json:"username"` + Email string `json:"email"` + Password string `json:"password"` + Name string `json:"name"` + AltName string `json:"alt_name"` + } + + err := app.readJSON(w, r, &input) + if err != nil { + app.badRequestResponse(w, r, err) + return + } + + user := &data.User{ + Email: input.Email, + Name: input.Name, + AltName: sql.NullString{String: input.AltName, Valid: true}, + Activated: false, + } + + userIdentity := &data.UserIdentity{ + ProviderID: input.ProviderId, + ProviderUserID: input.Username, + } + + err = userIdentity.Password.Set(input.Password) + if err != nil { + app.serverErrorResponse(w, r, err) + return + } + + v := validator.New() + if data.ValidateUser(v, user); !v.Valid() { + app.failedValidationResponse(w, r, v.Errors) + return + } + + if data.ValidateUserIdentity(v, userIdentity); !v.Valid() { + app.failedValidationResponse(w, r, v.Errors) + return + } + + err = app.models.Users.ExecuteRegistrationTx(user, userIdentity) + if err != nil { + switch { + case errors.Is(err, data.ErrDuplicateEmail): + v.AddError("email", "a user with this email address already exists") + app.failedValidationResponse(w, r, v.Errors) + case errors.Is(err, data.ErrDuplicateUser): + v.AddError("username", "a user with this username already exists") + app.failedValidationResponse(w, r, v.Errors) + default: + app.serverErrorResponse(w, r, err) + } + + return + } + + err = app.models.Permissions.AddForUser(user.ID, "issues:read") + if err != nil { + app.serverErrorResponse(w, r, err) + return + } + + token, err := app.models.Tokens.New(user.ID, userIdentity.ID, 3 * 24 * time.Hour, data.ScopeActivation) + if err != nil { + app.serverErrorResponse(w, r, err) + return + } + + app.background(func() { + data := map[string]interface{}{ + "token": token.Plaintext, + "userID": user.ID, + } + + err = app.mailer.Send(user.Email, "user_welcome.tmpl", data) + if err != nil { + app.logger.PrintError(err, nil) + } + }) + + // Write a JSON response containing the user data along with a 201 Created status + // code. + err = app.writeJSON(w, http.StatusCreated, envelope{"user": user}, nil) + if err != nil { + app.serverErrorResponse(w, r, err) + } +} + +func (app *application) deleteUserHandler(w http.ResponseWriter, r *http.Request) { + id, err := app.readIDParam(r) + if err != nil { + app.notFoundResponse(w, r) + return + } + + // Delete the issue from the database, sending a 404 Not Found response to the + // client if there isn't a matching record. + err = app.models.Users.Delete(id) + if err != nil { + switch { + case errors.Is(err, data.ErrRecordNotFound): + app.notFoundResponse(w, r) + default: + app.serverErrorResponse(w, r, err) + } + + return + } + + // Return a 200 OK status code along with a success message. + err = app.writeJSON(w, http.StatusOK, envelope{"message": "user successfully deleted"}, nil) + if err != nil { + app.serverErrorResponse(w, r, err) + } +} + +func (app *application) activateUserHandler(w http.ResponseWriter, r *http.Request) { + // Parse the plaintext activation token from the request body. + var input struct { + TokenPlaintext string `json:"token"` + } + + err := app.readJSON(w, r, &input) + if err != nil { + app.badRequestResponse(w, r, err) + return + } + + // Validate the plaintext token provided by the client. + v := validator.New() + if data.ValidateTokenPlaintext(v, input.TokenPlaintext); !v.Valid() { + app.failedValidationResponse(w, r, v.Errors) + return + } + + // Retrieve the details of the user associated with the token using the + // GetForToken() method (which we will create in a minute). If no matching record + // is found, then we let the client know that the token they provided is not valid. + user, err := app.models.Users.GetForToken(data.ScopeActivation, input.TokenPlaintext) + if err != nil { + switch { + case errors.Is(err, data.ErrRecordNotFound): + v.AddError("token", "invalid or expired activation token") + app.failedValidationResponse(w, r, v.Errors) + default: + app.serverErrorResponse(w, r, err) + } + return + } + + // Update the user's activation status. + user.Activated = true + + // Save the updated user record in our database, checking for any edit conflicts in + // the same way that we did for our movie records. + err = app.models.Users.Update(user) + if err != nil { + switch { + case errors.Is(err, data.ErrEditConflict): + app.editConflictResponse(w, r) + default: + app.serverErrorResponse(w, r, err) + } + return + } + + // If everything went successfully, then we delete all activation tokens for the + // user. + err = app.models.Tokens.DeleteAllForUser(data.ScopeActivation, user.ID) + if err != nil { + app.serverErrorResponse(w, r, err) + return + } + + // Send the updated user details to the client in a JSON response. + err = app.writeJSON(w, http.StatusOK, envelope{"user": user}, nil) + if err != nil { + app.serverErrorResponse(w, r, err) + } +} diff --git a/go.mod b/go.mod index d13eb6c..6022ac9 100644 --- a/go.mod +++ b/go.mod @@ -1,16 +1,18 @@ module party.at/party -go 1.24.2 +go 1.25.0 require ( github.com/gorilla/websocket v1.5.3 - github.com/jackc/pgx/v5 v5.7.6 github.com/julienschmidt/httprouter v1.3.0 ) require ( github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/lib/pq v1.12.0 // indirect + github.com/wneessen/go-mail v0.7.2 // indirect golang.org/x/crypto v0.37.0 // indirect - golang.org/x/text v0.24.0 // indirect + golang.org/x/text v0.29.0 // indirect + golang.org/x/time v0.15.0 // indirect ) diff --git a/go.sum b/go.sum index 3a3d815..ff32336 100644 --- a/go.sum +++ b/go.sum @@ -13,6 +13,8 @@ github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U= github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= +github.com/lib/pq v1.12.0 h1:mC1zeiNamwKBecjHarAr26c/+d8V5w/u4J0I/yASbJo= +github.com/lib/pq v1.12.0/go.mod h1:/p+8NSbOcwzAEI7wiMXFlgydTwcgTr3OSKMsD2BitpA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -20,12 +22,19 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/wneessen/go-mail v0.7.2 h1:xxPnhZ6IZLSgxShebmZ6DPKh1b6OJcoHfzy7UjOkzS8= +github.com/wneessen/go-mail v0.7.2/go.mod h1:+TkW6QP3EVkgTEqHtVmnAE/1MRhmzb8Y9/W3pweuS+k= golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE= golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc= golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610= golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= +golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk= +golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4= +golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U= +golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/internal/data/auth_provider.go b/internal/data/auth_provider.go new file mode 100644 index 0000000..92b14ba --- /dev/null +++ b/internal/data/auth_provider.go @@ -0,0 +1,15 @@ +package data + +type AuthProvider struct { + ID int64 `json:"id"` + Description string `json:"description"` + Active bool `json:"active"` +} + +type AuthProviderID int64 + +const ( + ProviderUnknown AuthProviderID = 0 + ProviderLocal AuthProviderID = 1 + ProviderIDAustria AuthProviderID = 2 +) diff --git a/internal/data/filters.go b/internal/data/filters.go new file mode 100644 index 0000000..1419910 --- /dev/null +++ b/internal/data/filters.go @@ -0,0 +1,77 @@ +package data + +import ( + "math" + "strings" + + "party.at/party/internal/validator" +) + +type Filters struct { + Page int + PageSize int + Sort string + SortSafelist []string +} + +func (f Filters) sortColumn() string { + for _, safeValue := range f.SortSafelist { + if f.Sort == safeValue { + return strings.TrimPrefix(f.Sort, "-") + } + } + + panic("unsafe sort parameter: " + f.Sort) +} + +// Return the sort direction ("ASC" or "DESC") depending on the prefix character of the +// Sort field. +func (f Filters) sortDirection() string { + if strings.HasPrefix(f.Sort, "-") { + return "DESC" + } + + return "ASC" +} + +func (f Filters) limit() int { + return f.PageSize +} + +func (f Filters) offset() int { + return (f.Page - 1) * f.PageSize +} + +func ValidateFilters(v *validator.Validator, f Filters) { + // Check that the page and page_size parameters contain sensible values. + v.Check(f.Page > 0, "page", "must be greater than zero") + v.Check(f.Page <= 10_000_000, "page", "must be a maximum of 10 million") + v.Check(f.PageSize > 0, "page_size", "must be greater than zero") + v.Check(f.PageSize <= 100, "page_size", "must be a maximum of 100") + + // Check that the sort parameter matches a value in the safelist. + v.Check(validator.In(f.Sort, f.SortSafelist...), "sort", "invalid sort value") +} + +type Metadata struct { + CurrentPage int `json:"current_page,omitempty"` + PageSize int `json:"page_size,omitempty"` + FirstPage int `json:"first_page,omitempty"` + LastPage int `json:"last_page,omitempty"` + TotalRecords int `json:"total_records,omitempty"` +} + +func calculateMetadata(totalRecords, page, pageSize int) Metadata { + if totalRecords == 0 { + // Note that we return an empty Metadata struct if there are no records. + return Metadata{} + } + + return Metadata{ + CurrentPage: page, + PageSize: pageSize, + FirstPage: 1, + LastPage: int(math.Ceil(float64(totalRecords) / float64(pageSize))), + TotalRecords: totalRecords, + } +} diff --git a/internal/data/issues.go b/internal/data/issues.go new file mode 100644 index 0000000..11009b7 --- /dev/null +++ b/internal/data/issues.go @@ -0,0 +1,242 @@ +package data + +import ( + "time" + "party.at/party/internal/validator" + "database/sql" + "errors" + "context" + "fmt" +) + +type Issue struct { + ID int64 `json:"id"` + Title string `json:"title"` + Description string `json:"description"` + StartTime time.Time `json:"start_time"` + EndTime time.Time `json:"end_time"` + Created time.Time `json:"created"` + Version int32 `json:"version"` +} + +func ValidateIssue(v *validator.Validator, issue *Issue) { + v.Check(issue.Title != "", "title", "must be provided") + v.Check(len(issue.Title) <= 500, "title", "must not be more than 500 bytes long") + v.Check(issue.Description != "", "description", "must be provided") + v.Check(len(issue.Description) <= 500, "description", "must not be greater than 500 bytes long") + // v.Check(issue.StartTime != "", "start_time", "must be provided") + // v.Check(len(issue.StartTime) <= 500, "start_time", "must not be more than 500 bytes long") + // v.Check(issue.EndTime != "", "end_time", "must be provided") + // v.Check(len(issue.EndTime) <= 500, "end_time", "must not be more than 500 bytes long") +} + +type IssueModel struct { + DB *sql.DB +} + +func (m IssueModel) Insert(issue *Issue) error { + query := ` +INSERT INTO issues (title, description, start_time, end_time) +VALUES ($1, $2, $3, $4) +RETURNING id, created, version` + + args := []interface{}{ + issue.Title, + issue.Description, + issue.StartTime, + issue.EndTime, + } + + return m.DB.QueryRow(query, args...).Scan( + &issue.ID, + &issue.Created, + &issue.Version, + ) +} + +// Add a placeholder method for fetching a specific record from the issues table. +func (m IssueModel) Get(id int64) (*Issue, error) { + if id < 1 { + return nil, ErrRecordNotFound + } + + // Define the SQL query for retrieving the issue data. + query :=` +SELECT id, title, description, start_time, end_time, created, version +FROM issues +WHERE id = $1` + + // Declare a Issue struct to hold the data returned by the query. + var issue Issue + + // Execute the query using the QueryRow() method, passing in the provided id value + // as a placeholder parameter, and scan the response data into the fields of the + // Issue struct. Importantly, notice that we need to convert the scan target for the + // genres column using the pq.Array() adapter function again. + err := m.DB.QueryRow(query, id).Scan( + &issue.ID, + &issue.Title, + &issue.Description, + &issue.StartTime, + &issue.EndTime, + &issue.Created, + &issue.Version, + ) + + // Handle any errors. If there was no matching issue found, Scan() will return + // a sql.ErrNoRows error. We check for this and return our custom ErrRecordNotFound + // error instead. + if err != nil { + switch { + case errors.Is(err, sql.ErrNoRows): + return nil, ErrRecordNotFound + default: + return nil, err + } + } + // Otherwise, return a pointer to the Issue struct. + return &issue, nil +} + +func (m IssueModel) Update(issue *Issue) error { + query := ` + UPDATE issues + SET title = $1, description = $2, start_time = $3, end_time = $4, version = version + 1 + WHERE id = $5 AND version = $6 + RETURNING version` + + // Create an args slice containing the values for the placeholder parameters. + args := []interface{}{ + issue.Title, + issue.Description, + issue.StartTime, + issue.EndTime, + issue.ID, + issue.Version, + } + + // Use the QueryRow() method to execute the query, passing in the args slice as a + // variadic parameter and scanning the new version value into the issue struct. + err := m.DB.QueryRow(query, args...).Scan(&issue.Version) + if err != nil { + switch { + case errors.Is(err, sql.ErrNoRows): + return ErrEditConflict + default: + return err + } + } + + return nil +} + +// Add a placeholder method for deleting a specific record from the issues table. +func (m IssueModel) Delete(id int64) error { + if id < 1 { + return ErrRecordNotFound + } + + // Construct the SQL query to delete the record. + query := ` + DELETE FROM issues + WHERE id = $1` + + // Execute the SQL query using the Exec() method, passing in the id variable as + // the value for the placeholder parameter. The Exec() method returns a sql.Result + // object. + result, err := m.DB.Exec(query, id) + if err != nil { + return err + } + + // Call the RowsAffected() method on the sql.Result object to get the number of rows + // affected by the query. + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + + // If no rows were affected, we know that the issues table didn't contain a record + // with the provided ID at the moment we tried to delete it. In that case we + // return an ErrRecordNotFound error. + if rowsAffected == 0 { + return ErrRecordNotFound + } + + return nil +} + +func (m IssueModel) GetAll(title string, filters Filters) ([]*Issue, Metadata, error) { + // Construct the SQL query to retrieve all issue records. + query := fmt.Sprintf(` + SELECT + COUNT(*) OVER(), id, title, description, start_time, end_time, created, version + FROM + issues + WHERE + (to_tsvector('simple', title) @@ plainto_tsquery('simple', $1) OR $1 = '') + ORDER BY + %s %s, id ASC + LIMIT + $2 + OFFSET + $3`, + filters.sortColumn(), + filters.sortDirection(), + ) + + // Create a context with a 3-second timeout. + ctx, cancel := context.WithTimeout(context.Background(), 3 * time.Second) + defer cancel() + + args := []interface{}{title, filters.limit(), filters.offset()} + + rows, err := m.DB.QueryContext(ctx, query, args...) + if err != nil { + return nil, Metadata{}, err + } + + // Importantly, defer a call to rows.Close() to ensure that the resultset is closed + // before GetAll() returns. + defer rows.Close() + + totalRecords := 0 + issues := []*Issue{} + + // Use rows.Next to iterate through the rows in the resultset. + for rows.Next() { + // Initialize an empty Issue struct to hold the data for an individual issue. + var issue Issue + + // Scan the values from the row into the Issue struct. Again, note that we're + // using the pq.Array() adapter on the genres field here. + err := rows.Scan( + &totalRecords, + &issue.ID, + &issue.Title, + &issue.Description, + &issue.StartTime, + &issue.EndTime, + &issue.Created, + &issue.Version, + ) + + if err != nil { + return nil, Metadata{}, err + } + + // Add the Issue struct to the slice. + issues = append(issues, &issue) + } + + // When the rows.Next() loop has finished, call rows.Err() to retrieve any error + // that was encountered during the iteration. + if err = rows.Err(); err != nil { + return nil, Metadata{}, err + } + + metadata := calculateMetadata(totalRecords, filters.Page, filters.PageSize) + + // If everything went OK, then return the slice of issues. + return issues, metadata, nil +} diff --git a/internal/data/models.go b/internal/data/models.go new file mode 100644 index 0000000..1e9d5f0 --- /dev/null +++ b/internal/data/models.go @@ -0,0 +1,29 @@ +package data + +import ( + "database/sql" + "errors" +) + +var ( + ErrRecordNotFound = errors.New("record not found") + ErrEditConflict = errors.New("edit conflict") +) + +type Models struct { + Users UserModel + UserIdentities UserIdentityModel + Issues IssueModel + Tokens TokenModel + Permissions PermissionModel +} + +func NewModels(db *sql.DB) Models { + return Models{ + Users: UserModel{DB: db}, + UserIdentities: UserIdentityModel{DB: db}, + Issues: IssueModel{DB: db}, + Tokens: TokenModel{DB: db}, + Permissions: PermissionModel{DB: db}, + } +} diff --git a/internal/data/permissions.go b/internal/data/permissions.go new file mode 100644 index 0000000..9469d3f --- /dev/null +++ b/internal/data/permissions.go @@ -0,0 +1,69 @@ +package data + +import ( + "context" + "database/sql" + "time" + "github.com/lib/pq" +) + +type Permissions []string + +func (p Permissions) Include(code string) bool { + for i := range p { + if code == p[i] { + return true + } + } + + return false +} + +type PermissionModel struct { + DB *sql.DB +} + +func (m PermissionModel) GetAllForUser(userID int64) (Permissions, error) { + query :=` +SELECT permissions.code +FROM permissions +INNER JOIN users_permissions ON users_permissions.permission_id = permissions.id +INNER JOIN users ON users_permissions.user_id = users.id +WHERE users.id = $1` + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + rows, err := m.DB.QueryContext(ctx, query, userID) + if err != nil { + return nil, err + } + defer rows.Close() + + var permissions Permissions + for rows.Next() { + var permission string + err := rows.Scan(&permission) + if err != nil { + return nil, err + } + + permissions = append(permissions, permission) + } + if err = rows.Err(); err != nil { + return nil, err + } + + return permissions, nil +} + +func (m PermissionModel) AddForUser(userID int64, codes ...string) error { + query :=` +INSERT INTO users_permissions +SELECT $1, permissions.id FROM permissions WHERE permissions.code = ANY($2)` + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + _, err := m.DB.ExecContext(ctx, query, userID, pq.Array(codes)) + return err +} diff --git a/internal/data/tokens.go b/internal/data/tokens.go new file mode 100644 index 0000000..46d8d87 --- /dev/null +++ b/internal/data/tokens.go @@ -0,0 +1,95 @@ +package data + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/base32" + "time" + "context" + "database/sql" + + "party.at/party/internal/validator" +) + +const ( + ScopeActivation = "activation" + ScopeAuthentication = "authentication" +) + +type Token struct { + Plaintext string `json:"token"` + Hash []byte `json:"-"` + UserID int64 `json:"-"` + UserIdentityID int64 `json:"-"` + Expiry time.Time `json:"expiry"` + Scope string `json:"-"` +} + +func generateToken(userID int64, userIdentityID int64, ttl time.Duration, scope string) (*Token, error) { + token := &Token{ + UserID: userID, + UserIdentityID: userIdentityID, + Expiry: time.Now().Add(ttl), + Scope: scope, + } + + // Initialize a zero-valued byte slice with a length of 16 bytes. + randomBytes := make([]byte, 16) + + _, err := rand.Read(randomBytes) + if err != nil { + return nil, err + } + + token.Plaintext = base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(randomBytes) + + hash := sha256.Sum256([]byte(token.Plaintext)) + token.Hash = hash[:] + + return token, nil +} + +func ValidateTokenPlaintext(v *validator.Validator, tokenPlaintext string) { + v.Check(tokenPlaintext != "", "token", "must be provided") + v.Check(len(tokenPlaintext) == 26, "token", "must be 26 bytes long") +} + +type TokenModel struct { + DB *sql.DB +} + +func (m TokenModel) New(userID int64, userIdentityID int64, ttl time.Duration, scope string) (*Token, error) { + token, err := generateToken(userID, userIdentityID, ttl, scope) + if err != nil { + return nil, err + } + + err = m.Insert(token) + return token, err +} + +func (m TokenModel) Insert(token *Token) error { + query :=` + INSERT INTO tokens (hash, user_id, user_identity_id, expiry, scope) + VALUES ($1, $2, $3, $4, $5)` + + args := []interface{}{token.Hash, token.UserID, token.UserIdentityID, token.Expiry, token.Scope} + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + _, err := m.DB.ExecContext(ctx, query, args...) + return err +} + +func (m TokenModel) DeleteAllForUser(scope string, userID int64) error { + query :=` +DELETE FROM tokens +WHERE scope = $1 AND user_id = $2` + + ctx, cancel := context.WithTimeout(context.Background(), 3 * time.Second) + defer cancel() + + _, err := m.DB.ExecContext(ctx, query, scope, userID) + return err +} diff --git a/internal/data/user_identities.go b/internal/data/user_identities.go new file mode 100644 index 0000000..72c6564 --- /dev/null +++ b/internal/data/user_identities.go @@ -0,0 +1,306 @@ +package data + +import ( + "golang.org/x/crypto/bcrypt" + "errors" + "party.at/party/internal/validator" + "database/sql" + "context" + "time" + "crypto/sha256" +) + +type UserIdentity struct { + ID int64 `json:"id"` + ProviderID int64 `json:"provider_id"` + UserID int64 `json:"user_id"` + ProviderUserID string `json:"provider_user_id"` + Password password `json:"-"` + Version int32 `json:"version"` +} + +type password struct { + plaintext *string + hash []byte +} + +func (p *password) Set(plaintextPassword string) error { + hash, err := bcrypt.GenerateFromPassword([]byte(plaintextPassword), 12) + if err != nil { + return err + } + + p.plaintext = &plaintextPassword + p.hash = hash + return nil +} + +func (p *password) Matches(plaintextPassword string) (bool, error) { + err := bcrypt.CompareHashAndPassword(p.hash, []byte(plaintextPassword)) + if err != nil { + switch { + case errors.Is(err, bcrypt.ErrMismatchedHashAndPassword): + return false, nil + default: + return false, err + } + } + + return true, nil +} + +func ValidatePasswordPlaintext(v *validator.Validator, password string) { + v.Check(password != "", "password", "must be provided") + v.Check(len(password) >= 8, "password", "must be at least 8 bytes long") + v.Check(len(password) <= 72, "password", "must not be more than 72 bytes long") +} + +func ValidateUserIdentity(v *validator.Validator, userIdentity *UserIdentity) { + v.Check(userIdentity.ProviderID == int64(ProviderLocal), "provider_id", "must be 1"); + + if userIdentity.ProviderID == int64(ProviderLocal) { + v.Check(userIdentity.ProviderUserID != "", "username", "must be provided") + v.Check(len(userIdentity.ProviderUserID) <= 500, "username", "must not be more than 500 bytes long") + } + + if userIdentity.Password.plaintext != nil { + ValidatePasswordPlaintext(v, + *userIdentity.Password.plaintext) + } + + if userIdentity.Password.hash == nil { + panic("missing password hash for user") + } +} + +type UserIdentityModel struct { + DB *sql.DB +} + +// func (m UserIdentityModel) Insert(userIdentity *UserIdentity) error { +// query := ` +// INSERT INTO user_identities (provider_id, user_id, provider_user, password) +// VALUES ($1, $2, $3, $4) +// RETURNING id, version` + +// args := []interface{}{ +// userIdentity.ProviderID, +// userIdentity.UserID, +// userIdentity.ProviderUserID, +// userIdentity.Password.hash, +// } + +// ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) +// defer cancel() + +// err := m.DB.QueryRowContext(ctx, query, args...).Scan(&userIdentity.ID, &userIdentity.Version) +// if err != nil { +// switch { +// case err.Error() == +// `pq: duplicate key value violates unique constraint "users_email_key"`: +// return ErrDuplicateEmail +// default: +// return err +// } +// } + +// return nil +// } + +func (m UserIdentityModel) Get(id int64) (*UserIdentity, error) { + if id < 1 { + return nil, ErrRecordNotFound + } + + query :=` +SELECT id, provider_id, user_id, provider_user, password, version +FROM user_identities +WHERE id = $1` + + // Declare a User struct to hold the data returned by the query. + var userIdentity UserIdentity + + // Execute the query using the QueryRow() method, passing in the provided id value + // as a placeholder parameter, and scan the response data into the fields of the + // User struct. Importantly, notice that we need to convert the scan target for the + // genres column using the pq.Array() adapter function again. + err := m.DB.QueryRow(query, id).Scan( + &userIdentity.ID, + &userIdentity.ProviderID, + &userIdentity.UserID, + &userIdentity.ProviderUserID, + &userIdentity.Password, + &userIdentity.Version, + ) + + if err != nil { + switch { + case errors.Is(err, sql.ErrNoRows): + return nil, ErrRecordNotFound + default: + return nil, err + } + } + // Otherwise, return a pointer to the User struct. + return &userIdentity, nil +} + +func (m UserIdentityModel) GetByUser(user_id int64) ([]*UserIdentity, error) { + if user_id < 1 { + return nil, ErrRecordNotFound + } + + query :=` +SELECT identity.id, identity.provider_id, identity.user_id, identity.provider_user_id, identity.password, identity.version +FROM user_identities identity +JOIN users u on identity.user_id = u.id +WHERE u.id = $1` + + // Execute the query using the QueryRow() method, passing in the provided id value + // as a placeholder parameter, and scan the response data into the fields of the + // User struct. Importantly, notice that we need to convert the scan target for the + // genres column using the pq.Array() adapter function again. + rows, err := m.DB.Query(query, user_id) + if err != nil { + switch { + case errors.Is(err, sql.ErrNoRows): + return nil, ErrRecordNotFound + default: + return nil, err + } + } + defer rows.Close() + + var identities []*UserIdentity + + for rows.Next() { + var userIdentity UserIdentity + + err := rows.Scan( + &userIdentity.ID, + &userIdentity.ProviderID, + &userIdentity.UserID, + &userIdentity.ProviderUserID, + &userIdentity.Password.hash, + &userIdentity.Version, + ) + if err != nil { + return nil, err + } + + identities = append(identities, &userIdentity) + } + + if len(identities) == 0 { + return nil, ErrRecordNotFound + } + + // Otherwise, return a pointer to the User struct. + return identities, nil +} + +func (m UserIdentityModel) GetForToken(tokenScope, tokenPlaintext string) (*UserIdentity, error) { + tokenHash := sha256.Sum256([]byte(tokenPlaintext)) + + query := ` +SELECT ui.id, ui.provider_id, ui.user_id, ui.provider_user_id, ui.password, ui.version +FROM user_identities ui +INNER JOIN tokens ON ui.id = tokens.user_identity_id +WHERE tokens.hash = $1 +AND tokens.scope = $2 +AND tokens.expiry > $3` + + args := []interface{}{tokenHash[:], tokenScope, time.Now()} + + var userIdentity UserIdentity + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + err := m.DB.QueryRowContext(ctx, query, args...).Scan( + &userIdentity.ID, + &userIdentity.ProviderID, + &userIdentity.UserID, + &userIdentity.ProviderUserID, + &userIdentity.Password.hash, + &userIdentity.Version, + ) + if err != nil { + switch { + case errors.Is(err, sql.ErrNoRows): + return nil, ErrRecordNotFound + default: + return nil, err + } + } + + return &userIdentity, nil +} + +func (m UserIdentityModel) Update(user *UserIdentity) error { + query := ` + UPDATE user_identities + SET provider_user_id = $1, password = $2, version = version + 1 + WHERE id = $3 AND version = $4 + RETURNING version` + + // Create an args slice containing the values for the placeholder parameters. + args := []interface{}{ + user.ProviderUserID, + user.Password.hash, + user.ID, + user.Version, + } + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + err := m.DB.QueryRowContext(ctx, query, args...).Scan(&user.Version) + if err != nil { + switch { + case err.Error() == `pq: duplicate key value violates unique constraint "users_email_key"`: + return ErrDuplicateEmail + case errors.Is(err, sql.ErrNoRows): + return ErrEditConflict + default: + return err + } + } + + return nil +} + +func (m UserIdentityModel) Delete(id int64) error { + if id < 1 { + return ErrRecordNotFound + } + + // Construct the SQL query to delete the record. + query := ` + DELETE FROM user_identities + WHERE id = $1` + + // Execute the SQL query using the Exec() method, passing in the id variable as + // the value for the placeholder parameter. The Exec() method returns a sql.Result + // object. + result, err := m.DB.Exec(query, id) + if err != nil { + return err + } + + // Call the RowsAffected() method on the sql.Result object to get the number of rows + // affected by the query. + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + + // If no rows were affected, we know that the issues table didn't contain a record + // with the provided ID at the moment we tried to delete it. In that case we + // return an ErrRecordNotFound error. + if rowsAffected == 0 { + return ErrRecordNotFound + } + + return nil +} diff --git a/internal/data/users.go b/internal/data/users.go new file mode 100644 index 0000000..58fbc3f --- /dev/null +++ b/internal/data/users.go @@ -0,0 +1,329 @@ +package data + +import ( + "context" + "time" + "party.at/party/internal/validator" + "database/sql" + "github.com/lib/pq" + "errors" + "crypto/sha256" +) + +var ( + ErrDuplicateEmail = errors.New("duplicate email") + ErrDuplicateUser = errors.New("duplicate username") +) + +var AnonymousUser = &User{} + +type User struct { + ID int64 `json:"id"` + Email string `json:"email"` + Name string `json:"name"` + AltName sql.NullString `json:"alt_name"` + Created time.Time `json:"created"` + LastLogin time.Time `json:"last_login"` + Activated bool `json:"activated"` + Version int32 `json:"-"` +} + +func (u *User) IsAnonymous() bool { + return u == AnonymousUser +} + +func ValidateEmail(v *validator.Validator, email string) { + v.Check(email != "", "email", "must be provided") + v.Check(validator.Matches(email, validator.EmailRX), "email" , "must be a valid email address") +} + +func ValidateUser(v *validator.Validator, user *User) { + ValidateEmail(v, user.Email) +} + +type UserModel struct { + DB *sql.DB +} + +func (m UserModel) ExecuteRegistrationTx(user *User, userIdentity *UserIdentity) error { + ctx, cancel := context.WithTimeout(context.Background(), 5 * time.Second) + defer cancel() + + tx, err := m.DB.BeginTx(ctx, nil) + if err != nil { + return err + } + defer tx.Rollback() + + query := ` +INSERT INTO users (email, name, alt_name) +VALUES ($1, $2, $3) +RETURNING id, created, last_login, version` + + args := []interface{}{ + user.Email, + user.Name, + user.AltName, + } + + err = tx.QueryRowContext(ctx, query, args...).Scan(&user.ID, &user.Created, &user.LastLogin, &user.Version) + if pgErr, ok := err.(*pq.Error); ok { + if pgErr.Code == "23505" { + if pgErr.Constraint == "users_email_key" { + return ErrDuplicateEmail + } + } + } + if err != nil { + return err + } + + userIdentity.UserID = user.ID + + // Insert Identity + query = ` +INSERT INTO user_identities (provider_id, user_id, provider_user_id, password) +VALUES ($1, $2, $3, $4) +RETURNING id, version` + + args = []interface{}{ + userIdentity.ProviderID, + userIdentity.UserID, + userIdentity.ProviderUserID, + userIdentity.Password.hash, + } + + err = tx.QueryRowContext(ctx, query, args...).Scan(&userIdentity.ID, &userIdentity.Version) + if pgErr, ok := err.(*pq.Error); ok { + if pgErr.Code == "23505" { + if pgErr.Constraint == "user_identities_provider_id_provider_user_id_key" { + return ErrDuplicateUser + } + } + } + if err != nil { + return err + } + + err = tx.Commit() + if err != nil { + return err + } + + return nil +} + +// func (m UserModel) Insert(user *User) error { +// query := ` +// INSERT INTO users (email, name, alt_name, activated) +// VALUES ($1, $2, $3, $4) +// RETURNING id, created, last_login, version` + +// args := []interface{}{ +// user.Email, +// user.Name, +// user.AltName, +// user.Activated, +// } + +// ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) +// defer cancel() + +// err := m.DB.QueryRowContext(ctx, query, args...).Scan(&user.ID, &user.Created, &user.LastLogin, &user.Version) +// if err != nil { +// switch { +// case err.Error() == +// `pq: duplicate key value violates unique constraint "users_email_key"`: +// return ErrDuplicateEmail +// default: +// return err +// } +// } + +// return nil +// } + +func (m UserModel) Get(id int64) (*User, error) { + if id < 1 { + return nil, ErrRecordNotFound + } + + // Define the SQL query for retrieving the issue data. + query :=` +SELECT id, email, name, alt_name, created, last_login, activated, version +FROM users +WHERE id = $1` + + // Declare a User struct to hold the data returned by the query. + var user User + + err := m.DB.QueryRow(query, id).Scan( + &user.ID, + &user.Email, + &user.Name, + &user.AltName, + &user.Created, + &user.LastLogin, + &user.Activated, + &user.Version, + ) + if err != nil { + switch { + case errors.Is(err, sql.ErrNoRows): + return nil, ErrRecordNotFound + default: + return nil, err + } + } + + return &user, nil +} + +func (m UserModel) GetByEmail(email string) (*User, error) { + query :=` +SELECT id, email, name, alt_name, created, last_login, activated, version +FROM users +WHERE email = $1` + + var user User + ctx, cancel := context.WithTimeout(context.Background(), 3 * time.Second) + defer cancel() + + err := m.DB.QueryRowContext(ctx, query, email).Scan( + &user.ID, + &user.Email, + &user.Name, + &user.AltName, + &user.Created, + &user.LastLogin, + &user.Activated, + &user.Version, + ) + + if err != nil { + switch { + case errors.Is(err, sql.ErrNoRows): + return nil, ErrRecordNotFound + default: + return nil, err + } + } + + return &user, nil +} + +func (m UserModel) GetForToken(tokenScope, tokenPlaintext string) (*User, error) { + // Calculate the SHA-256 hash of the plaintext token provided by the client. + // Remember that this returns a byte *array* with length 32, not a slice. + tokenHash := sha256.Sum256([]byte(tokenPlaintext)) + + // Set up the SQL query. + query :=` +SELECT users.id, users.created, users.name, users.email, users.activated, users.version +FROM users +INNER JOIN tokens ON users.id = tokens.user_id +WHERE tokens.hash = $1 +AND tokens.scope = $2 +AND tokens.expiry > $3` + + // Create a slice containing the query arguments. Notice how we use the [:] operator + // to get a slice containing the token hash, rather than passing in the array (which + // is not supported by the pq driver), and that we pass the current time as the + // value to check against the token expiry. + args := []interface{}{tokenHash[:], tokenScope, time.Now()} + var user User + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + // Execute the query, scanning the return values into a User struct. If no matching + // record is found we return an ErrRecordNotFound error. + err := m.DB.QueryRowContext(ctx, query, args...).Scan( + &user.ID, + &user.Created, + &user.Name, + &user.Email, + &user.Activated, + &user.Version, + ) + if err != nil { + switch { + case errors.Is(err, sql.ErrNoRows): + return nil, ErrRecordNotFound + default: + return nil, err + } + } + + // Return the matching user. + return &user, nil +} + +func (m UserModel) Update(user *User) error { + query := ` + UPDATE users + SET email = $1, name = $2, alt_name = $3, activated = $4, version = version + 1 + WHERE id = $5 AND version = $6 + RETURNING version` + + // Create an args slice containing the values for the placeholder parameters. + args := []interface{}{ + user.Email, + user.Name, + user.AltName, + user.Activated, + user.ID, + user.Version, + } + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + err := m.DB.QueryRowContext(ctx, query, args...).Scan(&user.Version) + if err != nil { + switch { + case err.Error() == `pq: duplicate key value violates unique constraint "users_email_key"`: + return ErrDuplicateEmail + case errors.Is(err, sql.ErrNoRows): + return ErrEditConflict + default: + return err + } + } + + return nil +} + +func (m UserModel) Delete(id int64) error { + if id < 1 { + return ErrRecordNotFound + } + + // Construct the SQL query to delete the record. + query := ` +DELETE FROM users +WHERE id = $1` + + // Execute the SQL query using the Exec() method, passing in the id variable as + // the value for the placeholder parameter. The Exec() method returns a sql.Result + // object. + result, err := m.DB.Exec(query, id) + if err != nil { + return err + } + + // Call the RowsAffected() method on the sql.Result object to get the number of rows + // affected by the query. + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + + // If no rows were affected, we know that the issues table didn't contain a record + // with the provided ID at the moment we tried to delete it. In that case we + // return an ErrRecordNotFound error. + if rowsAffected == 0 { + return ErrRecordNotFound + } + + return nil +} diff --git a/internal/jsonlog/jsonlog.go b/internal/jsonlog/jsonlog.go new file mode 100644 index 0000000..f39a8ec --- /dev/null +++ b/internal/jsonlog/jsonlog.go @@ -0,0 +1,124 @@ +package jsonlog + +import ( +"encoding/json" +"io" +"os" +"runtime/debug" +"sync" +"time" +) +// Define a Level type to represent the severity level for a log entry. +type Level int8 + +// Initialize constants which represent a specific severity level. We use the iota +// keyword as a shortcut to assign successive integer values to the constants. +const ( + LevelInfo Level = iota // Has the value 0. + LevelError // Has the value 1. + LevelFatal // Has the value 2. + LevelOff // Has the value 3. +) + +// Return a human-friendly string for the severity level. +func (l Level) String() string { + switch l { + case LevelInfo: + return "INFO" + case LevelError: + return "ERROR" + case LevelFatal: + return "FATAL" + default: + return "" + } +} + +// Define a custom Logger type. This holds the output destination that the log entries +// will be written to, the minimum severity level that log entries will be written for, +// plus a mutex for coordinating the writes. +type Logger struct { + out io.Writer + minLevel Level + mu sync.Mutex +} + +// Return a new Logger instance which writes log entries at or above a minimum severity +// level to a specific output destination. +func New(out io.Writer, minLevel Level) *Logger { + return &Logger{ + out: out, + minLevel: minLevel, + } +} + +// Declare some helper methods for writing log entries at the different levels. Notice +// that these all accept a map as the second parameter which can contain any arbitrary +// 'properties' that you want to appear in the log entry. +func (l *Logger) PrintInfo(message string, properties map[string]string) { + l.print(LevelInfo, message, properties) +} + +func (l *Logger) PrintError(err error, properties map[string]string) { + l.print(LevelError, err.Error(), properties) +} + +func (l *Logger) PrintFatal(err error, properties map[string]string) { + l.print(LevelFatal, err.Error(), properties) + os.Exit(1) // For entries at the FATAL level, we also terminate the application. +} + +// Print is an internal method for writing the log entry. +func (l *Logger) print(level Level, message string, properties map[string]string) (int, error) { + // If the severity level of the log entry is below the minimum severity for the + // logger, then return with no further action. + if level < l.minLevel { + return 0, nil + } + + // Declare an anonymous struct holding the data for the log entry. + aux := struct { + Level string `json:"level"` + Time string `json:"time"` + Message string `json:"message"` + Properties map[string]string `json:"properties,omitempty"` + Trace string `json:"trace,omitempty"` + }{ + Level: level.String(), + Time: time.Now().UTC().Format(time.RFC3339), + Message: message, + Properties: properties, + } + + // Include a stack trace for entries at the ERROR and FATAL levels. + if level >= LevelError { + aux.Trace = string(debug.Stack()) + } + + // Declare a line variable for holding the actual log entry text. + var line []byte + + // Marshal the anonymous struct to JSON and store it in the line variable. If there + // was a problem creating the JSON, set the contents of the log entry to be that + // plain-text error message instead. + line, err := json.Marshal(aux) + if err != nil { + line = []byte(LevelError.String() + ": unable to marshal log message:" + err.Error()) + } + + // Lock the mutex so that no two writes to the output destination cannot happen + // concurrently. If we don't do this, it's possible that the text for two or more + // log entries will be intermingled in the output. + l.mu.Lock() + defer l.mu.Unlock() + + // Write the log entry followed by a newline. + return l.out.Write(append(line, '\n')) +} + +// We also implement a Write() method on our Logger type so that it satisfies the +// io.Writer interface. This writes a log entry at the ERROR level with no additional +// properties. +func (l *Logger) Write(message []byte) (n int, err error) { + return l.print(LevelError, string(message), nil) +} diff --git a/internal/mailer/mailer.go b/internal/mailer/mailer.go new file mode 100644 index 0000000..ccb95ef --- /dev/null +++ b/internal/mailer/mailer.go @@ -0,0 +1,96 @@ +package mailer + +import ( + "bytes" + "embed" + "html/template" + "log" + // "time" + + "github.com/wneessen/go-mail" +) + +//go:embed "templates" +var templateFS embed.FS + +type Mailer struct { + dialer *mail.Client + sender string +} + +func New(host string, port int, username, password, sender string) Mailer { + // Initialize a new mail.Dialer instance with the given SMTP server settings. We + // also configure this to use a 5-second timeout whenever we send an email. + dialer, err := mail.NewClient(host, + mail.WithSMTPAuth(mail.SMTPAuthAutoDiscover), mail.WithTLSPortPolicy(mail.TLSMandatory), + mail.WithUsername(username), mail.WithPassword(password), + ) + if err != nil { + log.Fatalf("failed to deliver mail: %s", err) + } + + // dialer. = 5 * time.Second + // Return a Mailer instance containing the dialer and sender information. + return Mailer{ + dialer: dialer, + sender: sender, + } +} + +// Define a Send() method on the Mailer type. This takes the recipient email address +// as the first parameter, the name of the file containing the templates, and any +// dynamic data for the templates as an interface{} parameter. +func (m Mailer) Send(recipient, templateFile string, data interface{}) error { + // Use the ParseFS() method to parse the required template file from the embedded + // file system. + tmpl, err := template.New("email").ParseFS(templateFS, "templates/"+templateFile) + if err != nil { + return err + } + + // Execute the named template "subject", passing in the dynamic data and storing the + // result in a bytes.Buffer variable. + subject := new(bytes.Buffer) + err = tmpl.ExecuteTemplate(subject, "subject", data) + if err != nil { + return err + } + + // Follow the same pattern to execute the "plainBody" template and store the result + // in the plainBody variable. + plainBody := new(bytes.Buffer) + err = tmpl.ExecuteTemplate(plainBody, "plainBody", data) + if err != nil { + return err + } + + // And likewise with the "htmlBody" template. + htmlBody := new(bytes.Buffer) + err = tmpl.ExecuteTemplate(htmlBody, "htmlBody", data) + if err != nil { + return err + } + + // Use the mail.NewMessage() function to initialize a new mail.Message instance. + // Then we use the SetHeader() method to set the email recipient, sender and subject + // headers, the SetBody() method to set the plain-text body, and the AddAlternative() + // method to set the HTML body. It's important to note that AddAlternative() should + // always be called *after* SetBody(). + msg := mail.NewMsg() + msg.SetAddrHeader("To", recipient) + msg.SetAddrHeader("From", m.sender) + msg.SetGenHeader("Subject", subject.String()) + msg.SetBodyString("text/plain", plainBody.String()) + msg.AddAlternativeString("text/html", htmlBody.String()) + + // Call the DialAndSend() method on the dialer, passing in the message to send. This + // opens a connection to the SMTP server, sends the message, then closes the + // connection. If there is a timeout, it will return a "dial tcp: i/o timeout" + // error. + err = m.dialer.DialAndSend(msg) + if err != nil { + return err + } + + return nil +} diff --git a/internal/mailer/templates/user_welcome.tmpl b/internal/mailer/templates/user_welcome.tmpl new file mode 100644 index 0000000..25bb162 --- /dev/null +++ b/internal/mailer/templates/user_welcome.tmpl @@ -0,0 +1,35 @@ +{{define "subject"}}Welcome to Digitale Partei Österreich!{{end}} + +{{define "plainBody"}} +Hi, + +Thanks for signing up for a Digitale Partei Österreich account. We're excited to have you on board! +For future reference, your user ID number is {{.ID}}. + +Your activation token is {{.token}} + +Thanks, + +The DigitalePartei Team +{{end}} + +{{define "htmlBody"}} + + + + + + + + + +

Hi,

+

Thanks for signing up for a Digitale Partei Österreich account. We're excited to have you on board!

+

For future reference, your user ID number is {{.ID}}.

+

Your activation token is {{.token}}

+

Thanks,

+

The DigitalePartei Team

+ + + +{{end}} diff --git a/internal/validator/validator.go b/internal/validator/validator.go new file mode 100644 index 0000000..3cf621e --- /dev/null +++ b/internal/validator/validator.go @@ -0,0 +1,61 @@ +package validator + +import ( + "regexp" +) + +// Declare a regular expression for sanity checking the format of email addresses (we'll +// use this later in the book). If you're interested, this regular expression pattern is +// taken from https://html.spec.whatwg.org/#valid-e-mail-address. Note: if you're +// reading this in PDF or EPUB format and cannot see the full pattern, please see the +// note further down the page. +var ( + EmailRX = regexp.MustCompile("^[a-zA-Z0-9.!#$%&'*+\\/=?^_`{|}~-]+@[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(?:\\.[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*$") +) + +type Validator struct { + Errors map[string]string +} + +func New() *Validator { + return &Validator{Errors: make(map[string]string)} +} + +func (v *Validator) Valid() bool { + return len(v.Errors) == 0 +} + +func (v *Validator) AddError(key, message string) { + if _, exists := v.Errors[key]; !exists { + v.Errors[key] = message + } +} + +func (v *Validator) Check(ok bool, key, message string) { + if !ok { + v.AddError(key, message) + } +} + +func In(value string, list ...string) bool { + for i := range list { + if value == list[i] { + return true + } + } + return false +} + +func Matches(value string, rx *regexp.Regexp) bool { + return rx.MatchString(value) +} + +func Unique(values []string) bool { + uniqueValues := make(map[string]bool) + + for _, value := range values { + uniqueValues[value] = true + } + + return len(values) == len(uniqueValues) +} diff --git a/migrations/000001_create_users_table.down.sql b/migrations/000001_create_users_table.down.sql new file mode 100644 index 0000000..9b5985e --- /dev/null +++ b/migrations/000001_create_users_table.down.sql @@ -0,0 +1,5 @@ +DROP TABLE IF EXISTS user_identities; + +DROP TABLE IF EXISTS auth_provider; + +DROP TABLE IF EXISTS users; diff --git a/migrations/000001_create_users_table.up.sql b/migrations/000001_create_users_table.up.sql new file mode 100644 index 0000000..d29f312 --- /dev/null +++ b/migrations/000001_create_users_table.up.sql @@ -0,0 +1,65 @@ +CREATE TABLE IF NOT EXISTS users ( + id BIGSERIAL PRIMARY KEY, + email citext UNIQUE NOT NULL, + name TEXT NOT NULL, + alt_name TEXT, + created TIMESTAMPTZ NOT NULL DEFAULT now(), + last_login TIMESTAMPTZ NOT NULL DEFAULT '1970-01-01 00:00:00+00', + activated BOOL NOT NULL DEFAULT false, + version INT NOT NULL DEFAULT 1 +); + +CREATE TABLE IF NOT EXISTS auth_provider ( + id BIGSERIAL PRIMARY KEY, -- e.g., 'local', 'id_austria', + description TEXT NOT NULL, + active BOOLEAN DEFAULT false +); + +INSERT INTO auth_provider (description, active) VALUES ('local', true); +INSERT INTO auth_provider (description, active) VALUES ('id_austria', false); + +CREATE TABLE IF NOT EXISTS user_identities ( + id BIGSERIAL PRIMARY KEY, + provider_id BIGINT NOT NULL REFERENCES auth_provider(id), + user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE, + + -- For local: the username. For OIDC: the 'sub' (Subject ID) + provider_user_id TEXT NOT NULL, + + -- Nullable because OIDC users won't have a password in your DB + password bytea, + + version INT NOT NULL DEFAULT 1, + + UNIQUE(provider_id, provider_user_id) +); + +-- INSERT INTO users( +-- username, +-- email, +-- password, +-- name, +-- alt_name +-- ) +-- VALUES( +-- 'vik', +-- 'vikhenzo@gmail.com', +-- '$argon2id$v=19$m=65536,t=1,p=32$+dQ9uB7kKL7t7G3bI+TOMw$Wvic27W6SYH6Fx2Pp84irhVJ/blVh5qINlkv58bpgEc', +-- 'Vicente Ferrari Smith', +-- 'Vicente' +-- ); + +-- INSERT INTO users( +-- username, +-- email, +-- password, +-- name, +-- alt_name +-- ) +-- VALUES( +-- 'mkBflwkpe', +-- 'bellrebekaou@gmail.com', +-- '$argon2id$v=19$m=65536,t=1,p=1$BWcWdp8bhgS84LWqUCb2IA$DGF/FzQbSNHnfZraE9F2qvfdBGf5XB81+w00QgY/jG0', +-- 'zWkxKTNolTgJwO', +-- 'OahedOBLSo' +-- ); diff --git a/migrations/000002_create_additional_tables.down.sql b/migrations/000002_create_additional_tables.down.sql new file mode 100644 index 0000000..89ae30c --- /dev/null +++ b/migrations/000002_create_additional_tables.down.sql @@ -0,0 +1,8 @@ +DROP INDEX IF EXISTS idx_votes_option_id; +DROP INDEX IF EXISTS idx_vote_tokens_issue_id; +DROP INDEX IF EXISTS idx_options_issue_id; + +DROP TABLE IF EXISTS votes; +DROP TABLE IF EXISTS vote_tokens; +DROP TABLE IF EXISTS options; +DROP TABLE IF EXISTS issues; diff --git a/migrations/000002_create_additional_tables.up.sql b/migrations/000002_create_additional_tables.up.sql new file mode 100644 index 0000000..69c3d3a --- /dev/null +++ b/migrations/000002_create_additional_tables.up.sql @@ -0,0 +1,38 @@ +CREATE TABLE IF NOT EXISTS issues ( + id BIGSERIAL PRIMARY KEY, + title VARCHAR(255) NOT NULL, + description TEXT, + start_time TIMESTAMPTZ NOT NULL, + end_time TIMESTAMPTZ NOT NULL, + created TIMESTAMPTZ NOT NULL DEFAULT now(), + version INT NOT NULL DEFAULT 1 +); + +CREATE TABLE IF NOT EXISTS options ( + id BIGSERIAL PRIMARY KEY, + issue_id BIGINT NOT NULL REFERENCES issues(id) ON DELETE CASCADE, + label VARCHAR(255) NOT NULL, + created TIMESTAMPTZ NOT NULL DEFAULT now(), + version INT NOT NULL DEFAULT 1 +); + +CREATE TABLE IF NOT EXISTS vote_tokens ( + id BIGSERIAL PRIMARY KEY, + issue_id BIGINT NOT NULL REFERENCES issues(id) ON DELETE CASCADE, + token UUID NOT NULL UNIQUE DEFAULT gen_random_uuid(), + used BOOLEAN NOT NULL DEFAULT FALSE, + created TIMESTAMPTZ NOT NULL DEFAULT now(), + version INT NOT NULL DEFAULT 1 +); + +CREATE TABLE IF NOT EXISTS votes ( + id BIGINT GENERATED ALWAYS AS IDENTITY PRIMARY KEY, + token UUID NOT NULL UNIQUE REFERENCES vote_tokens(token) ON DELETE CASCADE, + option_id BIGINT NOT NULL REFERENCES options(id) ON DELETE CASCADE, + created TIMESTAMPTZ NOT NULL DEFAULT now(), + version INT NOT NULL DEFAULT 1 +); + +CREATE INDEX idx_votes_option_id ON votes(option_id); +CREATE INDEX idx_vote_tokens_issue_id ON vote_tokens(issue_id); +CREATE INDEX idx_options_issue_id ON options(issue_id); diff --git a/migrations/000003_add_issue_indices.down.sql b/migrations/000003_add_issue_indices.down.sql new file mode 100644 index 0000000..cd6db11 --- /dev/null +++ b/migrations/000003_add_issue_indices.down.sql @@ -0,0 +1,3 @@ +DROP INDEX IF EXISTS issue_title_idx; + +DROP INDEX IF EXISTS issue_description_idx; diff --git a/migrations/000003_add_issue_indices.up.sql b/migrations/000003_add_issue_indices.up.sql new file mode 100644 index 0000000..db9f9b2 --- /dev/null +++ b/migrations/000003_add_issue_indices.up.sql @@ -0,0 +1,3 @@ +CREATE INDEX IF NOT EXISTS issue_title_idx ON issues USING GIN (to_tsvector('simple', title)); + +CREATE INDEX IF NOT EXISTS issue_description_idx ON issues USING GIN (to_tsvector('simple', description)); diff --git a/migrations/000004_create_tokens_table.down.sql b/migrations/000004_create_tokens_table.down.sql new file mode 100644 index 0000000..1029218 --- /dev/null +++ b/migrations/000004_create_tokens_table.down.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS tokens; diff --git a/migrations/000004_create_tokens_table.up.sql b/migrations/000004_create_tokens_table.up.sql new file mode 100644 index 0000000..0f06c59 --- /dev/null +++ b/migrations/000004_create_tokens_table.up.sql @@ -0,0 +1,7 @@ +CREATE TABLE IF NOT EXISTS tokens ( + hash bytea PRIMARY KEY, + user_id bigint NOT NULL REFERENCES users ON DELETE CASCADE, + user_identity_id bigint REFERENCES user_identities ON DELETE CASCADE, + expiry timestamp(0) with time zone NOT NULL, + scope text NOT NULL +); diff --git a/migrations/000005_add_permissions.down.sql b/migrations/000005_add_permissions.down.sql new file mode 100644 index 0000000..60e29eb --- /dev/null +++ b/migrations/000005_add_permissions.down.sql @@ -0,0 +1,2 @@ +DROP TABLE IF EXISTS users_permissions; +DROP TABLE IF EXISTS permissions; diff --git a/migrations/000005_add_permissions.up.sql b/migrations/000005_add_permissions.up.sql new file mode 100644 index 0000000..f6fa9eb --- /dev/null +++ b/migrations/000005_add_permissions.up.sql @@ -0,0 +1,14 @@ +CREATE TABLE IF NOT EXISTS permissions ( + id bigserial PRIMARY KEY, + code text NOT NULL +); + +CREATE TABLE IF NOT EXISTS users_permissions ( + user_id bigint NOT NULL REFERENCES users ON DELETE CASCADE, + permission_id bigint NOT NULL REFERENCES permissions ON DELETE CASCADE, + PRIMARY KEY (user_id, permission_id) +); + +-- Add the two permissions to the table. +INSERT INTO permissions (code) +VALUES ('issues:read'), ('issues:write');