party/internal/data/user_identities.go

307 lines
7.8 KiB
Go

package data
import (
"golang.org/x/crypto/bcrypt"
"errors"
"party.at/party/internal/validator"
"database/sql"
"context"
"time"
"crypto/sha256"
)
type UserIdentity struct {
ID int64 `json:"id"`
ProviderID int64 `json:"provider_id"`
UserID int64 `json:"user_id"`
ProviderUserID string `json:"provider_user_id"`
Password password `json:"-"`
Version int32 `json:"version"`
}
type password struct {
plaintext *string
hash []byte
}
func (p *password) Set(plaintextPassword string) error {
hash, err := bcrypt.GenerateFromPassword([]byte(plaintextPassword), 12)
if err != nil {
return err
}
p.plaintext = &plaintextPassword
p.hash = hash
return nil
}
func (p *password) Matches(plaintextPassword string) (bool, error) {
err := bcrypt.CompareHashAndPassword(p.hash, []byte(plaintextPassword))
if err != nil {
switch {
case errors.Is(err, bcrypt.ErrMismatchedHashAndPassword):
return false, nil
default:
return false, err
}
}
return true, nil
}
func ValidatePasswordPlaintext(v *validator.Validator, password string) {
v.Check(password != "", "password", "must be provided")
v.Check(len(password) >= 8, "password", "must be at least 8 bytes long")
v.Check(len(password) <= 72, "password", "must not be more than 72 bytes long")
}
func ValidateUserIdentity(v *validator.Validator, userIdentity *UserIdentity) {
v.Check(userIdentity.ProviderID == int64(ProviderLocal), "provider_id", "must be 1");
// if userIdentity.ProviderID == int64(ProviderLocal) {
// v.Check(userIdentity.ProviderUserID != "", "username", "must be provided")
// v.Check(len(userIdentity.ProviderUserID) <= 500, "username", "must not be more than 500 bytes long")
// }
if userIdentity.Password.plaintext != nil {
ValidatePasswordPlaintext(v,
*userIdentity.Password.plaintext)
}
if userIdentity.Password.hash == nil {
panic("missing password hash for user")
}
}
type UserIdentityModel struct {
DB *sql.DB
}
// func (m UserIdentityModel) Insert(userIdentity *UserIdentity) error {
// query := `
// INSERT INTO user_identities (provider_id, user_id, provider_user, password)
// VALUES ($1, $2, $3, $4)
// RETURNING id, version`
// args := []interface{}{
// userIdentity.ProviderID,
// userIdentity.UserID,
// userIdentity.ProviderUserID,
// userIdentity.Password.hash,
// }
// ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
// defer cancel()
// err := m.DB.QueryRowContext(ctx, query, args...).Scan(&userIdentity.ID, &userIdentity.Version)
// if err != nil {
// switch {
// case err.Error() ==
// `pq: duplicate key value violates unique constraint "users_email_key"`:
// return ErrDuplicateEmail
// default:
// return err
// }
// }
// return nil
// }
func (m UserIdentityModel) Get(id int64) (*UserIdentity, error) {
if id < 1 {
return nil, ErrRecordNotFound
}
query :=`
SELECT id, provider_id, user_id, provider_user, password, version
FROM user_identities
WHERE id = $1`
// Declare a User struct to hold the data returned by the query.
var userIdentity UserIdentity
// Execute the query using the QueryRow() method, passing in the provided id value
// as a placeholder parameter, and scan the response data into the fields of the
// User struct. Importantly, notice that we need to convert the scan target for the
// genres column using the pq.Array() adapter function again.
err := m.DB.QueryRow(query, id).Scan(
&userIdentity.ID,
&userIdentity.ProviderID,
&userIdentity.UserID,
&userIdentity.ProviderUserID,
&userIdentity.Password,
&userIdentity.Version,
)
if err != nil {
switch {
case errors.Is(err, sql.ErrNoRows):
return nil, ErrRecordNotFound
default:
return nil, err
}
}
// Otherwise, return a pointer to the User struct.
return &userIdentity, nil
}
func (m UserIdentityModel) GetByUser(user_id int64) ([]*UserIdentity, error) {
if user_id < 1 {
return nil, ErrRecordNotFound
}
query :=`
SELECT identity.id, identity.provider_id, identity.user_id, identity.provider_user_id, identity.password, identity.version
FROM user_identities identity
JOIN users u on identity.user_id = u.id
WHERE u.id = $1`
// Execute the query using the QueryRow() method, passing in the provided id value
// as a placeholder parameter, and scan the response data into the fields of the
// User struct. Importantly, notice that we need to convert the scan target for the
// genres column using the pq.Array() adapter function again.
rows, err := m.DB.Query(query, user_id)
if err != nil {
switch {
case errors.Is(err, sql.ErrNoRows):
return nil, ErrRecordNotFound
default:
return nil, err
}
}
defer rows.Close()
var identities []*UserIdentity
for rows.Next() {
var userIdentity UserIdentity
err := rows.Scan(
&userIdentity.ID,
&userIdentity.ProviderID,
&userIdentity.UserID,
&userIdentity.ProviderUserID,
&userIdentity.Password.hash,
&userIdentity.Version,
)
if err != nil {
return nil, err
}
identities = append(identities, &userIdentity)
}
if len(identities) == 0 {
return nil, ErrRecordNotFound
}
// Otherwise, return a pointer to the User struct.
return identities, nil
}
func (m UserIdentityModel) GetForToken(tokenScope, tokenPlaintext string) (*UserIdentity, error) {
tokenHash := sha256.Sum256([]byte(tokenPlaintext))
query := `
SELECT ui.id, ui.provider_id, ui.user_id, ui.provider_user_id, ui.password, ui.version
FROM user_identities ui
INNER JOIN tokens ON ui.id = tokens.user_identity_id
WHERE tokens.hash = $1
AND tokens.scope = $2
AND tokens.expiry > $3`
args := []interface{}{tokenHash[:], tokenScope, time.Now()}
var userIdentity UserIdentity
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
err := m.DB.QueryRowContext(ctx, query, args...).Scan(
&userIdentity.ID,
&userIdentity.ProviderID,
&userIdentity.UserID,
&userIdentity.ProviderUserID,
&userIdentity.Password.hash,
&userIdentity.Version,
)
if err != nil {
switch {
case errors.Is(err, sql.ErrNoRows):
return nil, ErrRecordNotFound
default:
return nil, err
}
}
return &userIdentity, nil
}
func (m UserIdentityModel) Update(user *UserIdentity) error {
query := `
UPDATE user_identities
SET provider_user_id = $1, password = $2, version = version + 1
WHERE id = $3 AND version = $4
RETURNING version`
// Create an args slice containing the values for the placeholder parameters.
args := []interface{}{
user.ProviderUserID,
user.Password.hash,
user.ID,
user.Version,
}
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
err := m.DB.QueryRowContext(ctx, query, args...).Scan(&user.Version)
if err != nil {
switch {
case err.Error() == `pq: duplicate key value violates unique constraint "users_email_key"`:
return ErrDuplicateEmail
case errors.Is(err, sql.ErrNoRows):
return ErrEditConflict
default:
return err
}
}
return nil
}
func (m UserIdentityModel) Delete(id int64) error {
if id < 1 {
return ErrRecordNotFound
}
// Construct the SQL query to delete the record.
query := `
DELETE FROM user_identities
WHERE id = $1`
// Execute the SQL query using the Exec() method, passing in the id variable as
// the value for the placeholder parameter. The Exec() method returns a sql.Result
// object.
result, err := m.DB.Exec(query, id)
if err != nil {
return err
}
// Call the RowsAffected() method on the sql.Result object to get the number of rows
// affected by the query.
rowsAffected, err := result.RowsAffected()
if err != nil {
return err
}
// If no rows were affected, we know that the issues table didn't contain a record
// with the provided ID at the moment we tried to delete it. In that case we
// return an ErrRecordNotFound error.
if rowsAffected == 0 {
return ErrRecordNotFound
}
return nil
}