.
This commit is contained in:
parent
173c01b01b
commit
5772a7d855
4
.vscode/launch.json
vendored
4
.vscode/launch.json
vendored
@ -4,14 +4,14 @@
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
|
||||
{
|
||||
"name": "Launch API",
|
||||
"type": "go",
|
||||
"request": "launch",
|
||||
"mode": "auto",
|
||||
"program": "${workspaceFolder}/cmd/api",
|
||||
"envFile": "${workspaceFolder}/.env"
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"args": ["--env=development"]
|
||||
}
|
||||
]
|
||||
}
|
||||
@ -11,7 +11,10 @@ import (
|
||||
"strconv"
|
||||
"party.at/party/internal/validator"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"crypto/rand"
|
||||
)
|
||||
|
||||
type envelope map[string]interface{}
|
||||
@ -191,3 +194,19 @@ func (app *application) background(fn func()) {
|
||||
fn()
|
||||
}()
|
||||
}
|
||||
|
||||
func GenerateIssueKey() ([]byte, int, []byte, error) {
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return nil, 0, nil, err
|
||||
}
|
||||
|
||||
// Encode private key as PEM
|
||||
privDER := x509.MarshalPKCS1PrivateKey(key)
|
||||
privPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: privDER,
|
||||
})
|
||||
|
||||
return key.N.Bytes(), key.E, privPEM, err
|
||||
}
|
||||
|
||||
32
cmd/api/helpers_test.go
Normal file
32
cmd/api/helpers_test.go
Normal file
@ -0,0 +1,32 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/julienschmidt/httprouter"
|
||||
)
|
||||
|
||||
func TestReadIDParam(t *testing.T) {
|
||||
app := newTestApplication(t)
|
||||
|
||||
const test_id int64 = 3
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "/v1/issues/" + strconv.FormatInt(test_id, 10), nil)
|
||||
|
||||
params := httprouter.Params{{Key: "id", Value: "3"}}
|
||||
ctx := context.WithValue(r.Context(), httprouter.ParamsKey, params)
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
id, err := app.readIDParam(r)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if id != test_id {
|
||||
t.Errorf("want %d, got id %d", test_id, id)
|
||||
}
|
||||
}
|
||||
@ -12,35 +12,30 @@ import (
|
||||
"errors"
|
||||
"party.at/party/internal/data"
|
||||
"party.at/party/internal/validator"
|
||||
"encoding/hex"
|
||||
)
|
||||
|
||||
func (app *application) listIssuesHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// To keep things consistent with our other handlers, we'll define an input struct
|
||||
// to hold the expected values from the request query string.
|
||||
|
||||
var input struct {
|
||||
Title string
|
||||
data.Filters
|
||||
}
|
||||
|
||||
// Initialize a new Validator instance.
|
||||
|
||||
v := validator.New()
|
||||
// Call r.URL.Query() to get the url.Values map containing the query string data.
|
||||
|
||||
qs := r.URL.Query()
|
||||
|
||||
// Use our helpers to extract the title and genres query string values, falling back
|
||||
// to defaults of an empty string and an empty slice respectively if they are not
|
||||
// provided by the client.
|
||||
|
||||
input.Title = app.readString(qs, "title", "")
|
||||
// input.Genres = app.readCSV(qs, "genres", []string{})
|
||||
|
||||
// Get the page and page_size query string values as integers. Notice that we set
|
||||
// the default page value to 1 and default page_size to 20, and that we pass the
|
||||
// validator instance as the final argument here.
|
||||
|
||||
input.Filters.Page = app.readInt(qs, "page", 1, v)
|
||||
input.Filters.PageSize = app.readInt(qs, "page_size", 20, v)
|
||||
|
||||
// Extract the sort query string value, falling back to "id" if it is not provided
|
||||
// by the client (which will imply a ascending sort on issue ID).
|
||||
|
||||
input.Filters.Sort = app.readString(qs, "sort", "id")
|
||||
input.Filters.SortSafelist = []string{"id", "-id", "title", "-title", "description", "-description"}
|
||||
|
||||
@ -77,11 +72,20 @@ func (app *application) createIssueHandler(w http.ResponseWriter, r *http.Reques
|
||||
return
|
||||
}
|
||||
|
||||
n, e, private_pem, err := GenerateIssueKey()
|
||||
if err != nil {
|
||||
app.serverErrorResponse(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
issue := &data.Issue{
|
||||
Title: input.Title,
|
||||
Description: input.Description,
|
||||
StartTime: input.StartTime,
|
||||
EndTime: input.EndTime,
|
||||
N: n,
|
||||
E: e,
|
||||
PrivatePem: private_pem,
|
||||
}
|
||||
|
||||
v := validator.New()
|
||||
@ -219,3 +223,73 @@ func (app *application) deleteIssueHandler(w http.ResponseWriter, r *http.Reques
|
||||
app.serverErrorResponse(w, r, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (app *application) readIssuePubKeyHandler(w http.ResponseWriter, r *http.Request) {
|
||||
id, err := app.readIDParam(r)
|
||||
if err != nil {
|
||||
app.notFoundResponse(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
issue, err := app.models.Issues.Get(id)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, data.ErrRecordNotFound):
|
||||
app.notFoundResponse(w, r)
|
||||
default:
|
||||
app.serverErrorResponse(w, r, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type response struct {
|
||||
N string `json:"n"`
|
||||
E int `json:"e"`
|
||||
}
|
||||
|
||||
res := response{
|
||||
N: hex.EncodeToString(issue.N),
|
||||
E: issue.E,
|
||||
}
|
||||
|
||||
|
||||
err = app.writeJSON(w, http.StatusOK, envelope{"public_key": res}, nil)
|
||||
if err != nil {
|
||||
app.serverErrorResponse(w, r, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (app *application) blindSignIssueVoteHandler(w http.ResponseWriter, r *http.Request) {
|
||||
id, err := app.readIDParam(r)
|
||||
if err != nil {
|
||||
app.notFoundResponse(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
issue, err := app.models.Issues.Get(id)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, data.ErrRecordNotFound):
|
||||
app.notFoundResponse(w, r)
|
||||
default:
|
||||
app.serverErrorResponse(w, r, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type response struct {
|
||||
N string `json:"n"`
|
||||
E int `json:"e"`
|
||||
}
|
||||
|
||||
res := response{
|
||||
N: hex.EncodeToString(issue.N),
|
||||
E: issue.E,
|
||||
}
|
||||
|
||||
|
||||
err = app.writeJSON(w, http.StatusOK, envelope{"public_key": res}, nil)
|
||||
if err != nil {
|
||||
app.serverErrorResponse(w, r, err)
|
||||
}
|
||||
}
|
||||
|
||||
65
cmd/api/issues_test.go
Normal file
65
cmd/api/issues_test.go
Normal file
@ -0,0 +1,65 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestReadIssueHandler(t *testing.T) {
|
||||
app := newTestApplication(t)
|
||||
ts := newTestServer(t, app, app.routes())
|
||||
defer ts.Close()
|
||||
|
||||
token := ts.registerAndLogin(t, uniqueEmail(), "pa$$word123")
|
||||
|
||||
req := map[string]any{
|
||||
"title": "An old silent pond...",
|
||||
"description": "A frog jumps into the pond",
|
||||
"start_time": time.Now().Format(time.RFC3339),
|
||||
"end_time": time.Now().Add(24 * time.Hour).Format(time.RFC3339),
|
||||
}
|
||||
|
||||
code, _, res := ts.postJSONWithToken(t, "/v1/issues", token, req)
|
||||
if code != http.StatusCreated {
|
||||
t.Fatalf("seed issue: want 201 got %d: %s", code, res)
|
||||
}
|
||||
|
||||
var resp struct {
|
||||
Issue struct {
|
||||
ID int64 `json:"id"`
|
||||
} `json:"issue"`
|
||||
}
|
||||
json.Unmarshal(res, &resp)
|
||||
issueID := resp.Issue.ID
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
urlPath string
|
||||
wantCode int
|
||||
wantBody []byte
|
||||
}{
|
||||
{"Valid ID", "/v1/issues/" + strconv.Itoa(int(issueID)), http.StatusOK, []byte("An old silent pond...")},
|
||||
{"Non-existent ID", "/v1/issues/" + strconv.Itoa(int(issueID + 1)), http.StatusNotFound, nil},
|
||||
{"Negative ID", "/v1/issues/-1", http.StatusNotFound, nil},
|
||||
{"Decimal ID", "/v1/issues/1.23", http.StatusNotFound, nil},
|
||||
{"String ID", "/v1/issues/foo", http.StatusNotFound, nil},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
code, _, body := ts.getWithToken(t, tt.urlPath, token)
|
||||
|
||||
if code != tt.wantCode {
|
||||
t.Errorf("want %d; got %d", tt.wantCode, code)
|
||||
}
|
||||
|
||||
if !bytes.Contains(body, tt.wantBody) {
|
||||
t.Errorf("want body to contain %q; got %q", tt.wantBody, string(body))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -24,6 +24,10 @@ import (
|
||||
"party.at/party/internal/data"
|
||||
"party.at/party/internal/mailer"
|
||||
"party.at/party/internal/jsonlog"
|
||||
|
||||
"crypto/rand"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
)
|
||||
|
||||
var version string
|
||||
@ -80,12 +84,31 @@ var upgrader = websocket.Upgrader{
|
||||
|
||||
func main() {
|
||||
|
||||
fmt.Printf("Full Args: %v\n", os.Args)
|
||||
// 32 bytes = AES-256
|
||||
var key [32]byte
|
||||
if _, err := rand.Read(key[:]); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
message := []byte("hello secure world")
|
||||
|
||||
nonce, ciphertext, err := encrypt(key, message)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
plaintext, err := decrypt(key, nonce, ciphertext)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
fmt.Printf("Original: %s\n", message)
|
||||
fmt.Printf("Decrypted: %s\n", plaintext)
|
||||
|
||||
var cfg config
|
||||
|
||||
flag.IntVar(&cfg.port, "port", 4000, "API server port")
|
||||
flag.StringVar(&cfg.env, "env", "development", "Environment (development|staging|production)")
|
||||
flag.StringVar(&cfg.env, "env", "production", "Environment (development|staging|production)")
|
||||
flag.StringVar(&cfg.db.dsn, "db-dsn", os.Getenv("PARTY_DB_DSN"), "PostgreSQL DSN")
|
||||
//addr := flag.String("addr", ":8443", "HTTP network address")
|
||||
|
||||
@ -97,7 +120,7 @@ func main() {
|
||||
flag.IntVar(&cfg.smtp.port, "smtp-port", 25, "SMTP port")
|
||||
flag.StringVar(&cfg.smtp.username, "smtp-username", "98cf60028d7fcb", "SMTP username")
|
||||
flag.StringVar(&cfg.smtp.password, "smtp-password", "b9d4a35372e971", "SMTP password")
|
||||
flag.StringVar(&cfg.smtp.sender, "smtp-sender", "Greenlight <no-reply@greenlight.alexedwards.net>", "SMTP sender")
|
||||
flag.StringVar(&cfg.smtp.sender, "smtp-sender", "DPÖ <no-reply@party.at>", "SMTP sender")
|
||||
|
||||
flag.Float64Var(&cfg.limiter.rps, "limiter-rps", 2, "Rate limiter maximum requests per second")
|
||||
flag.IntVar(&cfg.limiter.burst, "limiter-burst", 4, "Rate limiter maximum burst")
|
||||
@ -126,7 +149,6 @@ func main() {
|
||||
if err != nil {
|
||||
logger.PrintFatal(err, nil)
|
||||
}
|
||||
|
||||
defer db.Close()
|
||||
|
||||
expvar.NewString("version").Set(version)
|
||||
@ -188,3 +210,43 @@ func openDB(cfg config) (*sql.DB, error) {
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func encrypt(key [32]byte, plaintext []byte) (nonce, ciphertext []byte, err error) {
|
||||
block, err := aes.NewCipher(key[:])
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
aead, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
nonce = make([]byte, aead.NonceSize())
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
|
||||
ciphertext = aead.Seal(nil, nonce, plaintext, nil)
|
||||
return nonce, ciphertext, nil
|
||||
}
|
||||
|
||||
func decrypt(key [32]byte, nonce, ciphertext []byte) ([]byte, error) {
|
||||
block, err := aes.NewCipher(key[:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
aead, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
plaintext, err := aead.Open(nil, nonce, ciphertext, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
@ -64,7 +64,7 @@ func (app *application) rateLimit(next http.Handler) http.Handler {
|
||||
// Loop through all clients. If they haven't been seen within the last three
|
||||
// minutes, delete the corresponding entry from the map.
|
||||
for ip, client := range clients {
|
||||
if time.Since(client.lastSeen) > 3*time.Minute {
|
||||
if time.Since(client.lastSeen) > 3 * time.Minute {
|
||||
delete(clients, ip)
|
||||
}
|
||||
}
|
||||
|
||||
@ -27,15 +27,18 @@ func (app *application) routes() http.Handler {
|
||||
router.HandlerFunc(http.MethodPatch, "/v1/issues/:id", app.requirePermission("issues:write", app.updateIssueHandler))
|
||||
router.HandlerFunc(http.MethodDelete, "/v1/issues/:id", app.requirePermission("issues:write", app.deleteIssueHandler))
|
||||
|
||||
router.HandlerFunc(http.MethodPost, "/v1/users", app.createUserHandler)
|
||||
router.HandlerFunc(http.MethodGet, "/v1/issues/:id/pubkey", app.requirePermission("issues:read", app.readIssuePubKeyHandler))
|
||||
router.HandlerFunc(http.MethodPost, "/v1/issues/:id/blind-sign", app.requirePermission("issues:read", app.blindSignIssueVoteHandler))
|
||||
|
||||
router.HandlerFunc(http.MethodPost, "/v1/users", app.createUserHandler)
|
||||
// router.HandlerFunc(http.MethodGet, "/v1/users/:id", app.readUserHandler)
|
||||
// router.HandlerFunc(http.MethodPatch, "/v1/users/:id", app.updateUserHandler)
|
||||
router.HandlerFunc(http.MethodDelete, "/v1/users/:id", app.deleteUserHandler)
|
||||
router.HandlerFunc(http.MethodPut, "/v1/users/activated", app.activateUserHandler)
|
||||
router.HandlerFunc(http.MethodDelete, "/v1/users/:id", app.deleteUserHandler)
|
||||
router.HandlerFunc(http.MethodPut, "/v1/users/activated", app.activateUserHandler)
|
||||
|
||||
router.HandlerFunc(http.MethodPost, "/v1/tokens/authentication", app.createAuthenticationTokenHandler)
|
||||
router.HandlerFunc(http.MethodPost, "/v1/tokens/authentication", app.createAuthenticationTokenHandler)
|
||||
|
||||
router.Handler(http.MethodGet, "/debug/vars", expvar.Handler())
|
||||
router.Handler (http.MethodGet, "/debug/vars", expvar.Handler())
|
||||
|
||||
return app.metrics(app.recoverPanic(app.enableCORS(app.rateLimit(app.authenticate(router)))))
|
||||
}
|
||||
|
||||
202
cmd/api/testutils_test.go
Normal file
202
cmd/api/testutils_test.go
Normal file
@ -0,0 +1,202 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"os"
|
||||
"encoding/json"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"party.at/party/internal/data"
|
||||
"party.at/party/internal/jsonlog"
|
||||
)
|
||||
|
||||
func newTestApplication(t *testing.T) *application {
|
||||
cfg := config{}
|
||||
cfg.db.dsn = "postgres://party:password@localhost:5432/party?sslmode=disable"
|
||||
cfg.db.maxOpenConns = 25
|
||||
cfg.db.maxIdleConns = 25
|
||||
cfg.db.maxIdleTime = "15m"
|
||||
cfg.limiter.enabled = false
|
||||
cfg.env = "development"
|
||||
|
||||
logger := jsonlog.New(os.Stdout, jsonlog.LevelInfo)
|
||||
|
||||
db, err := openDB(cfg)
|
||||
if err != nil {
|
||||
logger.PrintFatal(err, nil)
|
||||
}
|
||||
t.Cleanup(func() {db.Close()})
|
||||
|
||||
return &application{
|
||||
logger: logger,
|
||||
models: data.NewModels(db),
|
||||
config: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
type testServer struct {
|
||||
*httptest.Server
|
||||
app *application
|
||||
}
|
||||
|
||||
func newTestServer(t *testing.T, app *application, h http.Handler) *testServer {
|
||||
ts := httptest.NewTLSServer(h)
|
||||
return &testServer{ts, app}
|
||||
}
|
||||
|
||||
func (ts *testServer) postJSON(t *testing.T, path string, body any) (int, http.Header, []byte) {
|
||||
t.Helper()
|
||||
|
||||
b, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, ts.URL+path, bytes.NewReader(b))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := ts.Client().Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return resp.StatusCode, resp.Header, respBody
|
||||
}
|
||||
|
||||
func (ts *testServer) get(t *testing.T, path string) (int, http.Header, []byte) {
|
||||
t.Helper()
|
||||
|
||||
rs, err := ts.Client().Get(ts.URL + path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
defer rs.Body.Close()
|
||||
body, err := io.ReadAll(rs.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return rs.StatusCode, rs.Header, body
|
||||
}
|
||||
|
||||
// registers a user, activates them, logs in, returns the bearer token
|
||||
func (ts *testServer) registerAndLogin(t *testing.T, email, password string) string {
|
||||
t.Helper()
|
||||
|
||||
// 1. Register
|
||||
registerBody := map[string]any{
|
||||
"email": email,
|
||||
"password": password,
|
||||
"username": email,
|
||||
"name": "Test User",
|
||||
"alt_name" : "",
|
||||
"provider_id": 1,
|
||||
}
|
||||
|
||||
code, _, body := ts.postJSON(t, "/v1/users", registerBody)
|
||||
if code != http.StatusCreated {
|
||||
t.Fatalf("register: want 201 got %d: %s", code, body)
|
||||
}
|
||||
|
||||
// 2. Activate — if your flow requires it, either hit the endpoint
|
||||
// or directly flip the activated flag in the test DB
|
||||
// ts.activateUser(t, email)
|
||||
|
||||
// 3. Login
|
||||
loginBody := map[string]string{"email": email, "password": password}
|
||||
code, _, body = ts.postJSON(t, "/v1/tokens/authentication", loginBody)
|
||||
if code != http.StatusCreated {
|
||||
t.Fatalf("login: want 201 got %d: %s", code, body)
|
||||
}
|
||||
|
||||
// 4. Parse token out of response
|
||||
var resp struct {
|
||||
AuthenticationToken struct {
|
||||
Token string `json:"token"`
|
||||
} `json:"authentication_token"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
t.Fatalf("parse token: %v", err)
|
||||
}
|
||||
return resp.AuthenticationToken.Token
|
||||
}
|
||||
|
||||
// activateUser directly updates the DB — avoids needing a real email flow
|
||||
// func (ts *testServer) activateUser(t *testing.T, email string) {
|
||||
// t.Helper()
|
||||
// _, err := ts.app.db.Exec("UPDATE users SET activated = true WHERE email = $1", email)
|
||||
// if err != nil {
|
||||
// t.Fatalf("activate user: %v", err)
|
||||
// }
|
||||
// }
|
||||
|
||||
// like ts.get but adds Authorization header
|
||||
func (ts *testServer) getWithToken(t *testing.T, path, token string) (int, http.Header, []byte) {
|
||||
t.Helper()
|
||||
req, err := http.NewRequest(http.MethodGet, ts.URL+path, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
rs, err := ts.Client().Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
defer rs.Body.Close()
|
||||
body, err := io.ReadAll(rs.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return rs.StatusCode, rs.Header, body
|
||||
}
|
||||
|
||||
func (ts *testServer) postJSONWithToken(t *testing.T, path, token string, body any) (int, http.Header, []byte) {
|
||||
t.Helper()
|
||||
|
||||
b, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, ts.URL+path, bytes.NewReader(b))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
rs, err := ts.Client().Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
defer rs.Body.Close()
|
||||
respBody, err := io.ReadAll(rs.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return rs.StatusCode, rs.Header, respBody
|
||||
}
|
||||
|
||||
func uniqueEmail() string {
|
||||
return fmt.Sprintf("test_%d@example.com", time.Now().UnixNano())
|
||||
}
|
||||
@ -5,23 +5,22 @@ import (
|
||||
"net/http"
|
||||
"party.at/party/internal/data"
|
||||
"party.at/party/internal/validator"
|
||||
"database/sql"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (app *application) createUserHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
var input struct {
|
||||
ProviderId int64 `json:"provider_id"`
|
||||
Username string `json:"username"`
|
||||
PhoneNumber string `json:"phone_number"`
|
||||
Country string `json:"country"`
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
Name string `json:"name"`
|
||||
AltName string `json:"alt_name"`
|
||||
ProviderId int64 `json:"provider_id"`
|
||||
Username string `json:"username"`
|
||||
PhoneNumber string `json:"phone_number"`
|
||||
Country string `json:"country"`
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
Name string `json:"name"`
|
||||
AltName *string `json:"alt_name"`
|
||||
DateOfBirth time.Time `json:"date_of_birth"`
|
||||
Address string `json:"address"`
|
||||
Address string `json:"address"`
|
||||
}
|
||||
|
||||
err := app.readJSON(w, r, &input)
|
||||
@ -35,12 +34,16 @@ func (app *application) createUserHandler(w http.ResponseWriter, r *http.Request
|
||||
PhoneNumber: input.PhoneNumber,
|
||||
Country: input.Country,
|
||||
Name: input.Name,
|
||||
AltName: sql.NullString{String: input.AltName, Valid: true},
|
||||
AltName: input.AltName,
|
||||
DateOfBirth: input.DateOfBirth,
|
||||
Address: input.Address,
|
||||
Activated: false,
|
||||
}
|
||||
|
||||
if app.config.env == "development" {
|
||||
user.Activated = true
|
||||
}
|
||||
|
||||
userIdentity := &data.UserIdentity{
|
||||
ProviderID: input.ProviderId,
|
||||
ProviderUserID: input.Username,
|
||||
@ -85,27 +88,43 @@ func (app *application) createUserHandler(w http.ResponseWriter, r *http.Request
|
||||
return
|
||||
}
|
||||
|
||||
token, err := app.models.Tokens.New(user.ID, userIdentity.ID, 3 * 24 * time.Hour, data.ScopeActivation)
|
||||
if app.config.env == "development" {
|
||||
err = app.models.Permissions.AddForUser(user.ID, "issues:write")
|
||||
if err != nil {
|
||||
app.serverErrorResponse(w, r, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if app.config.env == "production" {
|
||||
token, err := app.models.Tokens.New(user.ID, userIdentity.ID, 3 * 24 * time.Hour, data.ScopeActivation)
|
||||
if err != nil {
|
||||
app.serverErrorResponse(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
app.background(func() {
|
||||
data := map[string]interface{}{
|
||||
"token": token.Plaintext,
|
||||
"userID": user.ID,
|
||||
}
|
||||
|
||||
err = app.mailer.Send(user.Email, "user_welcome.tmpl", data)
|
||||
if err != nil {
|
||||
app.logger.PrintError(err, nil)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
authentication_token, err := app.models.Tokens.New(user.ID, userIdentity.ID, 24 * time.Hour, data.ScopeAuthentication)
|
||||
if err != nil {
|
||||
app.serverErrorResponse(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
app.background(func() {
|
||||
data := map[string]interface{}{
|
||||
"token": token.Plaintext,
|
||||
"userID": user.ID,
|
||||
}
|
||||
|
||||
err = app.mailer.Send(user.Email, "user_welcome.tmpl", data)
|
||||
if err != nil {
|
||||
app.logger.PrintError(err, nil)
|
||||
}
|
||||
})
|
||||
|
||||
// Write a JSON response containing the user data along with a 201 Created status
|
||||
// code.
|
||||
err = app.writeJSON(w, http.StatusCreated, envelope{"user": user}, nil)
|
||||
err = app.writeJSON(w, http.StatusCreated, envelope{"user": user, "authentication_token": authentication_token}, nil)
|
||||
if err != nil {
|
||||
app.serverErrorResponse(w, r, err)
|
||||
}
|
||||
|
||||
88
internal/data/blind_sign_requests.go
Normal file
88
internal/data/blind_sign_requests.go
Normal file
@ -0,0 +1,88 @@
|
||||
package data
|
||||
|
||||
import (
|
||||
"time"
|
||||
"database/sql"
|
||||
"encoding/pem"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
)
|
||||
|
||||
type BlindSignRequest struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
IssueID int64 `json:"issue_id"`
|
||||
Created time.Time `json:"created"`
|
||||
}
|
||||
|
||||
type BlindSignRequestModel struct {
|
||||
DB *sql.DB
|
||||
}
|
||||
|
||||
func (m BlindSignRequestModel) Insert(blind_sign *BlindSignRequest) error {
|
||||
query := `
|
||||
INSERT INTO blind_sign_requests (user_id, issue_id)
|
||||
VALUES ($1, $2)
|
||||
RETURNING created`
|
||||
|
||||
args := []interface{}{
|
||||
blind_sign.UserID,
|
||||
blind_sign.IssueID,
|
||||
}
|
||||
|
||||
return m.DB.QueryRow(query, args...).Scan(
|
||||
&blind_sign.Created,
|
||||
)
|
||||
}
|
||||
|
||||
func (m BlindSignRequestModel) BlindSign(issueID int64, blindedVoteBytes []byte) ([]byte, error) {
|
||||
if issueID < 1 {
|
||||
return nil, ErrRecordNotFound
|
||||
}
|
||||
|
||||
query := `SELECT rsa_private_pem FROM issues WHERE id = $1`
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3 * time.Second)
|
||||
defer cancel()
|
||||
|
||||
var pemBytes []byte
|
||||
err := m.DB.QueryRowContext(ctx, query, issueID).Scan(&pemBytes)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, sql.ErrNoRows):
|
||||
return nil, ErrRecordNotFound
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
key, err := parsePrivateKey(pemBytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse private key: %w", err)
|
||||
}
|
||||
|
||||
m_ := new(big.Int).SetBytes(blindedVoteBytes)
|
||||
|
||||
// Validate range: m′ must be in [1, n-1]
|
||||
one := big.NewInt(1)
|
||||
if m_.Cmp(one) < 0 || m_.Cmp(key.N) >= 0 {
|
||||
return nil, ErrInvalidBlindedVote
|
||||
}
|
||||
|
||||
// s′ = m′^d mod n
|
||||
sig := new(big.Int).Exp(m_, key.D, key.N)
|
||||
|
||||
return sig.Bytes(), nil
|
||||
}
|
||||
|
||||
func parsePrivateKey(pemBytes []byte) (*rsa.PrivateKey, error) {
|
||||
block, _ := pem.Decode(pemBytes)
|
||||
if block == nil {
|
||||
return nil, errors.New("failed to decode PEM block")
|
||||
}
|
||||
return x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
}
|
||||
|
||||
@ -10,13 +10,16 @@ import (
|
||||
)
|
||||
|
||||
type Issue struct {
|
||||
ID int64 `json:"id"`
|
||||
Title string `json:"title"`
|
||||
Description string `json:"description"`
|
||||
StartTime time.Time `json:"start_time"`
|
||||
EndTime time.Time `json:"end_time"`
|
||||
Created time.Time `json:"created"`
|
||||
Version int32 `json:"version"`
|
||||
ID int64 `json:"id"`
|
||||
Title string `json:"title"`
|
||||
Description string `json:"description"`
|
||||
StartTime time.Time `json:"start_time"`
|
||||
EndTime time.Time `json:"end_time"`
|
||||
N []byte `json:"-"`
|
||||
E int `json:"-"`
|
||||
PrivatePem []byte `json:"-"`
|
||||
Created time.Time `json:"created"`
|
||||
Version int32 `json:"version"`
|
||||
}
|
||||
|
||||
func ValidateIssue(v *validator.Validator, issue *Issue) {
|
||||
@ -36,8 +39,8 @@ type IssueModel struct {
|
||||
|
||||
func (m IssueModel) Insert(issue *Issue) error {
|
||||
query := `
|
||||
INSERT INTO issues (title, description, start_time, end_time)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
INSERT INTO issues (title, description, start_time, end_time, rsa_n, rsa_e, rsa_private_pem)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
RETURNING id, created, version`
|
||||
|
||||
args := []interface{}{
|
||||
@ -45,6 +48,9 @@ RETURNING id, created, version`
|
||||
issue.Description,
|
||||
issue.StartTime,
|
||||
issue.EndTime,
|
||||
issue.N,
|
||||
issue.E,
|
||||
issue.PrivatePem,
|
||||
}
|
||||
|
||||
return m.DB.QueryRow(query, args...).Scan(
|
||||
@ -60,32 +66,26 @@ func (m IssueModel) Get(id int64) (*Issue, error) {
|
||||
return nil, ErrRecordNotFound
|
||||
}
|
||||
|
||||
// Define the SQL query for retrieving the issue data.
|
||||
query :=`
|
||||
SELECT id, title, description, start_time, end_time, created, version
|
||||
query := `
|
||||
SELECT id, title, description, start_time, end_time, rsa_n, rsa_e, rsa_private_pem, created, version
|
||||
FROM issues
|
||||
WHERE id = $1`
|
||||
|
||||
// Declare a Issue struct to hold the data returned by the query.
|
||||
var issue Issue
|
||||
|
||||
// Execute the query using the QueryRow() method, passing in the provided id value
|
||||
// as a placeholder parameter, and scan the response data into the fields of the
|
||||
// Issue struct. Importantly, notice that we need to convert the scan target for the
|
||||
// genres column using the pq.Array() adapter function again.
|
||||
err := m.DB.QueryRow(query, id).Scan(
|
||||
&issue.ID,
|
||||
&issue.Title,
|
||||
&issue.Description,
|
||||
&issue.StartTime,
|
||||
&issue.EndTime,
|
||||
&issue.N,
|
||||
&issue.E,
|
||||
&issue.PrivatePem,
|
||||
&issue.Created,
|
||||
&issue.Version,
|
||||
)
|
||||
|
||||
// Handle any errors. If there was no matching issue found, Scan() will return
|
||||
// a sql.ErrNoRows error. We check for this and return our custom ErrRecordNotFound
|
||||
// error instead.
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, sql.ErrNoRows):
|
||||
@ -94,29 +94,37 @@ WHERE id = $1`
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
// Otherwise, return a pointer to the Issue struct.
|
||||
|
||||
return &issue, nil
|
||||
}
|
||||
|
||||
func (m IssueModel) Update(issue *Issue) error {
|
||||
query := `
|
||||
UPDATE issues
|
||||
SET title = $1, description = $2, start_time = $3, end_time = $4, version = version + 1
|
||||
WHERE id = $5 AND version = $6
|
||||
SET
|
||||
title = $1,
|
||||
description = $2,
|
||||
start_time = $3,
|
||||
end_time = $4,
|
||||
rsa_n = $5,
|
||||
rsa_e = $6,
|
||||
rsa_private_pem = $7,
|
||||
version = version + 1
|
||||
WHERE id = $8 AND version = $9
|
||||
RETURNING version`
|
||||
|
||||
// Create an args slice containing the values for the placeholder parameters.
|
||||
args := []interface{}{
|
||||
issue.Title,
|
||||
issue.Description,
|
||||
issue.StartTime,
|
||||
issue.EndTime,
|
||||
issue.N,
|
||||
issue.E,
|
||||
issue.PrivatePem,
|
||||
issue.ID,
|
||||
issue.Version,
|
||||
}
|
||||
|
||||
// Use the QueryRow() method to execute the query, passing in the args slice as a
|
||||
// variadic parameter and scanning the new version value into the issue struct.
|
||||
err := m.DB.QueryRow(query, args...).Scan(&issue.Version)
|
||||
if err != nil {
|
||||
switch {
|
||||
@ -170,7 +178,7 @@ func (m IssueModel) GetAll(title string, filters Filters) ([]*Issue, Metadata, e
|
||||
// Construct the SQL query to retrieve all issue records.
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
COUNT(*) OVER(), id, title, description, start_time, end_time, created, version
|
||||
COUNT(*) OVER(), id, title, description, start_time, end_time, rsa_n, rsa_e, rsa_private_pem, created, version
|
||||
FROM
|
||||
issues
|
||||
WHERE
|
||||
@ -217,6 +225,9 @@ func (m IssueModel) GetAll(title string, filters Filters) ([]*Issue, Metadata, e
|
||||
&issue.Description,
|
||||
&issue.StartTime,
|
||||
&issue.EndTime,
|
||||
&issue.N,
|
||||
&issue.E,
|
||||
&issue.PrivatePem,
|
||||
&issue.Created,
|
||||
&issue.Version,
|
||||
)
|
||||
|
||||
@ -6,24 +6,27 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrRecordNotFound = errors.New("record not found")
|
||||
ErrEditConflict = errors.New("edit conflict")
|
||||
ErrRecordNotFound = errors.New("record not found")
|
||||
ErrEditConflict = errors.New("edit conflict")
|
||||
ErrInvalidBlindedVote = errors.New("invalid blinded vote")
|
||||
)
|
||||
|
||||
type Models struct {
|
||||
Users UserModel
|
||||
UserIdentities UserIdentityModel
|
||||
Issues IssueModel
|
||||
Tokens TokenModel
|
||||
Permissions PermissionModel
|
||||
Users UserModel
|
||||
UserIdentities UserIdentityModel
|
||||
Issues IssueModel
|
||||
Tokens TokenModel
|
||||
Permissions PermissionModel
|
||||
BlindSignRequests BlindSignRequestModel
|
||||
}
|
||||
|
||||
func NewModels(db *sql.DB) Models {
|
||||
return Models{
|
||||
Users: UserModel{DB: db},
|
||||
UserIdentities: UserIdentityModel{DB: db},
|
||||
Issues: IssueModel{DB: db},
|
||||
Tokens: TokenModel{DB: db},
|
||||
Permissions: PermissionModel{DB: db},
|
||||
Users: UserModel{DB: db},
|
||||
UserIdentities: UserIdentityModel{DB: db},
|
||||
Issues: IssueModel{DB: db},
|
||||
Tokens: TokenModel{DB: db},
|
||||
Permissions: PermissionModel{DB: db},
|
||||
BlindSignRequests: BlindSignRequestModel{DB: db},
|
||||
}
|
||||
}
|
||||
|
||||
@ -31,7 +31,7 @@ INNER JOIN users_permissions ON users_permissions.permission_id = permissions.id
|
||||
INNER JOIN users ON users_permissions.user_id = users.id
|
||||
WHERE users.id = $1`
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3 * time.Second)
|
||||
defer cancel()
|
||||
|
||||
rows, err := m.DB.QueryContext(ctx, query, userID)
|
||||
@ -62,7 +62,7 @@ func (m PermissionModel) AddForUser(userID int64, codes ...string) error {
|
||||
INSERT INTO users_permissions
|
||||
SELECT $1, permissions.id FROM permissions WHERE permissions.code = ANY($2)`
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3 * time.Second)
|
||||
defer cancel()
|
||||
_, err := m.DB.ExecContext(ctx, query, userID, pq.Array(codes))
|
||||
return err
|
||||
|
||||
@ -75,7 +75,7 @@ func (m TokenModel) Insert(token *Token) error {
|
||||
|
||||
args := []interface{}{token.Hash, token.UserID, token.UserIdentityID, token.Expiry, token.Scope}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3 * time.Second)
|
||||
defer cancel()
|
||||
|
||||
_, err := m.DB.ExecContext(ctx, query, args...)
|
||||
|
||||
@ -90,7 +90,7 @@ type UserIdentityModel struct {
|
||||
// userIdentity.Password.hash,
|
||||
// }
|
||||
|
||||
// ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
// ctx, cancel := context.WithTimeout(context.Background(), 3 * time.Second)
|
||||
// defer cancel()
|
||||
|
||||
// err := m.DB.QueryRowContext(ctx, query, args...).Scan(&userIdentity.ID, &userIdentity.Version)
|
||||
@ -214,7 +214,7 @@ AND tokens.expiry > $3`
|
||||
|
||||
var userIdentity UserIdentity
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3 * time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := m.DB.QueryRowContext(ctx, query, args...).Scan(
|
||||
@ -252,7 +252,7 @@ func (m UserIdentityModel) Update(user *UserIdentity) error {
|
||||
user.Version,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3 * time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := m.DB.QueryRowContext(ctx, query, args...).Scan(&user.Version)
|
||||
|
||||
@ -23,7 +23,7 @@ type User struct {
|
||||
PhoneNumber string `json:"phone_number"`
|
||||
Country string `json:"country"`
|
||||
Name string `json:"name"`
|
||||
AltName sql.NullString `json:"alt_name"`
|
||||
AltName *string `json:"alt_name"`
|
||||
DateOfBirth time.Time `json:"date_of_birth"`
|
||||
Address string `json:"address"`
|
||||
Created time.Time `json:"created"`
|
||||
@ -60,8 +60,8 @@ func (m UserModel) ExecuteRegistrationTx(user *User, userIdentity *UserIdentity)
|
||||
defer tx.Rollback()
|
||||
|
||||
query := `
|
||||
INSERT INTO users (email, phone_number, country, name, alt_name, date_of_birth, address)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
INSERT INTO users (email, phone_number, country, name, alt_name, date_of_birth, address, activated)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
RETURNING id, created, last_login, version`
|
||||
|
||||
args := []interface{}{
|
||||
@ -72,6 +72,7 @@ RETURNING id, created, last_login, version`
|
||||
user.AltName,
|
||||
user.DateOfBirth,
|
||||
user.Address,
|
||||
user.Activated,
|
||||
}
|
||||
|
||||
err = tx.QueryRowContext(ctx, query, args...).Scan(&user.ID, &user.Created, &user.LastLogin, &user.Version)
|
||||
@ -163,7 +164,7 @@ WHERE id = $1`
|
||||
|
||||
func (m UserModel) GetByEmail(email string) (*User, error) {
|
||||
query :=`
|
||||
SELECT id, email, phone, country, name, alt_name, date_of_birth, address, created, last_login, activated, version
|
||||
SELECT id, email, phone_number, country, name, alt_name, date_of_birth, address, created, last_login, activated, version
|
||||
FROM users
|
||||
WHERE email = $1`
|
||||
|
||||
@ -205,7 +206,17 @@ func (m UserModel) GetForToken(tokenScope, tokenPlaintext string) (*User, error)
|
||||
|
||||
// Set up the SQL query.
|
||||
query :=`
|
||||
SELECT users.id, users.email, user.phone_number, user.country, users.name, users.date_of_birth, users.address, users.created, users.activated, users.version
|
||||
SELECT
|
||||
users.id,
|
||||
users.email,
|
||||
users.phone_number,
|
||||
users.country,
|
||||
users.name,
|
||||
users.date_of_birth,
|
||||
users.address,
|
||||
users.created,
|
||||
users.activated,
|
||||
users.version
|
||||
FROM users
|
||||
INNER JOIN tokens ON users.id = tokens.user_id
|
||||
WHERE tokens.hash = $1
|
||||
@ -218,7 +229,7 @@ AND tokens.expiry > $3`
|
||||
// value to check against the token expiry.
|
||||
args := []interface{}{tokenHash[:], tokenScope, time.Now()}
|
||||
var user User
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3 * time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Execute the query, scanning the return values into a User struct. If no matching
|
||||
@ -251,7 +262,16 @@ AND tokens.expiry > $3`
|
||||
func (m UserModel) Update(user *User) error {
|
||||
query := `
|
||||
UPDATE users
|
||||
SET email = $1, phone_number = $2, country = $3, name = $4, alt_name = $5, date_of_birth = $6, address = $7, activated = $8, version = version + 1
|
||||
SET
|
||||
email = $1,
|
||||
phone_number = $2,
|
||||
country = $3,
|
||||
name = $4,
|
||||
alt_name = $5,
|
||||
date_of_birth = $6,
|
||||
address = $7,
|
||||
activated = $8,
|
||||
version = version + 1
|
||||
WHERE id = $9 AND version = $10
|
||||
RETURNING version`
|
||||
|
||||
@ -269,7 +289,7 @@ func (m UserModel) Update(user *User) error {
|
||||
user.Version,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3 * time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := m.DB.QueryRowContext(ctx, query, args...).Scan(&user.Version)
|
||||
|
||||
@ -1,26 +1,23 @@
|
||||
package jsonlog
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"os"
|
||||
"runtime/debug"
|
||||
"sync"
|
||||
"time"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"os"
|
||||
"runtime/debug"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
// Define a Level type to represent the severity level for a log entry.
|
||||
|
||||
type Level int8
|
||||
|
||||
// Initialize constants which represent a specific severity level. We use the iota
|
||||
// keyword as a shortcut to assign successive integer values to the constants.
|
||||
const (
|
||||
LevelInfo Level = iota // Has the value 0.
|
||||
LevelError // Has the value 1.
|
||||
LevelFatal // Has the value 2.
|
||||
LevelOff // Has the value 3.
|
||||
LevelInfo Level = iota
|
||||
LevelError
|
||||
LevelFatal
|
||||
LevelOff
|
||||
)
|
||||
|
||||
// Return a human-friendly string for the severity level.
|
||||
func (l Level) String() string {
|
||||
switch l {
|
||||
case LevelInfo:
|
||||
@ -34,17 +31,12 @@ func (l Level) String() string {
|
||||
}
|
||||
}
|
||||
|
||||
// Define a custom Logger type. This holds the output destination that the log entries
|
||||
// will be written to, the minimum severity level that log entries will be written for,
|
||||
// plus a mutex for coordinating the writes.
|
||||
type Logger struct {
|
||||
out io.Writer
|
||||
minLevel Level
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// Return a new Logger instance which writes log entries at or above a minimum severity
|
||||
// level to a specific output destination.
|
||||
func New(out io.Writer, minLevel Level) *Logger {
|
||||
return &Logger{
|
||||
out: out,
|
||||
@ -52,9 +44,6 @@ func New(out io.Writer, minLevel Level) *Logger {
|
||||
}
|
||||
}
|
||||
|
||||
// Declare some helper methods for writing log entries at the different levels. Notice
|
||||
// that these all accept a map as the second parameter which can contain any arbitrary
|
||||
// 'properties' that you want to appear in the log entry.
|
||||
func (l *Logger) PrintInfo(message string, properties map[string]string) {
|
||||
l.print(LevelInfo, message, properties)
|
||||
}
|
||||
@ -65,18 +54,14 @@ func (l *Logger) PrintError(err error, properties map[string]string) {
|
||||
|
||||
func (l *Logger) PrintFatal(err error, properties map[string]string) {
|
||||
l.print(LevelFatal, err.Error(), properties)
|
||||
os.Exit(1) // For entries at the FATAL level, we also terminate the application.
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Print is an internal method for writing the log entry.
|
||||
func (l *Logger) print(level Level, message string, properties map[string]string) (int, error) {
|
||||
// If the severity level of the log entry is below the minimum severity for the
|
||||
// logger, then return with no further action.
|
||||
if level < l.minLevel {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Declare an anonymous struct holding the data for the log entry.
|
||||
aux := struct {
|
||||
Level string `json:"level"`
|
||||
Time string `json:"time"`
|
||||
@ -90,35 +75,23 @@ func (l *Logger) print(level Level, message string, properties map[string]string
|
||||
Properties: properties,
|
||||
}
|
||||
|
||||
// Include a stack trace for entries at the ERROR and FATAL levels.
|
||||
if level >= LevelError {
|
||||
aux.Trace = string(debug.Stack())
|
||||
}
|
||||
|
||||
// Declare a line variable for holding the actual log entry text.
|
||||
var line []byte
|
||||
|
||||
// Marshal the anonymous struct to JSON and store it in the line variable. If there
|
||||
// was a problem creating the JSON, set the contents of the log entry to be that
|
||||
// plain-text error message instead.
|
||||
line, err := json.Marshal(aux)
|
||||
if err != nil {
|
||||
line = []byte(LevelError.String() + ": unable to marshal log message:" + err.Error())
|
||||
}
|
||||
|
||||
// Lock the mutex so that no two writes to the output destination cannot happen
|
||||
// concurrently. If we don't do this, it's possible that the text for two or more
|
||||
// log entries will be intermingled in the output.
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
// Write the log entry followed by a newline.
|
||||
return l.out.Write(append(line, '\n'))
|
||||
}
|
||||
|
||||
// We also implement a Write() method on our Logger type so that it satisfies the
|
||||
// io.Writer interface. This writes a log entry at the ERROR level with no additional
|
||||
// properties.
|
||||
func (l *Logger) Write(message []byte) (n int, err error) {
|
||||
return l.print(LevelError, string(message), nil)
|
||||
}
|
||||
|
||||
@ -26,7 +26,7 @@ The DigitalePartei Team
|
||||
<p>Hi,</p>
|
||||
<p>Thanks for signing up for a Digitale Partei Österreich account. We're excited to have you on board!</p>
|
||||
<p>For future reference, your user ID number is {{.ID}}.</p>
|
||||
<p>Your activation token is {{.token}}</p>
|
||||
<p><a href="dpoe://activate?token={{.token}}">Click here to activate!</a></p>
|
||||
<p>Thanks,</p>
|
||||
<p>The DigitalePartei Team</p>
|
||||
</body>
|
||||
|
||||
@ -13,18 +13,18 @@ CREATE TABLE IF NOT EXISTS users (
|
||||
version INT NOT NULL DEFAULT 1
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS auth_provider (
|
||||
CREATE TABLE IF NOT EXISTS auth_providers (
|
||||
id BIGSERIAL PRIMARY KEY, -- e.g., 'local', 'id_austria',
|
||||
description TEXT NOT NULL,
|
||||
active BOOLEAN DEFAULT false
|
||||
);
|
||||
|
||||
INSERT INTO auth_provider (description, active) VALUES ('local', true);
|
||||
INSERT INTO auth_provider (description, active) VALUES ('id_austria', false);
|
||||
INSERT INTO auth_providers (description, active) VALUES ('local', true);
|
||||
INSERT INTO auth_providers (description, active) VALUES ('id_austria', false);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS user_identities (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
provider_id BIGINT NOT NULL REFERENCES auth_provider(id),
|
||||
provider_id BIGINT NOT NULL REFERENCES auth_providers(id),
|
||||
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
|
||||
-- For local: the username. For OIDC: the 'sub' (Subject ID)
|
||||
@ -33,9 +33,9 @@ CREATE TABLE IF NOT EXISTS user_identities (
|
||||
-- Nullable because OIDC users won't have a password in your DB
|
||||
password bytea,
|
||||
|
||||
version INT NOT NULL DEFAULT 1,
|
||||
version INT NOT NULL DEFAULT 1
|
||||
|
||||
UNIQUE(provider_id, provider_user_id)
|
||||
-- UNIQUE(provider_id, provider_user_id)
|
||||
);
|
||||
|
||||
-- INSERT INTO users(
|
||||
|
||||
@ -3,6 +3,7 @@ DROP INDEX IF EXISTS idx_vote_tokens_issue_id;
|
||||
DROP INDEX IF EXISTS idx_options_issue_id;
|
||||
|
||||
DROP TABLE IF EXISTS votes;
|
||||
DROP TABLE IF EXISTS blind_sign_requests;
|
||||
DROP TABLE IF EXISTS vote_tokens;
|
||||
DROP TABLE IF EXISTS options;
|
||||
DROP TABLE IF EXISTS issues;
|
||||
|
||||
@ -4,6 +4,9 @@ CREATE TABLE IF NOT EXISTS issues (
|
||||
description TEXT,
|
||||
start_time TIMESTAMPTZ NOT NULL,
|
||||
end_time TIMESTAMPTZ NOT NULL,
|
||||
rsa_n BYTEA,
|
||||
rsa_e INT,
|
||||
rsa_private_pem BYTEA,
|
||||
created TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
version INT NOT NULL DEFAULT 1
|
||||
);
|
||||
@ -25,6 +28,13 @@ CREATE TABLE IF NOT EXISTS vote_tokens (
|
||||
version INT NOT NULL DEFAULT 1
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS blind_sign_requests (
|
||||
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
issue_id BIGINT NOT NULL REFERENCES issues(id) ON DELETE CASCADE,
|
||||
created TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
PRIMARY KEY (user_id, issue_id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS votes (
|
||||
id BIGINT GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
|
||||
token UUID NOT NULL UNIQUE REFERENCES vote_tokens(token) ON DELETE CASCADE,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user