This commit is contained in:
Vicente Ferrari Smith 2026-04-28 19:46:06 +02:00
parent 173c01b01b
commit 5772a7d855
22 changed files with 731 additions and 149 deletions

4
.vscode/launch.json vendored
View File

@ -4,14 +4,14 @@
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0", "version": "0.2.0",
"configurations": [ "configurations": [
{ {
"name": "Launch API", "name": "Launch API",
"type": "go", "type": "go",
"request": "launch", "request": "launch",
"mode": "auto", "mode": "auto",
"program": "${workspaceFolder}/cmd/api", "program": "${workspaceFolder}/cmd/api",
"envFile": "${workspaceFolder}/.env" "envFile": "${workspaceFolder}/.env",
"args": ["--env=development"]
} }
] ]
} }

View File

@ -11,7 +11,10 @@ import (
"strconv" "strconv"
"party.at/party/internal/validator" "party.at/party/internal/validator"
"github.com/julienschmidt/httprouter" "github.com/julienschmidt/httprouter"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"crypto/rand"
) )
type envelope map[string]interface{} type envelope map[string]interface{}
@ -191,3 +194,19 @@ func (app *application) background(fn func()) {
fn() fn()
}() }()
} }
func GenerateIssueKey() ([]byte, int, []byte, error) {
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, 0, nil, err
}
// Encode private key as PEM
privDER := x509.MarshalPKCS1PrivateKey(key)
privPEM := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: privDER,
})
return key.N.Bytes(), key.E, privPEM, err
}

32
cmd/api/helpers_test.go Normal file
View File

@ -0,0 +1,32 @@
package main
import (
"context"
"net/http"
"net/http/httptest"
"strconv"
"testing"
"github.com/julienschmidt/httprouter"
)
func TestReadIDParam(t *testing.T) {
app := newTestApplication(t)
const test_id int64 = 3
r := httptest.NewRequest(http.MethodGet, "/v1/issues/" + strconv.FormatInt(test_id, 10), nil)
params := httprouter.Params{{Key: "id", Value: "3"}}
ctx := context.WithValue(r.Context(), httprouter.ParamsKey, params)
r = r.WithContext(ctx)
id, err := app.readIDParam(r)
if err != nil {
t.Fatal(err)
}
if id != test_id {
t.Errorf("want %d, got id %d", test_id, id)
}
}

View File

@ -12,35 +12,30 @@ import (
"errors" "errors"
"party.at/party/internal/data" "party.at/party/internal/data"
"party.at/party/internal/validator" "party.at/party/internal/validator"
"encoding/hex"
) )
func (app *application) listIssuesHandler(w http.ResponseWriter, r *http.Request) { func (app *application) listIssuesHandler(w http.ResponseWriter, r *http.Request) {
// To keep things consistent with our other handlers, we'll define an input struct
// to hold the expected values from the request query string.
var input struct { var input struct {
Title string Title string
data.Filters data.Filters
} }
// Initialize a new Validator instance.
v := validator.New() v := validator.New()
// Call r.URL.Query() to get the url.Values map containing the query string data.
qs := r.URL.Query() qs := r.URL.Query()
// Use our helpers to extract the title and genres query string values, falling back
// to defaults of an empty string and an empty slice respectively if they are not
// provided by the client.
input.Title = app.readString(qs, "title", "") input.Title = app.readString(qs, "title", "")
// input.Genres = app.readCSV(qs, "genres", []string{}) // input.Genres = app.readCSV(qs, "genres", []string{})
// Get the page and page_size query string values as integers. Notice that we set
// the default page value to 1 and default page_size to 20, and that we pass the
// validator instance as the final argument here.
input.Filters.Page = app.readInt(qs, "page", 1, v) input.Filters.Page = app.readInt(qs, "page", 1, v)
input.Filters.PageSize = app.readInt(qs, "page_size", 20, v) input.Filters.PageSize = app.readInt(qs, "page_size", 20, v)
// Extract the sort query string value, falling back to "id" if it is not provided
// by the client (which will imply a ascending sort on issue ID).
input.Filters.Sort = app.readString(qs, "sort", "id") input.Filters.Sort = app.readString(qs, "sort", "id")
input.Filters.SortSafelist = []string{"id", "-id", "title", "-title", "description", "-description"} input.Filters.SortSafelist = []string{"id", "-id", "title", "-title", "description", "-description"}
@ -77,11 +72,20 @@ func (app *application) createIssueHandler(w http.ResponseWriter, r *http.Reques
return return
} }
n, e, private_pem, err := GenerateIssueKey()
if err != nil {
app.serverErrorResponse(w, r, err)
return
}
issue := &data.Issue{ issue := &data.Issue{
Title: input.Title, Title: input.Title,
Description: input.Description, Description: input.Description,
StartTime: input.StartTime, StartTime: input.StartTime,
EndTime: input.EndTime, EndTime: input.EndTime,
N: n,
E: e,
PrivatePem: private_pem,
} }
v := validator.New() v := validator.New()
@ -219,3 +223,73 @@ func (app *application) deleteIssueHandler(w http.ResponseWriter, r *http.Reques
app.serverErrorResponse(w, r, err) app.serverErrorResponse(w, r, err)
} }
} }
func (app *application) readIssuePubKeyHandler(w http.ResponseWriter, r *http.Request) {
id, err := app.readIDParam(r)
if err != nil {
app.notFoundResponse(w, r)
return
}
issue, err := app.models.Issues.Get(id)
if err != nil {
switch {
case errors.Is(err, data.ErrRecordNotFound):
app.notFoundResponse(w, r)
default:
app.serverErrorResponse(w, r, err)
}
return
}
type response struct {
N string `json:"n"`
E int `json:"e"`
}
res := response{
N: hex.EncodeToString(issue.N),
E: issue.E,
}
err = app.writeJSON(w, http.StatusOK, envelope{"public_key": res}, nil)
if err != nil {
app.serverErrorResponse(w, r, err)
}
}
func (app *application) blindSignIssueVoteHandler(w http.ResponseWriter, r *http.Request) {
id, err := app.readIDParam(r)
if err != nil {
app.notFoundResponse(w, r)
return
}
issue, err := app.models.Issues.Get(id)
if err != nil {
switch {
case errors.Is(err, data.ErrRecordNotFound):
app.notFoundResponse(w, r)
default:
app.serverErrorResponse(w, r, err)
}
return
}
type response struct {
N string `json:"n"`
E int `json:"e"`
}
res := response{
N: hex.EncodeToString(issue.N),
E: issue.E,
}
err = app.writeJSON(w, http.StatusOK, envelope{"public_key": res}, nil)
if err != nil {
app.serverErrorResponse(w, r, err)
}
}

