diff --git a/.vscode/launch.json b/.vscode/launch.json index f132fde..e2eb434 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -4,14 +4,14 @@ // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 "version": "0.2.0", "configurations": [ - { "name": "Launch API", "type": "go", "request": "launch", "mode": "auto", "program": "${workspaceFolder}/cmd/api", - "envFile": "${workspaceFolder}/.env" + "envFile": "${workspaceFolder}/.env", + "args": ["--env=development"] } ] } \ No newline at end of file diff --git a/cmd/api/helpers.go b/cmd/api/helpers.go index 076d39d..25841bd 100644 --- a/cmd/api/helpers.go +++ b/cmd/api/helpers.go @@ -11,7 +11,10 @@ import ( "strconv" "party.at/party/internal/validator" "github.com/julienschmidt/httprouter" - + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "crypto/rand" ) type envelope map[string]interface{} @@ -191,3 +194,19 @@ func (app *application) background(fn func()) { fn() }() } + +func GenerateIssueKey() ([]byte, int, []byte, error) { + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, 0, nil, err + } + + // Encode private key as PEM + privDER := x509.MarshalPKCS1PrivateKey(key) + privPEM := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: privDER, + }) + + return key.N.Bytes(), key.E, privPEM, err +} diff --git a/cmd/api/helpers_test.go b/cmd/api/helpers_test.go new file mode 100644 index 0000000..7ba564a --- /dev/null +++ b/cmd/api/helpers_test.go @@ -0,0 +1,32 @@ +package main + +import ( + "context" + "net/http" + "net/http/httptest" + "strconv" + "testing" + + "github.com/julienschmidt/httprouter" +) + +func TestReadIDParam(t *testing.T) { + app := newTestApplication(t) + + const test_id int64 = 3 + + r := httptest.NewRequest(http.MethodGet, "/v1/issues/" + strconv.FormatInt(test_id, 10), nil) + + params := httprouter.Params{{Key: "id", Value: "3"}} + ctx := context.WithValue(r.Context(), httprouter.ParamsKey, params) + r = r.WithContext(ctx) + + id, err := app.readIDParam(r) + if err != nil { + t.Fatal(err) + } + + if id != test_id { + t.Errorf("want %d, got id %d", test_id, id) + } +} diff --git a/cmd/api/issues.go b/cmd/api/issues.go index fe8772c..887d278 100644 --- a/cmd/api/issues.go +++ b/cmd/api/issues.go @@ -12,35 +12,30 @@ import ( "errors" "party.at/party/internal/data" "party.at/party/internal/validator" + "encoding/hex" ) func (app *application) listIssuesHandler(w http.ResponseWriter, r *http.Request) { - // To keep things consistent with our other handlers, we'll define an input struct - // to hold the expected values from the request query string. + var input struct { Title string data.Filters } - // Initialize a new Validator instance. + v := validator.New() - // Call r.URL.Query() to get the url.Values map containing the query string data. + qs := r.URL.Query() - // Use our helpers to extract the title and genres query string values, falling back - // to defaults of an empty string and an empty slice respectively if they are not - // provided by the client. + input.Title = app.readString(qs, "title", "") // input.Genres = app.readCSV(qs, "genres", []string{}) - // Get the page and page_size query string values as integers. Notice that we set - // the default page value to 1 and default page_size to 20, and that we pass the - // validator instance as the final argument here. + input.Filters.Page = app.readInt(qs, "page", 1, v) input.Filters.PageSize = app.readInt(qs, "page_size", 20, v) - // Extract the sort query string value, falling back to "id" if it is not provided - // by the client (which will imply a ascending sort on issue ID). + input.Filters.Sort = app.readString(qs, "sort", "id") input.Filters.SortSafelist = []string{"id", "-id", "title", "-title", "description", "-description"} @@ -77,11 +72,20 @@ func (app *application) createIssueHandler(w http.ResponseWriter, r *http.Reques return } + n, e, private_pem, err := GenerateIssueKey() + if err != nil { + app.serverErrorResponse(w, r, err) + return + } + issue := &data.Issue{ Title: input.Title, Description: input.Description, StartTime: input.StartTime, EndTime: input.EndTime, + N: n, + E: e, + PrivatePem: private_pem, } v := validator.New() @@ -219,3 +223,73 @@ func (app *application) deleteIssueHandler(w http.ResponseWriter, r *http.Reques app.serverErrorResponse(w, r, err) } } + +func (app *application) readIssuePubKeyHandler(w http.ResponseWriter, r *http.Request) { + id, err := app.readIDParam(r) + if err != nil { + app.notFoundResponse(w, r) + return + } + + issue, err := app.models.Issues.Get(id) + if err != nil { + switch { + case errors.Is(err, data.ErrRecordNotFound): + app.notFoundResponse(w, r) + default: + app.serverErrorResponse(w, r, err) + } + return + } + + type response struct { + N string `json:"n"` + E int `json:"e"` + } + + res := response{ + N: hex.EncodeToString(issue.N), + E: issue.E, + } + + + err = app.writeJSON(w, http.StatusOK, envelope{"public_key": res}, nil) + if err != nil { + app.serverErrorResponse(w, r, err) + } +} + +func (app *application) blindSignIssueVoteHandler(w http.ResponseWriter, r *http.Request) { + id, err := app.readIDParam(r) + if err != nil { + app.notFoundResponse(w, r) + return + } + + issue, err := app.models.Issues.Get(id) + if err != nil { + switch { + case errors.Is(err, data.ErrRecordNotFound): + app.notFoundResponse(w, r) + default: + app.serverErrorResponse(w, r, err) + } + return + } + + type response struct { + N string `json:"n"` + E int `json:"e"` + } + + res := response{ + N: hex.EncodeToString(issue.N), + E: issue.E, + } + + + err = app.writeJSON(w, http.StatusOK, envelope{"public_key": res}, nil) + if err != nil { + app.serverErrorResponse(w, r, err) + } +} diff --git a/cmd/api/issues_test.go b/cmd/api/issues_test.go new file mode 100644 index 0000000..c19d14b --- /dev/null +++ b/cmd/api/issues_test.go @@ -0,0 +1,65 @@ +package main + +import ( + "bytes" + "encoding/json" + "net/http" + "strconv" + "testing" + "time" +) + +func TestReadIssueHandler(t *testing.T) { + app := newTestApplication(t) + ts := newTestServer(t, app, app.routes()) + defer ts.Close() + + token := ts.registerAndLogin(t, uniqueEmail(), "pa$$word123") + + req := map[string]any{ + "title": "An old silent pond...", + "description": "A frog jumps into the pond", + "start_time": time.Now().Format(time.RFC3339), + "end_time": time.Now().Add(24 * time.Hour).Format(time.RFC3339), + } + + code, _, res := ts.postJSONWithToken(t, "/v1/issues", token, req) + if code != http.StatusCreated { + t.Fatalf("seed issue: want 201 got %d: %s", code, res) + } + + var resp struct { + Issue struct { + ID int64 `json:"id"` + } `json:"issue"` + } + json.Unmarshal(res, &resp) + issueID := resp.Issue.ID + + tests := []struct { + name string + urlPath string + wantCode int + wantBody []byte + }{ + {"Valid ID", "/v1/issues/" + strconv.Itoa(int(issueID)), http.StatusOK, []byte("An old silent pond...")}, + {"Non-existent ID", "/v1/issues/" + strconv.Itoa(int(issueID + 1)), http.StatusNotFound, nil}, + {"Negative ID", "/v1/issues/-1", http.StatusNotFound, nil}, + {"Decimal ID", "/v1/issues/1.23", http.StatusNotFound, nil}, + {"String ID", "/v1/issues/foo", http.StatusNotFound, nil}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + code, _, body := ts.getWithToken(t, tt.urlPath, token) + + if code != tt.wantCode { + t.Errorf("want %d; got %d", tt.wantCode, code) + } + + if !bytes.Contains(body, tt.wantBody) { + t.Errorf("want body to contain %q; got %q", tt.wantBody, string(body)) + } + }) + } +} diff --git a/cmd/api/main.go b/cmd/api/main.go index 483499c..d9d66c7 100644 --- a/cmd/api/main.go +++ b/cmd/api/main.go @@ -24,6 +24,10 @@ import ( "party.at/party/internal/data" "party.at/party/internal/mailer" "party.at/party/internal/jsonlog" + + "crypto/rand" + "crypto/aes" + "crypto/cipher" ) var version string @@ -80,12 +84,31 @@ var upgrader = websocket.Upgrader{ func main() { - fmt.Printf("Full Args: %v\n", os.Args) + // 32 bytes = AES-256 + var key [32]byte + if _, err := rand.Read(key[:]); err != nil { + panic(err) + } + + message := []byte("hello secure world") + + nonce, ciphertext, err := encrypt(key, message) + if err != nil { + panic(err) + } + + plaintext, err := decrypt(key, nonce, ciphertext) + if err != nil { + panic(err) + } + + fmt.Printf("Original: %s\n", message) + fmt.Printf("Decrypted: %s\n", plaintext) var cfg config flag.IntVar(&cfg.port, "port", 4000, "API server port") - flag.StringVar(&cfg.env, "env", "development", "Environment (development|staging|production)") + flag.StringVar(&cfg.env, "env", "production", "Environment (development|staging|production)") flag.StringVar(&cfg.db.dsn, "db-dsn", os.Getenv("PARTY_DB_DSN"), "PostgreSQL DSN") //addr := flag.String("addr", ":8443", "HTTP network address") @@ -97,7 +120,7 @@ func main() { flag.IntVar(&cfg.smtp.port, "smtp-port", 25, "SMTP port") flag.StringVar(&cfg.smtp.username, "smtp-username", "98cf60028d7fcb", "SMTP username") flag.StringVar(&cfg.smtp.password, "smtp-password", "b9d4a35372e971", "SMTP password") - flag.StringVar(&cfg.smtp.sender, "smtp-sender", "Greenlight ", "SMTP sender") + flag.StringVar(&cfg.smtp.sender, "smtp-sender", "DPÖ ", "SMTP sender") flag.Float64Var(&cfg.limiter.rps, "limiter-rps", 2, "Rate limiter maximum requests per second") flag.IntVar(&cfg.limiter.burst, "limiter-burst", 4, "Rate limiter maximum burst") @@ -126,7 +149,6 @@ func main() { if err != nil { logger.PrintFatal(err, nil) } - defer db.Close() expvar.NewString("version").Set(version) @@ -188,3 +210,43 @@ func openDB(cfg config) (*sql.DB, error) { return db, nil } + +func encrypt(key [32]byte, plaintext []byte) (nonce, ciphertext []byte, err error) { + block, err := aes.NewCipher(key[:]) + if err != nil { + return nil, nil, err + } + + aead, err := cipher.NewGCM(block) + if err != nil { + return nil, nil, err + } + + nonce = make([]byte, aead.NonceSize()) + if _, err := rand.Read(nonce); err != nil { + panic(err) + } + + + ciphertext = aead.Seal(nil, nonce, plaintext, nil) + return nonce, ciphertext, nil +} + +func decrypt(key [32]byte, nonce, ciphertext []byte) ([]byte, error) { + block, err := aes.NewCipher(key[:]) + if err != nil { + return nil, err + } + + aead, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + + plaintext, err := aead.Open(nil, nonce, ciphertext, nil) + if err != nil { + return nil, err + } + + return plaintext, nil +} diff --git a/cmd/api/middleware.go b/cmd/api/middleware.go index e6b74f8..8321fbf 100644 --- a/cmd/api/middleware.go +++ b/cmd/api/middleware.go @@ -64,7 +64,7 @@ func (app *application) rateLimit(next http.Handler) http.Handler { // Loop through all clients. If they haven't been seen within the last three // minutes, delete the corresponding entry from the map. for ip, client := range clients { - if time.Since(client.lastSeen) > 3*time.Minute { + if time.Since(client.lastSeen) > 3 * time.Minute { delete(clients, ip) } } diff --git a/cmd/api/routes.go b/cmd/api/routes.go index 90d1007..afd3d3d 100644 --- a/cmd/api/routes.go +++ b/cmd/api/routes.go @@ -27,15 +27,18 @@ func (app *application) routes() http.Handler { router.HandlerFunc(http.MethodPatch, "/v1/issues/:id", app.requirePermission("issues:write", app.updateIssueHandler)) router.HandlerFunc(http.MethodDelete, "/v1/issues/:id", app.requirePermission("issues:write", app.deleteIssueHandler)) - router.HandlerFunc(http.MethodPost, "/v1/users", app.createUserHandler) + router.HandlerFunc(http.MethodGet, "/v1/issues/:id/pubkey", app.requirePermission("issues:read", app.readIssuePubKeyHandler)) + router.HandlerFunc(http.MethodPost, "/v1/issues/:id/blind-sign", app.requirePermission("issues:read", app.blindSignIssueVoteHandler)) + + router.HandlerFunc(http.MethodPost, "/v1/users", app.createUserHandler) // router.HandlerFunc(http.MethodGet, "/v1/users/:id", app.readUserHandler) // router.HandlerFunc(http.MethodPatch, "/v1/users/:id", app.updateUserHandler) - router.HandlerFunc(http.MethodDelete, "/v1/users/:id", app.deleteUserHandler) - router.HandlerFunc(http.MethodPut, "/v1/users/activated", app.activateUserHandler) + router.HandlerFunc(http.MethodDelete, "/v1/users/:id", app.deleteUserHandler) + router.HandlerFunc(http.MethodPut, "/v1/users/activated", app.activateUserHandler) - router.HandlerFunc(http.MethodPost, "/v1/tokens/authentication", app.createAuthenticationTokenHandler) + router.HandlerFunc(http.MethodPost, "/v1/tokens/authentication", app.createAuthenticationTokenHandler) - router.Handler(http.MethodGet, "/debug/vars", expvar.Handler()) + router.Handler (http.MethodGet, "/debug/vars", expvar.Handler()) return app.metrics(app.recoverPanic(app.enableCORS(app.rateLimit(app.authenticate(router))))) } diff --git a/cmd/api/testutils_test.go b/cmd/api/testutils_test.go new file mode 100644 index 0000000..a844114 --- /dev/null +++ b/cmd/api/testutils_test.go @@ -0,0 +1,202 @@ +package main + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" + "os" + "encoding/json" + "bytes" + "fmt" + "time" + + "party.at/party/internal/data" + "party.at/party/internal/jsonlog" +) + +func newTestApplication(t *testing.T) *application { + cfg := config{} + cfg.db.dsn = "postgres://party:password@localhost:5432/party?sslmode=disable" + cfg.db.maxOpenConns = 25 + cfg.db.maxIdleConns = 25 + cfg.db.maxIdleTime = "15m" + cfg.limiter.enabled = false + cfg.env = "development" + + logger := jsonlog.New(os.Stdout, jsonlog.LevelInfo) + + db, err := openDB(cfg) + if err != nil { + logger.PrintFatal(err, nil) + } + t.Cleanup(func() {db.Close()}) + + return &application{ + logger: logger, + models: data.NewModels(db), + config: cfg, + } +} + +type testServer struct { + *httptest.Server + app *application +} + +func newTestServer(t *testing.T, app *application, h http.Handler) *testServer { + ts := httptest.NewTLSServer(h) + return &testServer{ts, app} +} + +func (ts *testServer) postJSON(t *testing.T, path string, body any) (int, http.Header, []byte) { + t.Helper() + + b, err := json.Marshal(body) + if err != nil { + t.Fatal(err) + } + + req, err := http.NewRequest(http.MethodPost, ts.URL+path, bytes.NewReader(b)) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := ts.Client().Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + + return resp.StatusCode, resp.Header, respBody +} + +func (ts *testServer) get(t *testing.T, path string) (int, http.Header, []byte) { + t.Helper() + + rs, err := ts.Client().Get(ts.URL + path) + if err != nil { + t.Fatal(err) + } + + defer rs.Body.Close() + body, err := io.ReadAll(rs.Body) + if err != nil { + t.Fatal(err) + } + + return rs.StatusCode, rs.Header, body +} + +// registers a user, activates them, logs in, returns the bearer token +func (ts *testServer) registerAndLogin(t *testing.T, email, password string) string { + t.Helper() + + // 1. Register + registerBody := map[string]any{ + "email": email, + "password": password, + "username": email, + "name": "Test User", + "alt_name" : "", + "provider_id": 1, + } + + code, _, body := ts.postJSON(t, "/v1/users", registerBody) + if code != http.StatusCreated { + t.Fatalf("register: want 201 got %d: %s", code, body) + } + + // 2. Activate — if your flow requires it, either hit the endpoint + // or directly flip the activated flag in the test DB + // ts.activateUser(t, email) + + // 3. Login + loginBody := map[string]string{"email": email, "password": password} + code, _, body = ts.postJSON(t, "/v1/tokens/authentication", loginBody) + if code != http.StatusCreated { + t.Fatalf("login: want 201 got %d: %s", code, body) + } + + // 4. Parse token out of response + var resp struct { + AuthenticationToken struct { + Token string `json:"token"` + } `json:"authentication_token"` + } + if err := json.Unmarshal(body, &resp); err != nil { + t.Fatalf("parse token: %v", err) + } + return resp.AuthenticationToken.Token +} + +// activateUser directly updates the DB — avoids needing a real email flow +// func (ts *testServer) activateUser(t *testing.T, email string) { +// t.Helper() +// _, err := ts.app.db.Exec("UPDATE users SET activated = true WHERE email = $1", email) +// if err != nil { +// t.Fatalf("activate user: %v", err) +// } +// } + +// like ts.get but adds Authorization header +func (ts *testServer) getWithToken(t *testing.T, path, token string) (int, http.Header, []byte) { + t.Helper() + req, err := http.NewRequest(http.MethodGet, ts.URL+path, nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Authorization", "Bearer "+token) + + rs, err := ts.Client().Do(req) + if err != nil { + t.Fatal(err) + } + + defer rs.Body.Close() + body, err := io.ReadAll(rs.Body) + if err != nil { + t.Fatal(err) + } + + return rs.StatusCode, rs.Header, body +} + +func (ts *testServer) postJSONWithToken(t *testing.T, path, token string, body any) (int, http.Header, []byte) { + t.Helper() + + b, err := json.Marshal(body) + if err != nil { + t.Fatal(err) + } + + req, err := http.NewRequest(http.MethodPost, ts.URL+path, bytes.NewReader(b)) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+token) + + rs, err := ts.Client().Do(req) + if err != nil { + t.Fatal(err) + } + + defer rs.Body.Close() + respBody, err := io.ReadAll(rs.Body) + if err != nil { + t.Fatal(err) + } + + return rs.StatusCode, rs.Header, respBody +} + +func uniqueEmail() string { + return fmt.Sprintf("test_%d@example.com", time.Now().UnixNano()) +} diff --git a/cmd/api/users.go b/cmd/api/users.go index 281daf7..26607c5 100644 --- a/cmd/api/users.go +++ b/cmd/api/users.go @@ -5,23 +5,22 @@ import ( "net/http" "party.at/party/internal/data" "party.at/party/internal/validator" - "database/sql" "time" ) func (app *application) createUserHandler(w http.ResponseWriter, r *http.Request) { var input struct { - ProviderId int64 `json:"provider_id"` - Username string `json:"username"` - PhoneNumber string `json:"phone_number"` - Country string `json:"country"` - Email string `json:"email"` - Password string `json:"password"` - Name string `json:"name"` - AltName string `json:"alt_name"` + ProviderId int64 `json:"provider_id"` + Username string `json:"username"` + PhoneNumber string `json:"phone_number"` + Country string `json:"country"` + Email string `json:"email"` + Password string `json:"password"` + Name string `json:"name"` + AltName *string `json:"alt_name"` DateOfBirth time.Time `json:"date_of_birth"` - Address string `json:"address"` + Address string `json:"address"` } err := app.readJSON(w, r, &input) @@ -35,12 +34,16 @@ func (app *application) createUserHandler(w http.ResponseWriter, r *http.Request PhoneNumber: input.PhoneNumber, Country: input.Country, Name: input.Name, - AltName: sql.NullString{String: input.AltName, Valid: true}, + AltName: input.AltName, DateOfBirth: input.DateOfBirth, Address: input.Address, Activated: false, } + if app.config.env == "development" { + user.Activated = true + } + userIdentity := &data.UserIdentity{ ProviderID: input.ProviderId, ProviderUserID: input.Username, @@ -85,27 +88,43 @@ func (app *application) createUserHandler(w http.ResponseWriter, r *http.Request return } - token, err := app.models.Tokens.New(user.ID, userIdentity.ID, 3 * 24 * time.Hour, data.ScopeActivation) + if app.config.env == "development" { + err = app.models.Permissions.AddForUser(user.ID, "issues:write") + if err != nil { + app.serverErrorResponse(w, r, err) + return + } + } + + if app.config.env == "production" { + token, err := app.models.Tokens.New(user.ID, userIdentity.ID, 3 * 24 * time.Hour, data.ScopeActivation) + if err != nil { + app.serverErrorResponse(w, r, err) + return + } + + app.background(func() { + data := map[string]interface{}{ + "token": token.Plaintext, + "userID": user.ID, + } + + err = app.mailer.Send(user.Email, "user_welcome.tmpl", data) + if err != nil { + app.logger.PrintError(err, nil) + } + }) + } + + authentication_token, err := app.models.Tokens.New(user.ID, userIdentity.ID, 24 * time.Hour, data.ScopeAuthentication) if err != nil { app.serverErrorResponse(w, r, err) return } - app.background(func() { - data := map[string]interface{}{ - "token": token.Plaintext, - "userID": user.ID, - } - - err = app.mailer.Send(user.Email, "user_welcome.tmpl", data) - if err != nil { - app.logger.PrintError(err, nil) - } - }) - // Write a JSON response containing the user data along with a 201 Created status // code. - err = app.writeJSON(w, http.StatusCreated, envelope{"user": user}, nil) + err = app.writeJSON(w, http.StatusCreated, envelope{"user": user, "authentication_token": authentication_token}, nil) if err != nil { app.serverErrorResponse(w, r, err) } diff --git a/internal/data/blind_sign_requests.go b/internal/data/blind_sign_requests.go new file mode 100644 index 0000000..80d5aa7 --- /dev/null +++ b/internal/data/blind_sign_requests.go @@ -0,0 +1,88 @@ +package data + +import ( + "time" + "database/sql" + "encoding/pem" + "context" + "errors" + "fmt" + "math/big" + "crypto/rsa" + "crypto/x509" +) + +type BlindSignRequest struct { + UserID int64 `json:"user_id"` + IssueID int64 `json:"issue_id"` + Created time.Time `json:"created"` +} + +type BlindSignRequestModel struct { + DB *sql.DB +} + +func (m BlindSignRequestModel) Insert(blind_sign *BlindSignRequest) error { + query := ` +INSERT INTO blind_sign_requests (user_id, issue_id) +VALUES ($1, $2) +RETURNING created` + + args := []interface{}{ + blind_sign.UserID, + blind_sign.IssueID, + } + + return m.DB.QueryRow(query, args...).Scan( + &blind_sign.Created, + ) +} + +func (m BlindSignRequestModel) BlindSign(issueID int64, blindedVoteBytes []byte) ([]byte, error) { + if issueID < 1 { + return nil, ErrRecordNotFound + } + + query := `SELECT rsa_private_pem FROM issues WHERE id = $1` + + ctx, cancel := context.WithTimeout(context.Background(), 3 * time.Second) + defer cancel() + + var pemBytes []byte + err := m.DB.QueryRowContext(ctx, query, issueID).Scan(&pemBytes) + if err != nil { + switch { + case errors.Is(err, sql.ErrNoRows): + return nil, ErrRecordNotFound + default: + return nil, err + } + } + + key, err := parsePrivateKey(pemBytes) + if err != nil { + return nil, fmt.Errorf("parse private key: %w", err) + } + + m_ := new(big.Int).SetBytes(blindedVoteBytes) + + // Validate range: m′ must be in [1, n-1] + one := big.NewInt(1) + if m_.Cmp(one) < 0 || m_.Cmp(key.N) >= 0 { + return nil, ErrInvalidBlindedVote + } + + // s′ = m′^d mod n + sig := new(big.Int).Exp(m_, key.D, key.N) + + return sig.Bytes(), nil +} + +func parsePrivateKey(pemBytes []byte) (*rsa.PrivateKey, error) { + block, _ := pem.Decode(pemBytes) + if block == nil { + return nil, errors.New("failed to decode PEM block") + } + return x509.ParsePKCS1PrivateKey(block.Bytes) +} + diff --git a/internal/data/issues.go b/internal/data/issues.go index 11009b7..367e10f 100644 --- a/internal/data/issues.go +++ b/internal/data/issues.go @@ -10,13 +10,16 @@ import ( ) type Issue struct { - ID int64 `json:"id"` - Title string `json:"title"` - Description string `json:"description"` - StartTime time.Time `json:"start_time"` - EndTime time.Time `json:"end_time"` - Created time.Time `json:"created"` - Version int32 `json:"version"` + ID int64 `json:"id"` + Title string `json:"title"` + Description string `json:"description"` + StartTime time.Time `json:"start_time"` + EndTime time.Time `json:"end_time"` + N []byte `json:"-"` + E int `json:"-"` + PrivatePem []byte `json:"-"` + Created time.Time `json:"created"` + Version int32 `json:"version"` } func ValidateIssue(v *validator.Validator, issue *Issue) { @@ -36,8 +39,8 @@ type IssueModel struct { func (m IssueModel) Insert(issue *Issue) error { query := ` -INSERT INTO issues (title, description, start_time, end_time) -VALUES ($1, $2, $3, $4) +INSERT INTO issues (title, description, start_time, end_time, rsa_n, rsa_e, rsa_private_pem) +VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING id, created, version` args := []interface{}{ @@ -45,6 +48,9 @@ RETURNING id, created, version` issue.Description, issue.StartTime, issue.EndTime, + issue.N, + issue.E, + issue.PrivatePem, } return m.DB.QueryRow(query, args...).Scan( @@ -60,32 +66,26 @@ func (m IssueModel) Get(id int64) (*Issue, error) { return nil, ErrRecordNotFound } - // Define the SQL query for retrieving the issue data. - query :=` -SELECT id, title, description, start_time, end_time, created, version + query := ` +SELECT id, title, description, start_time, end_time, rsa_n, rsa_e, rsa_private_pem, created, version FROM issues WHERE id = $1` - // Declare a Issue struct to hold the data returned by the query. var issue Issue - // Execute the query using the QueryRow() method, passing in the provided id value - // as a placeholder parameter, and scan the response data into the fields of the - // Issue struct. Importantly, notice that we need to convert the scan target for the - // genres column using the pq.Array() adapter function again. err := m.DB.QueryRow(query, id).Scan( &issue.ID, &issue.Title, &issue.Description, &issue.StartTime, &issue.EndTime, + &issue.N, + &issue.E, + &issue.PrivatePem, &issue.Created, &issue.Version, ) - // Handle any errors. If there was no matching issue found, Scan() will return - // a sql.ErrNoRows error. We check for this and return our custom ErrRecordNotFound - // error instead. if err != nil { switch { case errors.Is(err, sql.ErrNoRows): @@ -94,29 +94,37 @@ WHERE id = $1` return nil, err } } - // Otherwise, return a pointer to the Issue struct. + return &issue, nil } func (m IssueModel) Update(issue *Issue) error { query := ` UPDATE issues - SET title = $1, description = $2, start_time = $3, end_time = $4, version = version + 1 - WHERE id = $5 AND version = $6 + SET + title = $1, + description = $2, + start_time = $3, + end_time = $4, + rsa_n = $5, + rsa_e = $6, + rsa_private_pem = $7, + version = version + 1 + WHERE id = $8 AND version = $9 RETURNING version` - // Create an args slice containing the values for the placeholder parameters. args := []interface{}{ issue.Title, issue.Description, issue.StartTime, issue.EndTime, + issue.N, + issue.E, + issue.PrivatePem, issue.ID, issue.Version, } - // Use the QueryRow() method to execute the query, passing in the args slice as a - // variadic parameter and scanning the new version value into the issue struct. err := m.DB.QueryRow(query, args...).Scan(&issue.Version) if err != nil { switch { @@ -170,7 +178,7 @@ func (m IssueModel) GetAll(title string, filters Filters) ([]*Issue, Metadata, e // Construct the SQL query to retrieve all issue records. query := fmt.Sprintf(` SELECT - COUNT(*) OVER(), id, title, description, start_time, end_time, created, version + COUNT(*) OVER(), id, title, description, start_time, end_time, rsa_n, rsa_e, rsa_private_pem, created, version FROM issues WHERE @@ -217,6 +225,9 @@ func (m IssueModel) GetAll(title string, filters Filters) ([]*Issue, Metadata, e &issue.Description, &issue.StartTime, &issue.EndTime, + &issue.N, + &issue.E, + &issue.PrivatePem, &issue.Created, &issue.Version, ) diff --git a/internal/data/models.go b/internal/data/models.go index 1e9d5f0..a7ecda2 100644 --- a/internal/data/models.go +++ b/internal/data/models.go @@ -6,24 +6,27 @@ import ( ) var ( - ErrRecordNotFound = errors.New("record not found") - ErrEditConflict = errors.New("edit conflict") + ErrRecordNotFound = errors.New("record not found") + ErrEditConflict = errors.New("edit conflict") + ErrInvalidBlindedVote = errors.New("invalid blinded vote") ) type Models struct { - Users UserModel - UserIdentities UserIdentityModel - Issues IssueModel - Tokens TokenModel - Permissions PermissionModel + Users UserModel + UserIdentities UserIdentityModel + Issues IssueModel + Tokens TokenModel + Permissions PermissionModel + BlindSignRequests BlindSignRequestModel } func NewModels(db *sql.DB) Models { return Models{ - Users: UserModel{DB: db}, - UserIdentities: UserIdentityModel{DB: db}, - Issues: IssueModel{DB: db}, - Tokens: TokenModel{DB: db}, - Permissions: PermissionModel{DB: db}, + Users: UserModel{DB: db}, + UserIdentities: UserIdentityModel{DB: db}, + Issues: IssueModel{DB: db}, + Tokens: TokenModel{DB: db}, + Permissions: PermissionModel{DB: db}, + BlindSignRequests: BlindSignRequestModel{DB: db}, } } diff --git a/internal/data/permissions.go b/internal/data/permissions.go index 9469d3f..77248a4 100644 --- a/internal/data/permissions.go +++ b/internal/data/permissions.go @@ -31,7 +31,7 @@ INNER JOIN users_permissions ON users_permissions.permission_id = permissions.id INNER JOIN users ON users_permissions.user_id = users.id WHERE users.id = $1` - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 3 * time.Second) defer cancel() rows, err := m.DB.QueryContext(ctx, query, userID) @@ -62,7 +62,7 @@ func (m PermissionModel) AddForUser(userID int64, codes ...string) error { INSERT INTO users_permissions SELECT $1, permissions.id FROM permissions WHERE permissions.code = ANY($2)` - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 3 * time.Second) defer cancel() _, err := m.DB.ExecContext(ctx, query, userID, pq.Array(codes)) return err diff --git a/internal/data/tokens.go b/internal/data/tokens.go index 46d8d87..78ac62d 100644 --- a/internal/data/tokens.go +++ b/internal/data/tokens.go @@ -75,7 +75,7 @@ func (m TokenModel) Insert(token *Token) error { args := []interface{}{token.Hash, token.UserID, token.UserIdentityID, token.Expiry, token.Scope} - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 3 * time.Second) defer cancel() _, err := m.DB.ExecContext(ctx, query, args...) diff --git a/internal/data/user_identities.go b/internal/data/user_identities.go index 3e2bf96..9d7ae34 100644 --- a/internal/data/user_identities.go +++ b/internal/data/user_identities.go @@ -90,7 +90,7 @@ type UserIdentityModel struct { // userIdentity.Password.hash, // } -// ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) +// ctx, cancel := context.WithTimeout(context.Background(), 3 * time.Second) // defer cancel() // err := m.DB.QueryRowContext(ctx, query, args...).Scan(&userIdentity.ID, &userIdentity.Version) @@ -214,7 +214,7 @@ AND tokens.expiry > $3` var userIdentity UserIdentity - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 3 * time.Second) defer cancel() err := m.DB.QueryRowContext(ctx, query, args...).Scan( @@ -252,7 +252,7 @@ func (m UserIdentityModel) Update(user *UserIdentity) error { user.Version, } - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 3 * time.Second) defer cancel() err := m.DB.QueryRowContext(ctx, query, args...).Scan(&user.Version) diff --git a/internal/data/users.go b/internal/data/users.go index 77925f9..3032a71 100644 --- a/internal/data/users.go +++ b/internal/data/users.go @@ -23,7 +23,7 @@ type User struct { PhoneNumber string `json:"phone_number"` Country string `json:"country"` Name string `json:"name"` - AltName sql.NullString `json:"alt_name"` + AltName *string `json:"alt_name"` DateOfBirth time.Time `json:"date_of_birth"` Address string `json:"address"` Created time.Time `json:"created"` @@ -60,8 +60,8 @@ func (m UserModel) ExecuteRegistrationTx(user *User, userIdentity *UserIdentity) defer tx.Rollback() query := ` -INSERT INTO users (email, phone_number, country, name, alt_name, date_of_birth, address) -VALUES ($1, $2, $3, $4, $5, $6, $7) +INSERT INTO users (email, phone_number, country, name, alt_name, date_of_birth, address, activated) +VALUES ($1, $2, $3, $4, $5, $6, $7, $8) RETURNING id, created, last_login, version` args := []interface{}{ @@ -72,6 +72,7 @@ RETURNING id, created, last_login, version` user.AltName, user.DateOfBirth, user.Address, + user.Activated, } err = tx.QueryRowContext(ctx, query, args...).Scan(&user.ID, &user.Created, &user.LastLogin, &user.Version) @@ -163,7 +164,7 @@ WHERE id = $1` func (m UserModel) GetByEmail(email string) (*User, error) { query :=` -SELECT id, email, phone, country, name, alt_name, date_of_birth, address, created, last_login, activated, version +SELECT id, email, phone_number, country, name, alt_name, date_of_birth, address, created, last_login, activated, version FROM users WHERE email = $1` @@ -205,7 +206,17 @@ func (m UserModel) GetForToken(tokenScope, tokenPlaintext string) (*User, error) // Set up the SQL query. query :=` -SELECT users.id, users.email, user.phone_number, user.country, users.name, users.date_of_birth, users.address, users.created, users.activated, users.version +SELECT + users.id, + users.email, + users.phone_number, + users.country, + users.name, + users.date_of_birth, + users.address, + users.created, + users.activated, + users.version FROM users INNER JOIN tokens ON users.id = tokens.user_id WHERE tokens.hash = $1 @@ -218,7 +229,7 @@ AND tokens.expiry > $3` // value to check against the token expiry. args := []interface{}{tokenHash[:], tokenScope, time.Now()} var user User - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 3 * time.Second) defer cancel() // Execute the query, scanning the return values into a User struct. If no matching @@ -251,7 +262,16 @@ AND tokens.expiry > $3` func (m UserModel) Update(user *User) error { query := ` UPDATE users - SET email = $1, phone_number = $2, country = $3, name = $4, alt_name = $5, date_of_birth = $6, address = $7, activated = $8, version = version + 1 + SET + email = $1, + phone_number = $2, + country = $3, + name = $4, + alt_name = $5, + date_of_birth = $6, + address = $7, + activated = $8, + version = version + 1 WHERE id = $9 AND version = $10 RETURNING version` @@ -269,7 +289,7 @@ func (m UserModel) Update(user *User) error { user.Version, } - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 3 * time.Second) defer cancel() err := m.DB.QueryRowContext(ctx, query, args...).Scan(&user.Version) diff --git a/internal/jsonlog/jsonlog.go b/internal/jsonlog/jsonlog.go index f39a8ec..9d0b15a 100644 --- a/internal/jsonlog/jsonlog.go +++ b/internal/jsonlog/jsonlog.go @@ -1,26 +1,23 @@ package jsonlog import ( -"encoding/json" -"io" -"os" -"runtime/debug" -"sync" -"time" + "encoding/json" + "io" + "os" + "runtime/debug" + "sync" + "time" ) -// Define a Level type to represent the severity level for a log entry. + type Level int8 -// Initialize constants which represent a specific severity level. We use the iota -// keyword as a shortcut to assign successive integer values to the constants. const ( - LevelInfo Level = iota // Has the value 0. - LevelError // Has the value 1. - LevelFatal // Has the value 2. - LevelOff // Has the value 3. + LevelInfo Level = iota + LevelError + LevelFatal + LevelOff ) -// Return a human-friendly string for the severity level. func (l Level) String() string { switch l { case LevelInfo: @@ -34,17 +31,12 @@ func (l Level) String() string { } } -// Define a custom Logger type. This holds the output destination that the log entries -// will be written to, the minimum severity level that log entries will be written for, -// plus a mutex for coordinating the writes. type Logger struct { out io.Writer minLevel Level mu sync.Mutex } -// Return a new Logger instance which writes log entries at or above a minimum severity -// level to a specific output destination. func New(out io.Writer, minLevel Level) *Logger { return &Logger{ out: out, @@ -52,9 +44,6 @@ func New(out io.Writer, minLevel Level) *Logger { } } -// Declare some helper methods for writing log entries at the different levels. Notice -// that these all accept a map as the second parameter which can contain any arbitrary -// 'properties' that you want to appear in the log entry. func (l *Logger) PrintInfo(message string, properties map[string]string) { l.print(LevelInfo, message, properties) } @@ -65,18 +54,14 @@ func (l *Logger) PrintError(err error, properties map[string]string) { func (l *Logger) PrintFatal(err error, properties map[string]string) { l.print(LevelFatal, err.Error(), properties) - os.Exit(1) // For entries at the FATAL level, we also terminate the application. + os.Exit(1) } -// Print is an internal method for writing the log entry. func (l *Logger) print(level Level, message string, properties map[string]string) (int, error) { - // If the severity level of the log entry is below the minimum severity for the - // logger, then return with no further action. if level < l.minLevel { return 0, nil } - // Declare an anonymous struct holding the data for the log entry. aux := struct { Level string `json:"level"` Time string `json:"time"` @@ -90,35 +75,23 @@ func (l *Logger) print(level Level, message string, properties map[string]string Properties: properties, } - // Include a stack trace for entries at the ERROR and FATAL levels. if level >= LevelError { aux.Trace = string(debug.Stack()) } - // Declare a line variable for holding the actual log entry text. var line []byte - // Marshal the anonymous struct to JSON and store it in the line variable. If there - // was a problem creating the JSON, set the contents of the log entry to be that - // plain-text error message instead. line, err := json.Marshal(aux) if err != nil { line = []byte(LevelError.String() + ": unable to marshal log message:" + err.Error()) } - // Lock the mutex so that no two writes to the output destination cannot happen - // concurrently. If we don't do this, it's possible that the text for two or more - // log entries will be intermingled in the output. l.mu.Lock() defer l.mu.Unlock() - // Write the log entry followed by a newline. return l.out.Write(append(line, '\n')) } -// We also implement a Write() method on our Logger type so that it satisfies the -// io.Writer interface. This writes a log entry at the ERROR level with no additional -// properties. func (l *Logger) Write(message []byte) (n int, err error) { return l.print(LevelError, string(message), nil) } diff --git a/internal/mailer/templates/user_welcome.tmpl b/internal/mailer/templates/user_welcome.tmpl index 25bb162..38a0088 100644 --- a/internal/mailer/templates/user_welcome.tmpl +++ b/internal/mailer/templates/user_welcome.tmpl @@ -26,7 +26,7 @@ The DigitalePartei Team

