package data import ( "context" "time" "party.at/party/internal/validator" "database/sql" "github.com/lib/pq" "errors" "crypto/sha256" ) var ( ErrDuplicateEmail = errors.New("duplicate email") ErrDuplicateUser = errors.New("duplicate username") ) var AnonymousUser = &User{} type User struct { ID int64 `json:"id"` Email string `json:"email"` PhoneNumber string `json:"phone_number"` Country string `json:"country"` Name string `json:"name"` AltName sql.NullString `json:"alt_name"` DateOfBirth time.Time `json:"date_of_birth"` Address string `json:"address"` Created time.Time `json:"created"` LastLogin time.Time `json:"last_login"` Activated bool `json:"activated"` Version int32 `json:"-"` } func (u *User) IsAnonymous() bool { return u == AnonymousUser } func ValidateEmail(v *validator.Validator, email string) { v.Check(email != "", "email", "must be provided") v.Check(validator.Matches(email, validator.EmailRX), "email" , "must be a valid email address") } func ValidateUser(v *validator.Validator, user *User) { ValidateEmail(v, user.Email) } type UserModel struct { DB *sql.DB } func (m UserModel) ExecuteRegistrationTx(user *User, userIdentity *UserIdentity) error { ctx, cancel := context.WithTimeout(context.Background(), 5 * time.Second) defer cancel() tx, err := m.DB.BeginTx(ctx, nil) if err != nil { return err } defer tx.Rollback() query := ` INSERT INTO users (email, phone_number, country, name, alt_name, date_of_birth, address) VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING id, created, last_login, version` args := []interface{}{ user.Email, user.PhoneNumber, user.Country, user.Name, user.AltName, user.DateOfBirth, user.Address, } err = tx.QueryRowContext(ctx, query, args...).Scan(&user.ID, &user.Created, &user.LastLogin, &user.Version) if pgErr, ok := err.(*pq.Error); ok { if pgErr.Code == "23505" { if pgErr.Constraint == "users_email_key" { return ErrDuplicateEmail } } } if err != nil { return err } userIdentity.UserID = user.ID // Insert Identity query = ` INSERT INTO user_identities (provider_id, user_id, provider_user_id, password) VALUES ($1, $2, $3, $4) RETURNING id, version` args = []interface{}{ userIdentity.ProviderID, userIdentity.UserID, userIdentity.ProviderUserID, userIdentity.Password.hash, } err = tx.QueryRowContext(ctx, query, args...).Scan(&userIdentity.ID, &userIdentity.Version) if pgErr, ok := err.(*pq.Error); ok { if pgErr.Code == "23505" { if pgErr.Constraint == "user_identities_provider_id_provider_user_id_key" { return ErrDuplicateUser } } } if err != nil { return err } err = tx.Commit() if err != nil { return err } return nil } func (m UserModel) Get(id int64) (*User, error) { if id < 1 { return nil, ErrRecordNotFound } // Define the SQL query for retrieving the issue data. query :=` SELECT id, email, phone_number, country, name, alt_name, date_of_birth, address, created, last_login, activated, version FROM users WHERE id = $1` // Declare a User struct to hold the data returned by the query. var user User err := m.DB.QueryRow(query, id).Scan( &user.ID, &user.Email, &user.PhoneNumber, &user.Country, &user.Name, &user.AltName, &user.DateOfBirth, &user.Address, &user.Created, &user.LastLogin, &user.Activated, &user.Version, ) if err != nil { switch { case errors.Is(err, sql.ErrNoRows): return nil, ErrRecordNotFound default: return nil, err } } return &user, nil } func (m UserModel) GetByEmail(email string) (*User, error) { query :=` SELECT id, email, phone, country, name, alt_name, date_of_birth, address, created, last_login, activated, version FROM users WHERE email = $1` var user User ctx, cancel := context.WithTimeout(context.Background(), 3 * time.Second) defer cancel() err := m.DB.QueryRowContext(ctx, query, email).Scan( &user.ID, &user.Email, &user.PhoneNumber, &user.Country, &user.Name, &user.AltName, &user.DateOfBirth, &user.Address, &user.Created, &user.LastLogin, &user.Activated, &user.Version, ) if err != nil { switch { case errors.Is(err, sql.ErrNoRows): return nil, ErrRecordNotFound default: return nil, err } } return &user, nil } func (m UserModel) GetForToken(tokenScope, tokenPlaintext string) (*User, error) { // Calculate the SHA-256 hash of the plaintext token provided by the client. // Remember that this returns a byte *array* with length 32, not a slice. tokenHash := sha256.Sum256([]byte(tokenPlaintext)) // Set up the SQL query. query :=` SELECT users.id, users.email, user.phone_number, user.country, users.name, users.date_of_birth, users.address, users.created, users.activated, users.version FROM users INNER JOIN tokens ON users.id = tokens.user_id WHERE tokens.hash = $1 AND tokens.scope = $2 AND tokens.expiry > $3` // Create a slice containing the query arguments. Notice how we use the [:] operator // to get a slice containing the token hash, rather than passing in the array (which // is not supported by the pq driver), and that we pass the current time as the // value to check against the token expiry. args := []interface{}{tokenHash[:], tokenScope, time.Now()} var user User ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() // Execute the query, scanning the return values into a User struct. If no matching // record is found we return an ErrRecordNotFound error. err := m.DB.QueryRowContext(ctx, query, args...).Scan( &user.ID, &user.Email, &user.PhoneNumber, &user.Country, &user.Name, &user.DateOfBirth, &user.Address, &user.Created, &user.Activated, &user.Version, ) if err != nil { switch { case errors.Is(err, sql.ErrNoRows): return nil, ErrRecordNotFound default: return nil, err } } // Return the matching user. return &user, nil } func (m UserModel) Update(user *User) error { query := ` UPDATE users SET email = $1, phone_number = $2, country = $3, name = $4, alt_name = $5, date_of_birth = $6, address = $7, activated = $8, version = version + 1 WHERE id = $9 AND version = $10 RETURNING version` // Create an args slice containing the values for the placeholder parameters. args := []interface{}{ user.Email, user.PhoneNumber, user.Country, user.Name, user.AltName, user.DateOfBirth, user.Address, user.Activated, user.ID, user.Version, } ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() err := m.DB.QueryRowContext(ctx, query, args...).Scan(&user.Version) if err != nil { switch { case err.Error() == `pq: duplicate key value violates unique constraint "users_email_key"`: return ErrDuplicateEmail case errors.Is(err, sql.ErrNoRows): return ErrEditConflict default: return err } } return nil } func (m UserModel) Delete(id int64) error { if id < 1 { return ErrRecordNotFound } // Construct the SQL query to delete the record. query := ` DELETE FROM users WHERE id = $1` // Execute the SQL query using the Exec() method, passing in the id variable as // the value for the placeholder parameter. The Exec() method returns a sql.Result // object. result, err := m.DB.Exec(query, id) if err != nil { return err } // Call the RowsAffected() method on the sql.Result object to get the number of rows // affected by the query. rowsAffected, err := result.RowsAffected() if err != nil { return err } // If no rows were affected, we know that the issues table didn't contain a record // with the provided ID at the moment we tried to delete it. In that case we // return an ErrRecordNotFound error. if rowsAffected == 0 { return ErrRecordNotFound } return nil }