65
cmd/api/issues_test.go Normal file
View File

@ -0,0 +1,65 @@
package main
import (
"bytes"
"encoding/json"
"net/http"
"strconv"
"testing"
"time"
)
func TestReadIssueHandler(t *testing.T) {
app := newTestApplication(t)
ts := newTestServer(t, app, app.routes())
defer ts.Close()
token := ts.registerAndLogin(t, uniqueEmail(), "pa$$word123")
req := map[string]any{
"title": "An old silent pond...",
"description": "A frog jumps into the pond",
"start_time": time.Now().Format(time.RFC3339),
"end_time": time.Now().Add(24 * time.Hour).Format(time.RFC3339),
}
code, _, res := ts.postJSONWithToken(t, "/v1/issues", token, req)
if code != http.StatusCreated {
t.Fatalf("seed issue: want 201 got %d: %s", code, res)
}
var resp struct {
Issue struct {
ID int64 `json:"id"`
} `json:"issue"`
}
json.Unmarshal(res, &resp)
issueID := resp.Issue.ID
tests := []struct {
name string
urlPath string
wantCode int
wantBody []byte
}{
{"Valid ID", "/v1/issues/" + strconv.Itoa(int(issueID)), http.StatusOK, []byte("An old silent pond...")},
{"Non-existent ID", "/v1/issues/" + strconv.Itoa(int(issueID + 1)), http.StatusNotFound, nil},
{"Negative ID", "/v1/issues/-1", http.StatusNotFound, nil},
{"Decimal ID", "/v1/issues/1.23", http.StatusNotFound, nil},
{"String ID", "/v1/issues/foo", http.StatusNotFound, nil},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
code, _, body := ts.getWithToken(t, tt.urlPath, token)
if code != tt.wantCode {
t.Errorf("want %d; got %d", tt.wantCode, code)
}
if !bytes.Contains(body, tt.wantBody) {
t.Errorf("want body to contain %q; got %q", tt.wantBody, string(body))
}
})
}
}

View File

