diff --git a/cmd/party/api/users.go b/cmd/party/api/users.go index fe0b52b..1d7c5a2 100644 --- a/cmd/party/api/users.go +++ b/cmd/party/api/users.go @@ -4,7 +4,6 @@ import ( "errors" "net/http" "strconv" - "time" "party.at/party/cmd/party/common" "party.at/party/internal/data" @@ -41,38 +40,14 @@ func (api *Api) ListUsers(w http.ResponseWriter, r *http.Request) { } func (api *Api) CreateUser(w http.ResponseWriter, r *http.Request) { - var input struct { - ProviderId int64 `json:"provider_id"` - Username string `json:"username"` - PhoneNumber string `json:"phone_number"` - Country string `json:"country"` - Email string `json:"email"` - Password string `json:"password"` - Name string `json:"name"` - AltName *string `json:"alt_name"` - DateOfBirth time.Time `json:"date_of_birth"` - Address string `json:"address"` - } + var input common.RegisterUserInput if err := common.ReadJSON(w, r, &input); err != nil { api.errorResponse(w, r, data.ErrBadRequest) return } - - - user, authToken, err := api.App.RegisterUser(common.RegisterUserInput{ - ProviderID: input.ProviderId, - Username: input.Username, - PhoneNumber: input.PhoneNumber, - Country: input.Country, - Email: input.Email, - Password: input.Password, - Name: input.Name, - AltName: input.AltName, - DateOfBirth: input.DateOfBirth, - Address: input.Address, - }) + user, authToken, err := api.App.RegisterUser(input) if err != nil { api.errorResponse(w, r, err) return @@ -84,9 +59,11 @@ func (api *Api) CreateUser(w http.ResponseWriter, r *http.Request) { } func (api *Api) ReadUser(w http.ResponseWriter, r *http.Request) { - var id int64 - param := r.PathValue("id") + user := common.GetUser(r) + param := r.PathValue("id") + + var id int64 if param == "me" { id = common.GetUser(r).ID } else { @@ -98,12 +75,7 @@ func (api *Api) ReadUser(w http.ResponseWriter, r *http.Request) { } } - if common.GetUser(r).ID != id { - api.errorResponse(w, r, data.ErrNotPermitted) - return - } - - user, err := api.App.GetUser(id) + target, err := api.App.GetUser(user, id) if err != nil { if errors.Is(err, data.ErrRecordNotFound) { api.errorResponse(w, r, data.ErrRecordNotFound) @@ -113,7 +85,7 @@ func (api *Api) ReadUser(w http.ResponseWriter, r *http.Request) { return } - if err = common.WriteJSON(w, http.StatusOK, common.Envelope{"user": user}, nil); err != nil { + if err = common.WriteJSON(w, http.StatusOK, common.Envelope{"user": target}, nil); err != nil { api.ServerErrorResponse(w, r, err) } } diff --git a/cmd/party/common/users.go b/cmd/party/common/users.go index 143e0c3..4456169 100644 --- a/cmd/party/common/users.go +++ b/cmd/party/common/users.go @@ -8,16 +8,17 @@ import ( ) type RegisterUserInput struct { - ProviderID int64 - Username string - PhoneNumber string - Country string - Email string - Password string - Name string - AltName *string - DateOfBirth time.Time - Address string + ProviderID int64 `json:"provider_id"` + Username string `json:"username"` + PhoneNumber string `json:"phone_number"` + Country string `json:"country"` + Email string `json:"email"` + Password string `json:"password"` + Name string `json:"name"` + AltName *string `json:"alt_name"` + DateOfBirth time.Time `json:"date_of_birth"` + Address string `json:"address"` + Role string `json:"role"` } func (app *Application) RegisterUser(input RegisterUserInput) (*data.User, *data.Token, error) { @@ -53,8 +54,8 @@ func (app *Application) RegisterUser(input RegisterUserInput) (*data.User, *data } role := "viewer" - if app.Config.Env == "development" { - role = "admin" + if app.Config.Env == "development" && input.Role != "" { + role = input.Role } if err := app.Models.Roles.AssignToUser(user.ID, role); err != nil { @@ -85,8 +86,54 @@ func (app *Application) RegisterUser(input RegisterUserInput) (*data.User, *data return user, authToken, nil } -func (app *Application) GetUser(id int64) (*data.User, error) { - return app.Models.Users.Get(id) +func (app *Application) GetUser(user *data.User, id int64) (*data.User, error) { + var ret data.User + + roles, err := app.Models.Roles.GetAllForUser(user.ID) + if err != nil { + return &ret, err + } + + target, err := app.Models.Users.Get(id) + if err != nil { + return &ret, err + } + + if user.ID == target.ID { + return target, nil + } + + if roles[0].Code == "viewer" { + return nil, data.ErrNotPermitted + } + + if roles[0].Code == "contributor" { + return nil, data.ErrInvalidCredentials + } + + if roles[0].Code == "member_of_parliament" { + ret.Email = target.Email + ret.PhoneNumber = target.PhoneNumber + ret.Country = target.Country + ret.Name = target.Name + ret.AltName = target.AltName + ret.DateOfBirth = target.DateOfBirth + ret.Address = target.Address + } + + if roles[0].Code == "party_leadership" { + + } + + if roles[0].Code == "admin" { + ret.ID = target.ID + ret.LastLogin = target.LastLogin + ret.Activated = target.Activated + ret.Created = target.Created + ret.Version = target.Version + } + + return &ret, nil } func (app *Application) ListUsers(filters data.Filters) ([]*data.User, data.Metadata, error) { diff --git a/cmd/party/testutils_test.go b/cmd/party/testutils_test.go index a0ead73..ff12d07 100644 --- a/cmd/party/testutils_test.go +++ b/cmd/party/testutils_test.go @@ -92,8 +92,14 @@ func (ts *testServer) get(t *testing.T, path string) (int, http.Header, []byte) return rs.StatusCode, rs.Header, body } -func (ts *testServer) registerAndLogin(t *testing.T, email, password string) string { +func (ts *testServer) registerAndLogin(t *testing.T, email, password string, roles ...string) string { t.Helper() + + var role string + if len(role) > 0 { + role = roles[0] + } + registerBody := map[string]any{ "email": email, "password": password, @@ -101,6 +107,7 @@ func (ts *testServer) registerAndLogin(t *testing.T, email, password string) str "name": "Test User", "alt_name": "", "provider_id": 1, + "role": role, } code, _, body := ts.postJSON(t, "/v1/users", registerBody) if code != http.StatusCreated { diff --git a/cmd/party/users_test.go b/cmd/party/users_test.go new file mode 100644 index 0000000..abb174f --- /dev/null +++ b/cmd/party/users_test.go @@ -0,0 +1,27 @@ +package main + +import ( + "bytes" + // "net/http" + // "strconv" + "testing" + // "time" +) + +func TestReadUserHandler(t *testing.T) { + app := newTestApplication(t) + ts := newTestServer(t, app, routes(app)) + defer ts.Close() + + token := ts.registerAndLogin(t, uniqueEmail(), "pa$$word123", "viewer") + + code, _, body := ts.getWithToken(t, "/v1/users/1", token) + + if code != 200 { + t.Errorf("want %d; got %d", 200, code) + } + + if !bytes.Contains(body, []byte("An old silent pond...")) { + t.Errorf("want body to contain %q; got %q", "", string(body)) + } +} diff --git a/cmd/party/web/users.go b/cmd/party/web/users.go index a23fa96..9be22cd 100644 --- a/cmd/party/web/users.go +++ b/cmd/party/web/users.go @@ -75,7 +75,7 @@ func (web *Web) ProfilePage(w http.ResponseWriter, r *http.Request) { return } - fullUser, err := web.App.GetUser(user.ID) + fullUser, err := web.App.GetUser(user, user.ID) if err != nil { if errors.Is(err, data.ErrRecordNotFound) { http.NotFound(w, r) diff --git a/internal/data/errors.go b/internal/data/errors.go index 99527cb..b6be00d 100644 --- a/internal/data/errors.go +++ b/internal/data/errors.go @@ -7,26 +7,28 @@ import ( type ErrorCode int const ( + // 400 variants + // 401 variants - ErrCodeInvalidCredentials ErrorCode = 4011 - ErrCodeInvalidAuthToken ErrorCode = 4012 - ErrCodeAuthRequired ErrorCode = 4013 + errCodeInvalidCredentials ErrorCode = 4011 + errCodeInvalidAuthToken ErrorCode = 4012 + errCodeAuthRequired ErrorCode = 4013 // 403 variants - ErrCodeInactiveAccount ErrorCode = 4031 - ErrCodeNotPermitted ErrorCode = 4032 + errCodeInactiveAccount ErrorCode = 4031 + errCodeNotPermitted ErrorCode = 4032 // 409 variants - ErrCodeEditConflict ErrorCode = 4091 - ErrCodeAlreadyVoted ErrorCode = 4092 - ErrCodeAlreadyBlindSigned ErrorCode = 4093 - ErrCodeVoteAlreadyCast ErrorCode = 4094 + errCodeEditConflict ErrorCode = 4091 + errCodeAlreadyVoted ErrorCode = 4092 + errCodeAlreadyBlindSigned ErrorCode = 4093 + errCodeVoteAlreadyCast ErrorCode = 4094 // 422 variants - ErrCodeValidationFailed ErrorCode = 4221 - ErrCodeBlindedVoteRange ErrorCode = 4222 - ErrCodeInvalidSignature ErrorCode = 4223 - ErrCodeHasNotStarted ErrorCode = 4224 + errCodeValidationFailed ErrorCode = 4221 + errCodeBlindedVoteRange ErrorCode = 4222 + errCodeInvalidSignature ErrorCode = 4223 + errCodeHasNotStarted ErrorCode = 4224 ) type Error struct { @@ -60,41 +62,41 @@ func New(httpCode int, code ErrorCode, text string) error { var ( // 400 Bad Request - ErrFailedPEM = New(400, 0, "failed to decode PEM block") - ErrBadlyFormedJSON = New(400, 0, "body contains badly-formed JSON") - ErrBodyEmpty = New(400, 0, "body must not be empty") - ErrSingleValue = New(400, 0, "body must only contain a single JSON value") - ErrInvalidID = New(400, 0, "invalid id parameter") - ErrBadRequest = New(400, 0, "the server cannot process the request due to a client error") + ErrFailedPEM = New(400, 400, "failed to decode PEM block") + ErrBadlyFormedJSON = New(400, 400, "body contains badly-formed JSON") + ErrBodyEmpty = New(400, 400, "body must not be empty") + ErrSingleValue = New(400, 400, "body must only contain a single JSON value") + ErrInvalidID = New(400, 400, "invalid id parameter") + ErrBadRequest = New(400, 400, "the server cannot process the request due to a client error") // 401 Unauthorized - ErrInvalidCredentials = New(401, 4011, "invalid credentials") - ErrInvalidAuthToken = New(401, 4012, "invalid or missing authentication token") - ErrNoToken = New(401, 4012, "token must be provided") - ErrAuthRequired = New(401, 4013, "you must be authenticated to access this resource") - ErrHasNotStarted = New(401, 4224, "the vote has not yet started") + ErrInvalidCredentials = New(401, errCodeInvalidCredentials, "invalid credentials") + ErrInvalidAuthToken = New(401, errCodeInvalidAuthToken, "invalid or missing authentication token") + ErrNoToken = New(401, errCodeInvalidAuthToken, "token must be provided") + ErrAuthRequired = New(401, errCodeAuthRequired, "you must be authenticated to access this resource") // 403 Forbidden - ErrInactiveAccount = New(403, 4031, "your user account must be activated to access this resource") - ErrNotPermitted = New(403, 4032, "your user account doesn't have the necessary permissions to access this resource") + ErrInactiveAccount = New(403, errCodeInactiveAccount, "your user account must be activated to access this resource") + ErrNotPermitted = New(403, errCodeNotPermitted, "your user account doesn't have the necessary permissions to access this resource") // 404 Not Found - ErrRecordNotFound = New(404, 0, "record not found") - ErrNoPath = New(404, 0, "path is required") + ErrRecordNotFound = New(404, 404, "record not found") + ErrNoPath = New(404, 404, "path is required") // 409 Conflict - ErrEditConflict = New(409, 4091, "edit conflict") - ErrDuplicateVote = New(409, 4092, "this signature has already been used to cast a vote") - ErrDuplicateBlindSign = New(409, 4093, "user has already requested a blind signature for this issue") - ErrDuplicateSignature = New(409, 4094, "this signature has already been used to cast a vote") - ErrDuplicateEmail = New(409, 0, "duplicate email") - ErrDuplicateUser = New(409, 0, "duplicate username") + ErrEditConflict = New(409, errCodeEditConflict, "edit conflict") + ErrDuplicateVote = New(409, errCodeAlreadyVoted, "this signature has already been used to cast a vote") + ErrDuplicateBlindSign = New(409, errCodeAlreadyBlindSigned, "user has already requested a blind signature for this issue") + ErrDuplicateSignature = New(409, errCodeVoteAlreadyCast, "this signature has already been used to cast a vote") + ErrDuplicateEmail = New(409, 409, "duplicate email") + ErrDuplicateUser = New(409, 409, "duplicate username") // 422 Unprocessable Entity - ErrValidationFailed = New(422, 4221, "validation failed") - ErrInvalidBlindedVote = New(422, 4222, "blinded_vote is out of valid range [1, n-1]") - ErrInvalidSignature = New(422, 4223, "signature verification failed") + ErrValidationFailed = New(422, errCodeValidationFailed, "validation failed") + ErrInvalidBlindedVote = New(422, errCodeBlindedVoteRange, "blinded_vote is out of valid range [1, n-1]") + ErrInvalidSignature = New(422, errCodeInvalidSignature, "signature verification failed") + ErrHasNotStarted = New(422, errCodeHasNotStarted, "the vote has not yet started") // 429 Too Many Requests - ErrRateLimitExceeded = New(429, 0, "rate limit exceeded") + ErrRateLimitExceeded = New(429, 429, "rate limit exceeded") ) diff --git a/internal/data/users.go b/internal/data/users.go index ab88a43..ed63d05 100644 --- a/internal/data/users.go +++ b/internal/data/users.go @@ -15,17 +15,17 @@ import ( var AnonymousUser = &User{} type User struct { - ID int64 `json:"id"` - Email string `json:"email"` - PhoneNumber string `json:"phone_number"` - Country string `json:"country"` - Name string `json:"name"` - AltName *string `json:"alt_name"` - DateOfBirth time.Time `json:"date_of_birth"` - Address string `json:"address"` - Created time.Time `json:"created"` - LastLogin time.Time `json:"last_login"` - Activated bool `json:"activated"` + ID int64 `json:"id,omitempty"` + Email string `json:"email,omitempty"` + PhoneNumber string `json:"phone_number,omitempty"` + Country string `json:"country,omitempty"` + Name string `json:"name,omitempty"` + AltName *string `json:"alt_name,omitempty"` + DateOfBirth time.Time `json:"date_of_birth,omitempty"` + Address string `json:"address,omitempty"` + Created time.Time `json:"created,omitempty"` + LastLogin time.Time `json:"last_login,omitempty"` + Activated bool `json:"activated,omitempty"` Version int32 `json:"-"` } @@ -306,10 +306,29 @@ func (m UserModel) Update(user *User) error { func (m UserModel) GetAll(filters Filters) ([]*User, Metadata, error) { query := fmt.Sprintf(` - SELECT COUNT(*) OVER(), id, email, phone_number, country, name, alt_name, date_of_birth, address, created, last_login, activated, version - FROM users - ORDER BY %s %s, id ASC - LIMIT $1 OFFSET $2`, +SELECT + COUNT(*) OVER(), + id, + email, + phone_number, + country, + name, + alt_name, + date_of_birth, + address, + created, + last_login, + activated, + version +FROM + users +ORDER BY + %s %s, + id ASC +LIMIT + $1 +OFFSET + $2`, filters.sortColumn(), filters.sortDirection(), )