Hi,

Thanks for signing up for a Digitale Partei Österreich account. We're excited to have you on board!

For future reference, your user ID number is {{.ID}}.

-

Your activation token is {{.token}}

+

Click here to activate!

Thanks,

The DigitalePartei Team

diff --git a/migrations/000001_create_users_table.up.sql b/migrations/000001_create_users_table.up.sql index 5b453aa..9192ea5 100644 --- a/migrations/000001_create_users_table.up.sql +++ b/migrations/000001_create_users_table.up.sql @@ -13,18 +13,18 @@ CREATE TABLE IF NOT EXISTS users ( version INT NOT NULL DEFAULT 1 ); -CREATE TABLE IF NOT EXISTS auth_provider ( +CREATE TABLE IF NOT EXISTS auth_providers ( id BIGSERIAL PRIMARY KEY, -- e.g., 'local', 'id_austria', description TEXT NOT NULL, active BOOLEAN DEFAULT false ); -INSERT INTO auth_provider (description, active) VALUES ('local', true); -INSERT INTO auth_provider (description, active) VALUES ('id_austria', false); +INSERT INTO auth_providers (description, active) VALUES ('local', true); +INSERT INTO auth_providers (description, active) VALUES ('id_austria', false); CREATE TABLE IF NOT EXISTS user_identities ( id BIGSERIAL PRIMARY KEY, - provider_id BIGINT NOT NULL REFERENCES auth_provider(id), + provider_id BIGINT NOT NULL REFERENCES auth_providers(id), user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE, -- For local: the username. For OIDC: the 'sub' (Subject ID) @@ -33,9 +33,9 @@ CREATE TABLE IF NOT EXISTS user_identities ( -- Nullable because OIDC users won't have a password in your DB password bytea, - version INT NOT NULL DEFAULT 1, + version INT NOT NULL DEFAULT 1 - UNIQUE(provider_id, provider_user_id) + -- UNIQUE(provider_id, provider_user_id) ); -- INSERT INTO users( diff --git a/migrations/000002_create_additional_tables.down.sql b/migrations/000002_create_additional_tables.down.sql index 89ae30c..e514226 100644 --- a/migrations/000002_create_additional_tables.down.sql +++ b/migrations/000002_create_additional_tables.down.sql @@ -3,6 +3,7 @@ DROP INDEX IF EXISTS idx_vote_tokens_issue_id; DROP INDEX IF EXISTS idx_options_issue_id; DROP TABLE IF EXISTS votes; +DROP TABLE IF EXISTS blind_sign_requests; DROP TABLE IF EXISTS vote_tokens; DROP TABLE IF EXISTS options; DROP TABLE IF EXISTS issues; diff --git a/migrations/000002_create_additional_tables.up.sql b/migrations/000002_create_additional_tables.up.sql index 69c3d3a..e225d53 100644 --- a/migrations/000002_create_additional_tables.up.sql +++ b/migrations/000002_create_additional_tables.up.sql @@ -4,6 +4,9 @@ CREATE TABLE IF NOT EXISTS issues ( description TEXT, start_time TIMESTAMPTZ NOT NULL, end_time TIMESTAMPTZ NOT NULL, + rsa_n BYTEA, + rsa_e INT, + rsa_private_pem BYTEA, created TIMESTAMPTZ NOT NULL DEFAULT now(), version INT NOT NULL DEFAULT 1 ); @@ -25,6 +28,13 @@ CREATE TABLE IF NOT EXISTS vote_tokens ( version INT NOT NULL DEFAULT 1 ); +CREATE TABLE IF NOT EXISTS blind_sign_requests ( + user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE, + issue_id BIGINT NOT NULL REFERENCES issues(id) ON DELETE CASCADE, + created TIMESTAMPTZ NOT NULL DEFAULT now(), + PRIMARY KEY (user_id, issue_id) +); + CREATE TABLE IF NOT EXISTS votes ( id BIGINT GENERATED ALWAYS AS IDENTITY PRIMARY KEY, token UUID NOT NULL UNIQUE REFERENCES vote_tokens(token) ON DELETE CASCADE,