@ -24,6 +24,10 @@ import (
"party.at/party/internal/data" "party.at/party/internal/data"
"party.at/party/internal/mailer" "party.at/party/internal/mailer"
"party.at/party/internal/jsonlog" "party.at/party/internal/jsonlog"
"crypto/rand"
"crypto/aes"
"crypto/cipher"
) )
var version string var version string
@ -80,12 +84,31 @@ var upgrader = websocket.Upgrader{
func main() { func main() {
fmt.Printf("Full Args: %v\n", os.Args) // 32 bytes = AES-256
var key [32]byte
if _, err := rand.Read(key[:]); err != nil {
panic(err)
}
message := []byte("hello secure world")
nonce, ciphertext, err := encrypt(key, message)
if err != nil {
panic(err)
}
plaintext, err := decrypt(key, nonce, ciphertext)
if err != nil {
panic(err)
}
fmt.Printf("Original: %s\n", message)
fmt.Printf("Decrypted: %s\n", plaintext)
var cfg config var cfg config
flag.IntVar(&cfg.port, "port", 4000, "API server port") flag.IntVar(&cfg.port, "port", 4000, "API server port")
flag.StringVar(&cfg.env, "env", "development", "Environment (development|staging|production)") flag.StringVar(&cfg.env, "env", "production", "Environment (development|staging|production)")
flag.StringVar(&cfg.db.dsn, "db-dsn", os.Getenv("PARTY_DB_DSN"), "PostgreSQL DSN") flag.StringVar(&cfg.db.dsn, "db-dsn", os.Getenv("PARTY_DB_DSN"), "PostgreSQL DSN")
//addr := flag.String("addr", ":8443", "HTTP network address") //addr := flag.String("addr", ":8443", "HTTP network address")
@ -97,7 +120,7 @@ func main() {
flag.IntVar(&cfg.smtp.port, "smtp-port", 25, "SMTP port") flag.IntVar(&cfg.smtp.port, "smtp-port", 25, "SMTP port")
flag.StringVar(&cfg.smtp.username, "smtp-username", "98cf60028d7fcb", "SMTP username") flag.StringVar(&cfg.smtp.username, "smtp-username", "98cf60028d7fcb", "SMTP username")
flag.StringVar(&cfg.smtp.password, "smtp-password", "b9d4a35372e971", "SMTP password") flag.StringVar(&cfg.smtp.password, "smtp-password", "b9d4a35372e971", "SMTP password")
flag.StringVar(&cfg.smtp.sender, "smtp-sender", "Greenlight <no-reply@greenlight.alexedwards.net>", "SMTP sender") flag.StringVar(&cfg.smtp.sender, "smtp-sender", "DPÖ <no-reply@party.at>", "SMTP sender")
flag.Float64Var(&cfg.limiter.rps, "limiter-rps", 2, "Rate limiter maximum requests per second") flag.Float64Var(&cfg.limiter.rps, "limiter-rps", 2, "Rate limiter maximum requests per second")
flag.IntVar(&cfg.limiter.burst, "limiter-burst", 4, "Rate limiter maximum burst") flag.IntVar(&cfg.limiter.burst, "limiter-burst", 4, "Rate limiter maximum burst")
@ -126,7 +149,6 @@ func main() {
if err != nil { if err != nil {
logger.PrintFatal(err, nil) logger.PrintFatal(err, nil)
} }
defer db.Close() defer db.Close()
expvar.NewString("version").Set(version) expvar.NewString("version").Set(version)
@ -188,3 +210,43 @@ func openDB(cfg config) (*sql.DB, error) {
return db, nil return db, nil
} }
func encrypt(key [32]byte, plaintext []byte) (nonce, ciphertext []byte, err error) {
block, err := aes.NewCipher(key[:])
if err != nil {
return nil, nil, err
}
aead, err := cipher.NewGCM(block)
if err != nil {
return nil, nil, err
}
nonce = make([]byte, aead.NonceSize())
if _, err := rand.Read(nonce); err != nil {
panic(err)
}
ciphertext = aead.Seal(nil, nonce, plaintext, nil)
return nonce, ciphertext, nil
}
func decrypt(key [32]byte, nonce, ciphertext []byte) ([]byte, error) {
block, err := aes.NewCipher(key[:])
if err != nil {
return nil, err
}
aead, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
plaintext, err := aead.Open(nil, nonce, ciphertext, nil)
if err != nil {
return nil, err
}
return plaintext, nil
}

View File

@ -27,6 +27,9 @@ func (app *application) routes() http.Handler {
router.HandlerFunc(http.MethodPatch, "/v1/issues/:id", app.requirePermission("issues:write", app.updateIssueHandler)) router.HandlerFunc(http.MethodPatch, "/v1/issues/:id", app.requirePermission("issues:write", app.updateIssueHandler))
router.HandlerFunc(http.MethodDelete, "/v1/issues/:id", app.requirePermission("issues:write", app.deleteIssueHandler)) router.HandlerFunc(http.MethodDelete, "/v1/issues/:id", app.requirePermission("issues:write", app.deleteIssueHandler))
router.HandlerFunc(http.MethodGet, "/v1/issues/:id/pubkey", app.requirePermission("issues:read", app.readIssuePubKeyHandler))
router.HandlerFunc(http.MethodPost, "/v1/issues/:id/blind-sign", app.requirePermission("issues:read", app.blindSignIssueVoteHandler))
router.HandlerFunc(http.MethodPost, "/v1/users", app.createUserHandler) router.HandlerFunc(http.MethodPost, "/v1/users", app.createUserHandler)
// router.HandlerFunc(http.MethodGet, "/v1/users/:id", app.readUserHandler) // router.HandlerFunc(http.MethodGet, "/v1/users/:id", app.readUserHandler)
// router.HandlerFunc(http.MethodPatch, "/v1/users/:id", app.updateUserHandler) // router.HandlerFunc(http.MethodPatch, "/v1/users/:id", app.updateUserHandler)

202
cmd/api/testutils_test.go Normal file
View File

@ -0,0 +1,202 @@
package main
import (
"io"
"net/http"
"net/http/httptest"
"testing"
"os"
"encoding/json"
"bytes"
"fmt"
"time"
"party.at/party/internal/data"
"party.at/party/internal/jsonlog"
)
func newTestApplication(t *testing.T) *application {
cfg := config{}
cfg.db.dsn = "postgres://party:password@localhost:5432/party?sslmode=disable"
cfg.db.maxOpenConns = 25
cfg.db.maxIdleConns = 25
cfg.db.maxIdleTime = "15m"
cfg.limiter.enabled = false
cfg.env = "development"
logger := jsonlog.New(os.Stdout, jsonlog.LevelInfo)
db, err := openDB(cfg)
if err != nil {
logger.PrintFatal(err, nil)
}
t.Cleanup(func() {db.Close()})
return &application{
logger: logger,
models: data.NewModels(db),
config: cfg,
}
}
type testServer struct {
*httptest.Server
app *application
}
func newTestServer(t *testing.T, app *application, h http.Handler) *testServer {
ts := httptest.NewTLSServer(h)
return &testServer{ts, app}
}
func (ts *testServer) postJSON(t *testing.T, path string, body any) (int, http.Header, []byte) {
t.Helper()
b, err := json.Marshal(body)
if err != nil {
t.Fatal(err)
}
req, err := http.NewRequest(http.MethodPost, ts.URL+path, bytes.NewReader(b))
if err != nil {
t.Fatal(err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := ts.Client().Do(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
return resp.StatusCode, resp.Header, respBody
}
func (ts *testServer) get(t *testing.T, path string) (int, http.Header, []byte) {
t.Helper()
rs, err := ts.Client().Get(ts.URL + path)
if err != nil {
t.Fatal(err)
}
defer rs.Body.Close()
body, err := io.ReadAll(rs.Body)
if err != nil {
t.Fatal(err)
}
return rs.StatusCode, rs.Header, body
}
// registers a user, activates them, logs in, returns the bearer token
func (ts *testServer) registerAndLogin(t *testing.T, email, password string) string {
t.Helper()
// 1. Register
registerBody := map[string]any{
"email": email,
"password": password,
"username": email,
"name": "Test User",
"alt_name" : "",
"provider_id": 1,
}
code, _, body := ts.postJSON(t, "/v1/users", registerBody)
if code != http.StatusCreated {
t.Fatalf("register: want 201 got %d: %s", code, body)
}
// 2. Activate — if your flow requires it, either hit the endpoint
// or directly flip the activated flag in the test DB
// ts.activateUser(t, email)
// 3. Login
loginBody := map[string]string{"email": email, "password": password}
code, _, body = ts.postJSON(t, "/v1/tokens/authentication", loginBody)
if code != http.StatusCreated {
t.Fatalf("login: want 201 got %d: %s", code, body)
}
// 4. Parse token out of response
var resp struct {
AuthenticationToken struct {
Token string `json:"token"`
} `json:"authentication_token"`
}
if err := json.Unmarshal(body, &resp); err != nil {
t.Fatalf("parse token: %v", err)
}
return resp.AuthenticationToken.Token
}
// activateUser directly updates the DB — avoids needing a real email flow
// func (ts *testServer) activateUser(t *testing.T, email string) {
// t.Helper()
// _, err := ts.app.db.Exec("UPDATE users SET activated = true WHERE email = $1", email)
// if err != nil {
// t.Fatalf("activate user: %v", err)
// }
// }
// like ts.get but adds Authorization header
func (ts *testServer) getWithToken(t *testing.T, path, token string) (int, http.Header, []byte) {
t.Helper()
req, err := http.NewRequest(http.MethodGet, ts.URL+path, nil)
if err != nil {
t.Fatal(err)
}
req.Header.Set("Authorization", "Bearer "+token)
rs, err := ts.Client().Do(req)
if err != nil {
t.Fatal(err)
}
defer rs.Body.Close()
body, err := io.ReadAll(rs.Body)
if err != nil {
t.Fatal(err)
}
return rs.StatusCode, rs.Header, body
}
func (ts *testServer) postJSONWithToken(t *testing.T, path, token string, body any) (int, http.Header, []byte) {
t.Helper()
b, err := json.Marshal(body)
if err != nil {
t.Fatal(err)
}
req, err := http.NewRequest(http.MethodPost, ts.URL+path, bytes.NewReader(b))
if err != nil {
t.Fatal(err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+token)
rs, err := ts.Client().Do(req)
if err != nil {
t.Fatal(err)
}
defer rs.Body.Close()
respBody, err := io.ReadAll(rs.Body)
if err != nil {
t.Fatal(err)
}
return rs.StatusCode, rs.Header, respBody
}
func uniqueEmail() string {
return fmt.Sprintf("test_%d@example.com", time.Now().UnixNano())
}

View File

@ -5,7 +5,6 @@ import (
"net/http" "net/http"
"party.at/party/internal/data" "party.at/party/internal/data"
"party.at/party/internal/validator" "party.at/party/internal/validator"
"database/sql"
"time" "time"
) )
@ -19,7 +18,7 @@ func (app *application) createUserHandler(w http.ResponseWriter, r *http.Request
Email string `json:"email"` Email string `json:"email"`
Password string `json:"password"` Password string `json:"password"`
Name string `json:"name"` Name string `json:"name"`
AltName string `json:"alt_name"` AltName *string `json:"alt_name"`
DateOfBirth time.Time `json:"date_of_birth"` DateOfBirth time.Time `json:"date_of_birth"`
Address string `json:"address"` Address string `json:"address"`
} }
@ -35,12 +34,16 @@ func (app *application) createUserHandler(w http.ResponseWriter, r *http.Request
PhoneNumber: input.PhoneNumber, PhoneNumber: input.PhoneNumber,
Country: input.Country, Country: input.Country,
Name: input.Name, Name: input.Name,
AltName: sql.NullString{String: input.AltName, Valid: true}, AltName: input.AltName,
DateOfBirth: input.DateOfBirth, DateOfBirth: input.DateOfBirth,
Address: input.Address, Address: input.Address,
Activated: false, Activated: false,
} }
if app.config.env == "development" {
user.Activated = true
}
userIdentity := &data.UserIdentity{ userIdentity := &data.UserIdentity{
ProviderID: input.ProviderId, ProviderID: input.ProviderId,
ProviderUserID: input.Username, ProviderUserID: input.Username,
@ -85,6 +88,15 @@ func (app *application) createUserHandler(w http.ResponseWriter, r *http.Request
return return
} }
if app.config.env == "development" {
err = app.models.Permissions.AddForUser(user.ID, "issues:write")
if err != nil {
app.serverErrorResponse(w, r, err)
return
}
}
if app.config.env == "production" {
token, err := app.models.Tokens.New(user.ID, userIdentity.ID, 3 * 24 * time.Hour, data.ScopeActivation) token, err := app.models.Tokens.New(user.ID, userIdentity.ID, 3 * 24 * time.Hour, data.ScopeActivation)
if err != nil { if err != nil {
app.serverErrorResponse(w, r, err) app.serverErrorResponse(w, r, err)
@ -102,10 +114,17 @@ func (app *application) createUserHandler(w http.ResponseWriter, r *http.Request
app.logger.PrintError(err, nil) app.logger.PrintError(err, nil)
} }
}) })
}
authentication_token, err := app.models.Tokens.New(user.ID, userIdentity.ID, 24 * time.Hour, data.ScopeAuthentication)
if err != nil {
app.serverErrorResponse(w, r, err)
return
}
// Write a JSON response containing the user data along with a 201 Created status // Write a JSON response containing the user data along with a 201 Created status
// code. // code.
err = app.writeJSON(w, http.StatusCreated, envelope{"user": user}, nil) err = app.writeJSON(w, http.StatusCreated, envelope{"user": user, "authentication_token": authentication_token}, nil)
if err != nil { if err != nil {
app.serverErrorResponse(w, r, err) app.serverErrorResponse(w, r, err)
} }

View File

@ -0,0 +1,88 @@
package data
import (
"time"
"database/sql"
"encoding/pem"
"context"
"errors"
"fmt"
"math/big"
"crypto/rsa"
"crypto/x509"
)
type BlindSignRequest struct {
UserID int64 `json:"user_id"`
IssueID int64 `json:"issue_id"`
Created time.Time `json:"created"`
}
type BlindSignRequestModel struct {
DB *sql.DB
}
func (m BlindSignRequestModel) Insert(blind_sign *BlindSignRequest) error {
query := `
INSERT INTO blind_sign_requests (user_id, issue_id)
VALUES ($1, $2)
RETURNING created`
args := []interface{}{
blind_sign.UserID,
blind_sign.IssueID,
}
return m.DB.QueryRow(query, args...).Scan(
&blind_sign.Created,
)
}
func (m BlindSignRequestModel) BlindSign(issueID int64, blindedVoteBytes []byte) ([]byte, error) {
if issueID < 1 {
return nil, ErrRecordNotFound
}
query := `SELECT rsa_private_pem FROM issues WHERE id = $1`
ctx, cancel := context.WithTimeout(context.Background(), 3 * time.Second)
defer cancel()
var pemBytes []byte
err := m.DB.QueryRowContext(ctx, query, issueID).Scan(&pemBytes)
if err != nil {
switch {
case errors.Is(err, sql.ErrNoRows):
return nil, ErrRecordNotFound
default:
return nil, err
}
}
key, err := parsePrivateKey(pemBytes)
if err != nil {
return nil, fmt.Errorf("parse private key: %w", err)
}
m_ := new(big.Int).SetBytes(blindedVoteBytes)
// Validate range: m must be in [1, n-1]
one := big.NewInt(1)
if m_.Cmp(one) < 0 || m_.Cmp(key.N) >= 0 {
return nil, ErrInvalidBlindedVote
}
// s = m^d mod n
sig := new(big.Int).Exp(m_, key.D, key.N)
return sig.Bytes(), nil
}
func parsePrivateKey(pemBytes []byte) (*rsa.PrivateKey, error) {
block, _ := pem.Decode(pemBytes)
if block == nil {
return nil, errors.New("failed to decode PEM block")
}
return x509.ParsePKCS1PrivateKey(block.Bytes)
}

View File

@ -15,6 +15,9 @@ type Issue struct {
Description string `json:"description"` Description string `json:"description"`
StartTime time.Time `json:"start_time"` StartTime time.Time `json:"start_time"`
EndTime time.Time `json:"end_time"` EndTime time.Time `json:"end_time"`
N []byte `json:"-"`
E int `json:"-"`
PrivatePem []byte `json:"-"`
Created time.Time `json:"created"` Created time.Time `json:"created"`
Version int32 `json:"version"` Version int32 `json:"version"`
} }
@ -36,8 +39,8 @@ type IssueModel struct {
func (m IssueModel) Insert(issue *Issue) error { func (m IssueModel) Insert(issue *Issue) error {
query := ` query := `
INSERT INTO issues (title, description, start_time, end_time) INSERT INTO issues (title, description, start_time, end_time, rsa_n, rsa_e, rsa_private_pem)
VALUES ($1, $2, $3, $4) VALUES ($1, $2, $3, $4, $5, $6, $7)
RETURNING id, created, version` RETURNING id, created, version`
args := []interface{}{ args := []interface{}{
@ -45,6 +48,9 @@ RETURNING id, created, version`
issue.Description, issue.Description,
issue.StartTime, issue.StartTime,
issue.EndTime, issue.EndTime,
issue.N,
issue.E,
issue.PrivatePem,
} }
return m.DB.QueryRow(query, args...).Scan( return m.DB.QueryRow(query, args...).Scan(
@ -60,32 +66,26 @@ func (m IssueModel) Get(id int64) (*Issue, error) {
return nil, ErrRecordNotFound return nil, ErrRecordNotFound
} }
// Define the SQL query for retrieving the issue data.
query := ` query := `
SELECT id, title, description, start_time, end_time, created, version SELECT id, title, description, start_time, end_time, rsa_n, rsa_e, rsa_private_pem, created, version
FROM issues FROM issues
WHERE id = $1` WHERE id = $1`
// Declare a Issue struct to hold the data returned by the query.
var issue Issue var issue Issue
// Execute the query using the QueryRow() method, passing in the provided id value
// as a placeholder parameter, and scan the response data into the fields of the
// Issue struct. Importantly, notice that we need to convert the scan target for the
// genres column using the pq.Array() adapter function again.
err := m.DB.QueryRow(query, id).Scan( err := m.DB.QueryRow(query, id).Scan(
&issue.ID, &issue.ID,
&issue.Title, &issue.Title,
&issue.Description, &issue.Description,
&issue.StartTime, &issue.StartTime,
&issue.EndTime, &issue.EndTime,
&issue.N,
&issue.E,
&issue.PrivatePem,
&issue.Created, &issue.Created,
&issue.Version, &issue.Version,
) )
// Handle any errors. If there was no matching issue found, Scan() will return
// a sql.ErrNoRows error. We check for this and return our custom ErrRecordNotFound
// error instead.
if err != nil { if err != nil {
switch { switch {
case errors.Is(err, sql.ErrNoRows): case errors.Is(err, sql.ErrNoRows):
@ -94,29 +94,37 @@ WHERE id = $1`
return nil, err return nil, err
} }
} }
// Otherwise, return a pointer to the Issue struct.
return &issue, nil return &issue, nil
} }
func (m IssueModel) Update(issue *Issue) error { func (m IssueModel) Update(issue *Issue) error {
query := ` query := `
UPDATE issues UPDATE issues
SET title = $1, description = $2, start_time = $3, end_time = $4, version = version + 1 SET
WHERE id = $5 AND version = $6 title = $1,
description = $2,
start_time = $3,
end_time = $4,
rsa_n = $5,
rsa_e = $6,
rsa_private_pem = $7,
version = version + 1
WHERE id = $8 AND version = $9
RETURNING version` RETURNING version`
// Create an args slice containing the values for the placeholder parameters.
args := []interface{}{ args := []interface{}{
issue.Title, issue.Title,
issue.Description, issue.Description,
issue.StartTime, issue.StartTime,
issue.EndTime, issue.EndTime,
issue.N,
issue.E,
issue.PrivatePem,
issue.ID, issue.ID,
issue.Version, issue.Version,
} }
// Use the QueryRow() method to execute the query, passing in the args slice as a
// variadic parameter and scanning the new version value into the issue struct.
err := m.DB.QueryRow(query, args...).Scan(&issue.Version) err := m.DB.QueryRow(query, args...).Scan(&issue.Version)
if err != nil { if err != nil {
switch { switch {
@ -170,7 +178,7 @@ func (m IssueModel) GetAll(title string, filters Filters) ([]*Issue, Metadata, e
// Construct the SQL query to retrieve all issue records. // Construct the SQL query to retrieve all issue records.
query := fmt.Sprintf(` query := fmt.Sprintf(`
SELECT SELECT
COUNT(*) OVER(), id, title, description, start_time, end_time, created, version COUNT(*) OVER(), id, title, description, start_time, end_time, rsa_n, rsa_e, rsa_private_pem, created, version
FROM FROM
issues issues
WHERE WHERE
@ -217,6 +225,9 @@ func (m IssueModel) GetAll(title string, filters Filters) ([]*Issue, Metadata, e
&issue.Description, &issue.Description,
&issue.StartTime, &issue.StartTime,
&issue.EndTime, &issue.EndTime,
&issue.N,
&issue.E,
&issue.PrivatePem,
&issue.Created, &issue.Created,
&issue.Version, &issue.Version,
) )

View File

@ -8,6 +8,7 @@ import (
var ( var (
ErrRecordNotFound = errors.New("record not found") ErrRecordNotFound = errors.New("record not found")
ErrEditConflict = errors.New("edit conflict") ErrEditConflict = errors.New("edit conflict")
ErrInvalidBlindedVote = errors.New("invalid blinded vote")
) )
type Models struct { type Models struct {
@ -16,6 +17,7 @@ type Models struct {
Issues IssueModel Issues IssueModel
Tokens TokenModel Tokens TokenModel
Permissions PermissionModel Permissions PermissionModel
BlindSignRequests BlindSignRequestModel
} }
func NewModels(db *sql.DB) Models { func NewModels(db *sql.DB) Models {
@ -25,5 +27,6 @@ func NewModels(db *sql.DB) Models {
Issues: IssueModel{DB: db}, Issues: IssueModel{DB: db},
Tokens: TokenModel{DB: db}, Tokens: TokenModel{DB: db},
Permissions: PermissionModel{DB: db}, Permissions: PermissionModel{DB: db},
BlindSignRequests: BlindSignRequestModel{DB: db},
} }
} }

View File

@ -23,7 +23,7 @@ type User struct {
PhoneNumber string `json:"phone_number"` PhoneNumber string `json:"phone_number"`
Country string `json:"country"` Country string `json:"country"`
Name string `json:"name"` Name string `json:"name"`
AltName sql.NullString `json:"alt_name"` AltName *string `json:"alt_name"`
DateOfBirth time.Time `json:"date_of_birth"` DateOfBirth time.Time `json:"date_of_birth"`
Address string `json:"address"` Address string `json:"address"`
Created time.Time `json:"created"` Created time.Time `json:"created"`
@ -60,8 +60,8 @@ func (m UserModel) ExecuteRegistrationTx(user *User, userIdentity *UserIdentity)
defer tx.Rollback() defer tx.Rollback()
query := ` query := `
INSERT INTO users (email, phone_number, country, name, alt_name, date_of_birth, address) INSERT INTO users (email, phone_number, country, name, alt_name, date_of_birth, address, activated)
VALUES ($1, $2, $3, $4, $5, $6, $7) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
RETURNING id, created, last_login, version` RETURNING id, created, last_login, version`
args := []interface{}{ args := []interface{}{
@ -72,6 +72,7 @@ RETURNING id, created, last_login, version`
user.AltName, user.AltName,
user.DateOfBirth, user.DateOfBirth,
user.Address, user.Address,
user.Activated,
} }
err = tx.QueryRowContext(ctx, query, args...).Scan(&user.ID, &user.Created, &user.LastLogin, &user.Version) err = tx.QueryRowContext(ctx, query, args...).Scan(&user.ID, &user.Created, &user.LastLogin, &user.Version)
@ -163,7 +164,7 @@ WHERE id = $1`
func (m UserModel) GetByEmail(email string) (*User, error) { func (m UserModel) GetByEmail(email string) (*User, error) {
query :=` query :=`
SELECT id, email, phone, country, name, alt_name, date_of_birth, address, created, last_login, activated, version SELECT id, email, phone_number, country, name, alt_name, date_of_birth, address, created, last_login, activated, version
FROM users FROM users
WHERE email = $1` WHERE email = $1`
@ -205,7 +206,17 @@ func (m UserModel) GetForToken(tokenScope, tokenPlaintext string) (*User, error)
// Set up the SQL query. // Set up the SQL query.
query :=` query :=`
SELECT users.id, users.email, user.phone_number, user.country, users.name, users.date_of_birth, users.address, users.created, users.activated, users.version SELECT
users.id,
users.email,
users.phone_number,
users.country,
users.name,
users.date_of_birth,
users.address,
users.created,
users.activated,
users.version
FROM users FROM users
INNER JOIN tokens ON users.id = tokens.user_id INNER JOIN tokens ON users.id = tokens.user_id
WHERE tokens.hash = $1 WHERE tokens.hash = $1
@ -251,7 +262,16 @@ AND tokens.expiry > $3`
func (m UserModel) Update(user *User) error { func (m UserModel) Update(user *User) error {
query := ` query := `
UPDATE users UPDATE users
SET email = $1, phone_number = $2, country = $3, name = $4, alt_name = $5, date_of_birth = $6, address = $7, activated = $8, version = version + 1 SET
email = $1,
phone_number = $2,
country = $3,
name = $4,
alt_name = $5,
date_of_birth = $6,
address = $7,
activated = $8,
version = version + 1
WHERE id = $9 AND version = $10 WHERE id = $9 AND version = $10
RETURNING version` RETURNING version`

View File

@ -8,19 +8,16 @@ import (
"sync" "sync"
"time" "time"
) )
// Define a Level type to represent the severity level for a log entry.
type Level int8 type Level int8
// Initialize constants which represent a specific severity level. We use the iota
// keyword as a shortcut to assign successive integer values to the constants.
const ( const (
LevelInfo Level = iota // Has the value 0. LevelInfo Level = iota
LevelError // Has the value 1. LevelError
LevelFatal // Has the value 2. LevelFatal
LevelOff // Has the value 3. LevelOff
) )
// Return a human-friendly string for the severity level.
func (l Level) String() string { func (l Level) String() string {
switch l { switch l {
case LevelInfo: case LevelInfo:
@ -34,17 +31,12 @@ func (l Level) String() string {
} }
} }
// Define a custom Logger type. This holds the output destination that the log entries
// will be written to, the minimum severity level that log entries will be written for,
// plus a mutex for coordinating the writes.
type Logger struct { type Logger struct {
out io.Writer out io.Writer
minLevel Level minLevel Level
mu sync.Mutex mu sync.Mutex
} }
// Return a new Logger instance which writes log entries at or above a minimum severity
// level to a specific output destination.
func New(out io.Writer, minLevel Level) *Logger { func New(out io.Writer, minLevel Level) *Logger {
return &Logger{ return &Logger{
out: out, out: out,
@ -52,9 +44,6 @@ func New(out io.Writer, minLevel Level) *Logger {
} }
} }
// Declare some helper methods for writing log entries at the different levels. Notice
// that these all accept a map as the second parameter which can contain any arbitrary
// 'properties' that you want to appear in the log entry.
func (l *Logger) PrintInfo(message string, properties map[string]string) { func (l *Logger) PrintInfo(message string, properties map[string]string) {
l.print(LevelInfo, message, properties) l.print(LevelInfo, message, properties)
} }
@ -65,18 +54,14 @@ func (l *Logger) PrintError(err error, properties map[string]string) {
func (l *Logger) PrintFatal(err error, properties map[string]string) { func (l *Logger) PrintFatal(err error, properties map[string]string) {
l.print(LevelFatal, err.Error(), properties) l.print(LevelFatal, err.Error(), properties)
os.Exit(1) // For entries at the FATAL level, we also terminate the application. os.Exit(1)
} }
// Print is an internal method for writing the log entry.
func (l *Logger) print(level Level, message string, properties map[string]string) (int, error) { func (l *Logger) print(level Level, message string, properties map[string]string) (int, error) {
// If the severity level of the log entry is below the minimum severity for the
// logger, then return with no further action.
if level < l.minLevel { if level < l.minLevel {
return 0, nil return 0, nil
} }
// Declare an anonymous struct holding the data for the log entry.
aux := struct { aux := struct {
Level string `json:"level"` Level string `json:"level"`
Time string `json:"time"` Time string `json:"time"`
@ -90,35 +75,23 @@ func (l *Logger) print(level Level, message string, properties map[string]string
Properties: properties, Properties: properties,
} }
// Include a stack trace for entries at the ERROR and FATAL levels.
if level >= LevelError { if level >= LevelError {
aux.Trace = string(debug.Stack()) aux.Trace = string(debug.Stack())
} }
// Declare a line variable for holding the actual log entry text.
var line []byte var line []byte
// Marshal the anonymous struct to JSON and store it in the line variable. If there
// was a problem creating the JSON, set the contents of the log entry to be that
// plain-text error message instead.
line, err := json.Marshal(aux) line, err := json.Marshal(aux)
if err != nil { if err != nil {
line = []byte(LevelError.String() + ": unable to marshal log message:" + err.Error()) line = []byte(LevelError.String() + ": unable to marshal log message:" + err.Error())
} }
// Lock the mutex so that no two writes to the output destination cannot happen
// concurrently. If we don't do this, it's possible that the text for two or more
// log entries will be intermingled in the output.
l.mu.Lock() l.mu.Lock()
defer l.mu.Unlock() defer l.mu.Unlock()
// Write the log entry followed by a newline.
return l.out.Write(append(line, '\n')) return l.out.Write(append(line, '\n'))
} }
// We also implement a Write() method on our Logger type so that it satisfies the
// io.Writer interface. This writes a log entry at the ERROR level with no additional
// properties.
func (l *Logger) Write(message []byte) (n int, err error) { func (l *Logger) Write(message []byte) (n int, err error) {
return l.print(LevelError, string(message), nil) return l.print(LevelError, string(message), nil)
} }

View File

@ -26,7 +26,7 @@ The DigitalePartei Team
<p>Hi,</p> <p>Hi,</p>
<p>Thanks for signing up for a Digitale Partei Österreich account. We're excited to have you on board!</p> <p>Thanks for signing up for a Digitale Partei Österreich account. We're excited to have you on board!</p>
<p>For future reference, your user ID number is {{.ID}}.</p> <p>For future reference, your user ID number is {{.ID}}.</p>
<p>Your activation token is {{.token}}</p> <p><a href="dpoe://activate?token={{.token}}">Click here to activate!</a></p>
<p>Thanks,</p> <p>Thanks,</p>
<p>The DigitalePartei Team</p> <p>The DigitalePartei Team</p>
</body> </body>

View File

@ -13,18 +13,18 @@ CREATE TABLE IF NOT EXISTS users (
version INT NOT NULL DEFAULT 1 version INT NOT NULL DEFAULT 1
); );
CREATE TABLE IF NOT EXISTS auth_provider ( CREATE TABLE IF NOT EXISTS auth_providers (
id BIGSERIAL PRIMARY KEY, -- e.g., 'local', 'id_austria', id BIGSERIAL PRIMARY KEY, -- e.g., 'local', 'id_austria',
description TEXT NOT NULL, description TEXT NOT NULL,
active BOOLEAN DEFAULT false active BOOLEAN DEFAULT false
); );
INSERT INTO auth_provider (description, active) VALUES ('local', true); INSERT INTO auth_providers (description, active) VALUES ('local', true);
INSERT INTO auth_provider (description, active) VALUES ('id_austria', false); INSERT INTO auth_providers (description, active) VALUES ('id_austria', false);
CREATE TABLE IF NOT EXISTS user_identities ( CREATE TABLE IF NOT EXISTS user_identities (
id BIGSERIAL PRIMARY KEY, id BIGSERIAL PRIMARY KEY,
provider_id BIGINT NOT NULL REFERENCES auth_provider(id), provider_id BIGINT NOT NULL REFERENCES auth_providers(id),
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE, user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
-- For local: the username. For OIDC: the 'sub' (Subject ID) -- For local: the username. For OIDC: the 'sub' (Subject ID)
@ -33,9 +33,9 @@ CREATE TABLE IF NOT EXISTS user_identities (
-- Nullable because OIDC users won't have a password in your DB -- Nullable because OIDC users won't have a password in your DB
password bytea, password bytea,
version INT NOT NULL DEFAULT 1, version INT NOT NULL DEFAULT 1
UNIQUE(provider_id, provider_user_id) -- UNIQUE(provider_id, provider_user_id)
); );
-- INSERT INTO users( -- INSERT INTO users(

View File

@ -3,6 +3,7 @@ DROP INDEX IF EXISTS idx_vote_tokens_issue_id;
DROP INDEX IF EXISTS idx_options_issue_id; DROP INDEX IF EXISTS idx_options_issue_id;
DROP TABLE IF EXISTS votes; DROP TABLE IF EXISTS votes;
DROP TABLE IF EXISTS blind_sign_requests;
DROP TABLE IF EXISTS vote_tokens; DROP TABLE IF EXISTS vote_tokens;
DROP TABLE IF EXISTS options; DROP TABLE IF EXISTS options;
DROP TABLE IF EXISTS issues; DROP TABLE IF EXISTS issues;

View File

@ -4,6 +4,9 @@ CREATE TABLE IF NOT EXISTS issues (
description TEXT, description TEXT,
start_time TIMESTAMPTZ NOT NULL, start_time TIMESTAMPTZ NOT NULL,
end_time TIMESTAMPTZ NOT NULL, end_time TIMESTAMPTZ NOT NULL,
rsa_n BYTEA,
rsa_e INT,
rsa_private_pem BYTEA,
created TIMESTAMPTZ NOT NULL DEFAULT now(), created TIMESTAMPTZ NOT NULL DEFAULT now(),
version INT NOT NULL DEFAULT 1 version INT NOT NULL DEFAULT 1
); );
@ -25,6 +28,13 @@ CREATE TABLE IF NOT EXISTS vote_tokens (
version INT NOT NULL DEFAULT 1 version INT NOT NULL DEFAULT 1
); );
CREATE TABLE IF NOT EXISTS blind_sign_requests (
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
issue_id BIGINT NOT NULL REFERENCES issues(id) ON DELETE CASCADE,
created TIMESTAMPTZ NOT NULL DEFAULT now(),
PRIMARY KEY (user_id, issue_id)
);
CREATE TABLE IF NOT EXISTS votes ( CREATE TABLE IF NOT EXISTS votes (
id BIGINT GENERATED ALWAYS AS IDENTITY PRIMARY KEY, id BIGINT GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
token UUID NOT NULL UNIQUE REFERENCES vote_tokens(token) ON DELETE CASCADE, token UUID NOT NULL UNIQUE REFERENCES vote_tokens(token) ON DELETE CASCADE,