344 lines
7.8 KiB
Go
344 lines
7.8 KiB
Go
package data
|
|
|
|
import (
|
|
"context"
|
|
"time"
|
|
"party.at/party/internal/validator"
|
|
"database/sql"
|
|
"github.com/lib/pq"
|
|
"errors"
|
|
"crypto/sha256"
|
|
)
|
|
|
|
var (
|
|
ErrDuplicateEmail = errors.New("duplicate email")
|
|
ErrDuplicateUser = errors.New("duplicate username")
|
|
)
|
|
|
|
var AnonymousUser = &User{}
|
|
|
|
type User struct {
|
|
ID int64 `json:"id"`
|
|
Email string `json:"email"`
|
|
PhoneNumber string `json:"phone_number"`
|
|
Country string `json:"country"`
|
|
Name string `json:"name"`
|
|
AltName *string `json:"alt_name"`
|
|
DateOfBirth time.Time `json:"date_of_birth"`
|
|
Address string `json:"address"`
|
|
Created time.Time `json:"created"`
|
|
LastLogin time.Time `json:"last_login"`
|
|
Activated bool `json:"activated"`
|
|
Version int32 `json:"-"`
|
|
}
|
|
|
|
func (u *User) IsAnonymous() bool {
|
|
return u == AnonymousUser
|
|
}
|
|
|
|
func ValidateEmail(v *validator.Validator, email string) {
|
|
v.Check(email != "", "email", "must be provided")
|
|
v.Check(validator.Matches(email, validator.EmailRX), "email" , "must be a valid email address")
|
|
}
|
|
|
|
func ValidateUser(v *validator.Validator, user *User) {
|
|
ValidateEmail(v, user.Email)
|
|
}
|
|
|
|
type UserModel struct {
|
|
DB *sql.DB
|
|
}
|
|
|
|
func (m UserModel) ExecuteRegistrationTx(user *User, userIdentity *UserIdentity) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5 * time.Second)
|
|
defer cancel()
|
|
|
|
tx, err := m.DB.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
query := `
|
|
INSERT INTO users (email, phone_number, country, name, alt_name, date_of_birth, address, activated)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
|
RETURNING id, created, last_login, version`
|
|
|
|
args := []interface{}{
|
|
user.Email,
|
|
user.PhoneNumber,
|
|
user.Country,
|
|
user.Name,
|
|
user.AltName,
|
|
user.DateOfBirth,
|
|
user.Address,
|
|
user.Activated,
|
|
}
|
|
|
|
err = tx.QueryRowContext(ctx, query, args...).Scan(&user.ID, &user.Created, &user.LastLogin, &user.Version)
|
|
if pgErr, ok := err.(*pq.Error); ok {
|
|
if pgErr.Code == "23505" {
|
|
if pgErr.Constraint == "users_email_key" {
|
|
return ErrDuplicateEmail
|
|
}
|
|
}
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
userIdentity.UserID = user.ID
|
|
|
|
// Insert Identity
|
|
query = `
|
|
INSERT INTO user_identities (provider_id, user_id, provider_user_id, password)
|
|
VALUES ($1, $2, $3, $4)
|
|
RETURNING id, version`
|
|
|
|
args = []interface{}{
|
|
userIdentity.ProviderID,
|
|
userIdentity.UserID,
|
|
userIdentity.ProviderUserID,
|
|
userIdentity.Password.hash,
|
|
}
|
|
|
|
err = tx.QueryRowContext(ctx, query, args...).Scan(&userIdentity.ID, &userIdentity.Version)
|
|
if pgErr, ok := err.(*pq.Error); ok {
|
|
if pgErr.Code == "23505" {
|
|
if pgErr.Constraint == "user_identities_provider_id_provider_user_id_key" {
|
|
return ErrDuplicateUser
|
|
}
|
|
}
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = tx.Commit()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (m UserModel) Get(id int64) (*User, error) {
|
|
if id < 1 {
|
|
return nil, ErrRecordNotFound
|
|
}
|
|
|
|
// Define the SQL query for retrieving the issue data.
|
|
query :=`
|
|
SELECT id, email, phone_number, country, name, alt_name, date_of_birth, address, created, last_login, activated, version
|
|
FROM users
|
|
WHERE id = $1`
|
|
|
|
// Declare a User struct to hold the data returned by the query.
|
|
var user User
|
|
|
|
err := m.DB.QueryRow(query, id).Scan(
|
|
&user.ID,
|
|
&user.Email,
|
|
&user.PhoneNumber,
|
|
&user.Country,
|
|
&user.Name,
|
|
&user.AltName,
|
|
&user.DateOfBirth,
|
|
&user.Address,
|
|
&user.Created,
|
|
&user.LastLogin,
|
|
&user.Activated,
|
|
&user.Version,
|
|
)
|
|
if err != nil {
|
|
switch {
|
|
case errors.Is(err, sql.ErrNoRows):
|
|
return nil, ErrRecordNotFound
|
|
default:
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return &user, nil
|
|
}
|
|
|
|
func (m UserModel) GetByEmail(email string) (*User, error) {
|
|
query :=`
|
|
SELECT id, email, phone_number, country, name, alt_name, date_of_birth, address, created, last_login, activated, version
|
|
FROM users
|
|
WHERE email = $1`
|
|
|
|
var user User
|
|
ctx, cancel := context.WithTimeout(context.Background(), 3 * time.Second)
|
|
defer cancel()
|
|
|
|
err := m.DB.QueryRowContext(ctx, query, email).Scan(
|
|
&user.ID,
|
|
&user.Email,
|
|
&user.PhoneNumber,
|
|
&user.Country,
|
|
&user.Name,
|
|
&user.AltName,
|
|
&user.DateOfBirth,
|
|
&user.Address,
|
|
&user.Created,
|
|
&user.LastLogin,
|
|
&user.Activated,
|
|
&user.Version,
|
|
)
|
|
|
|
if err != nil {
|
|
switch {
|
|
case errors.Is(err, sql.ErrNoRows):
|
|
return nil, ErrRecordNotFound
|
|
default:
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return &user, nil
|
|
}
|
|
|
|
func (m UserModel) GetForToken(tokenScope, tokenPlaintext string) (*User, error) {
|
|
// Calculate the SHA-256 hash of the plaintext token provided by the client.
|
|
// Remember that this returns a byte *array* with length 32, not a slice.
|
|
tokenHash := sha256.Sum256([]byte(tokenPlaintext))
|
|
|
|
// Set up the SQL query.
|
|
query :=`
|
|
SELECT
|
|
users.id,
|
|
users.email,
|
|
users.phone_number,
|
|
users.country,
|
|
users.name,
|
|
users.date_of_birth,
|
|
users.address,
|
|
users.created,
|
|
users.activated,
|
|
users.version
|
|
FROM users
|
|
INNER JOIN tokens ON users.id = tokens.user_id
|
|
WHERE tokens.hash = $1
|
|
AND tokens.scope = $2
|
|
AND tokens.expiry > $3`
|
|
|
|
// Create a slice containing the query arguments. Notice how we use the [:] operator
|
|
// to get a slice containing the token hash, rather than passing in the array (which
|
|
// is not supported by the pq driver), and that we pass the current time as the
|
|
// value to check against the token expiry.
|
|
args := []interface{}{tokenHash[:], tokenScope, time.Now()}
|
|
var user User
|
|
ctx, cancel := context.WithTimeout(context.Background(), 3 * time.Second)
|
|
defer cancel()
|
|
|
|
// Execute the query, scanning the return values into a User struct. If no matching
|
|
// record is found we return an ErrRecordNotFound error.
|
|
err := m.DB.QueryRowContext(ctx, query, args...).Scan(
|
|
&user.ID,
|
|
&user.Email,
|
|
&user.PhoneNumber,
|
|
&user.Country,
|
|
&user.Name,
|
|
&user.DateOfBirth,
|
|
&user.Address,
|
|
&user.Created,
|
|
&user.Activated,
|
|
&user.Version,
|
|
)
|
|
if err != nil {
|
|
switch {
|
|
case errors.Is(err, sql.ErrNoRows):
|
|
return nil, ErrRecordNotFound
|
|
default:
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
// Return the matching user.
|
|
return &user, nil
|
|
}
|
|
|
|
func (m UserModel) Update(user *User) error {
|
|
query := `
|
|
UPDATE users
|
|
SET
|
|
email = $1,
|
|
phone_number = $2,
|
|
country = $3,
|
|
name = $4,
|
|
alt_name = $5,
|
|
date_of_birth = $6,
|
|
address = $7,
|
|
activated = $8,
|
|
version = version + 1
|
|
WHERE id = $9 AND version = $10
|
|
RETURNING version`
|
|
|
|
// Create an args slice containing the values for the placeholder parameters.
|
|
args := []interface{}{
|
|
user.Email,
|
|
user.PhoneNumber,
|
|
user.Country,
|
|
user.Name,
|
|
user.AltName,
|
|
user.DateOfBirth,
|
|
user.Address,
|
|
user.Activated,
|
|
user.ID,
|
|
user.Version,
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 3 * time.Second)
|
|
defer cancel()
|
|
|
|
err := m.DB.QueryRowContext(ctx, query, args...).Scan(&user.Version)
|
|
if err != nil {
|
|
switch {
|
|
case err.Error() == `pq: duplicate key value violates unique constraint "users_email_key"`:
|
|
return ErrDuplicateEmail
|
|
case errors.Is(err, sql.ErrNoRows):
|
|
return ErrEditConflict
|
|
default:
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (m UserModel) Delete(id int64) error {
|
|
if id < 1 {
|
|
return ErrRecordNotFound
|
|
}
|
|
|
|
// Construct the SQL query to delete the record.
|
|
query := `
|
|
DELETE FROM users
|
|
WHERE id = $1`
|
|
|
|
// Execute the SQL query using the Exec() method, passing in the id variable as
|
|
// the value for the placeholder parameter. The Exec() method returns a sql.Result
|
|
// object.
|
|
result, err := m.DB.Exec(query, id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Call the RowsAffected() method on the sql.Result object to get the number of rows
|
|
// affected by the query.
|
|
rowsAffected, err := result.RowsAffected()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// If no rows were affected, we know that the issues table didn't contain a record
|
|
// with the provided ID at the moment we tried to delete it. In that case we
|
|
// return an ErrRecordNotFound error.
|
|
if rowsAffected == 0 {
|
|
return ErrRecordNotFound
|
|
}
|
|
|
|
return nil
|
|
}
|