307 lines
7.8 KiB
Go
307 lines
7.8 KiB
Go
package data
|
|
|
|
import (
|
|
"golang.org/x/crypto/bcrypt"
|
|
"errors"
|
|
"party.at/party/internal/validator"
|
|
"database/sql"
|
|
"context"
|
|
"time"
|
|
"crypto/sha256"
|
|
)
|
|
|
|
type UserIdentity struct {
|
|
ID int64 `json:"id"`
|
|
ProviderID int64 `json:"provider_id"`
|
|
UserID int64 `json:"user_id"`
|
|
ProviderUserID string `json:"provider_user_id"`
|
|
Password password `json:"-"`
|
|
Version int32 `json:"version"`
|
|
}
|
|
|
|
type password struct {
|
|
plaintext *string
|
|
hash []byte
|
|
}
|
|
|
|
func (p *password) Set(plaintextPassword string) error {
|
|
hash, err := bcrypt.GenerateFromPassword([]byte(plaintextPassword), 12)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
p.plaintext = &plaintextPassword
|
|
p.hash = hash
|
|
return nil
|
|
}
|
|
|
|
func (p *password) Matches(plaintextPassword string) (bool, error) {
|
|
err := bcrypt.CompareHashAndPassword(p.hash, []byte(plaintextPassword))
|
|
if err != nil {
|
|
switch {
|
|
case errors.Is(err, bcrypt.ErrMismatchedHashAndPassword):
|
|
return false, nil
|
|
default:
|
|
return false, err
|
|
}
|
|
}
|
|
|
|
return true, nil
|
|
}
|
|
|
|
func ValidatePasswordPlaintext(v *validator.Validator, password string) {
|
|
v.Check(password != "", "password", "must be provided")
|
|
v.Check(len(password) >= 8, "password", "must be at least 8 bytes long")
|
|
v.Check(len(password) <= 72, "password", "must not be more than 72 bytes long")
|
|
}
|
|
|
|
func ValidateUserIdentity(v *validator.Validator, userIdentity *UserIdentity) {
|
|
v.Check(userIdentity.ProviderID == int64(ProviderLocal), "provider_id", "must be 1");
|
|
|
|
if userIdentity.ProviderID == int64(ProviderLocal) {
|
|
v.Check(userIdentity.ProviderUserID != "", "username", "must be provided")
|
|
v.Check(len(userIdentity.ProviderUserID) <= 500, "username", "must not be more than 500 bytes long")
|
|
}
|
|
|
|
if userIdentity.Password.plaintext != nil {
|
|
ValidatePasswordPlaintext(v,
|
|
*userIdentity.Password.plaintext)
|
|
}
|
|
|
|
if userIdentity.Password.hash == nil {
|
|
panic("missing password hash for user")
|
|
}
|
|
}
|
|
|
|
type UserIdentityModel struct {
|
|
DB *sql.DB
|
|
}
|
|
|
|
// func (m UserIdentityModel) Insert(userIdentity *UserIdentity) error {
|
|
// query := `
|
|
// INSERT INTO user_identities (provider_id, user_id, provider_user, password)
|
|
// VALUES ($1, $2, $3, $4)
|
|
// RETURNING id, version`
|
|
|
|
// args := []interface{}{
|
|
// userIdentity.ProviderID,
|
|
// userIdentity.UserID,
|
|
// userIdentity.ProviderUserID,
|
|
// userIdentity.Password.hash,
|
|
// }
|
|
|
|
// ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
|
// defer cancel()
|
|
|
|
// err := m.DB.QueryRowContext(ctx, query, args...).Scan(&userIdentity.ID, &userIdentity.Version)
|
|
// if err != nil {
|
|
// switch {
|
|
// case err.Error() ==
|
|
// `pq: duplicate key value violates unique constraint "users_email_key"`:
|
|
// return ErrDuplicateEmail
|
|
// default:
|
|
// return err
|
|
// }
|
|
// }
|
|
|
|
// return nil
|
|
// }
|
|
|
|
func (m UserIdentityModel) Get(id int64) (*UserIdentity, error) {
|
|
if id < 1 {
|
|
return nil, ErrRecordNotFound
|
|
}
|
|
|
|
query :=`
|
|
SELECT id, provider_id, user_id, provider_user, password, version
|
|
FROM user_identities
|
|
WHERE id = $1`
|
|
|
|
// Declare a User struct to hold the data returned by the query.
|
|
var userIdentity UserIdentity
|
|
|
|
// Execute the query using the QueryRow() method, passing in the provided id value
|
|
// as a placeholder parameter, and scan the response data into the fields of the
|
|
// User struct. Importantly, notice that we need to convert the scan target for the
|
|
// genres column using the pq.Array() adapter function again.
|
|
err := m.DB.QueryRow(query, id).Scan(
|
|
&userIdentity.ID,
|
|
&userIdentity.ProviderID,
|
|
&userIdentity.UserID,
|
|
&userIdentity.ProviderUserID,
|
|
&userIdentity.Password,
|
|
&userIdentity.Version,
|
|
)
|
|
|
|
if err != nil {
|
|
switch {
|
|
case errors.Is(err, sql.ErrNoRows):
|
|
return nil, ErrRecordNotFound
|
|
default:
|
|
return nil, err
|
|
}
|
|
}
|
|
// Otherwise, return a pointer to the User struct.
|
|
return &userIdentity, nil
|
|
}
|
|
|
|
func (m UserIdentityModel) GetByUser(user_id int64) ([]*UserIdentity, error) {
|
|
if user_id < 1 {
|
|
return nil, ErrRecordNotFound
|
|
}
|
|
|
|
query :=`
|
|
SELECT identity.id, identity.provider_id, identity.user_id, identity.provider_user_id, identity.password, identity.version
|
|
FROM user_identities identity
|
|
JOIN users u on identity.user_id = u.id
|
|
WHERE u.id = $1`
|
|
|
|
// Execute the query using the QueryRow() method, passing in the provided id value
|
|
// as a placeholder parameter, and scan the response data into the fields of the
|
|
// User struct. Importantly, notice that we need to convert the scan target for the
|
|
// genres column using the pq.Array() adapter function again.
|
|
rows, err := m.DB.Query(query, user_id)
|
|
if err != nil {
|
|
switch {
|
|
case errors.Is(err, sql.ErrNoRows):
|
|
return nil, ErrRecordNotFound
|
|
default:
|
|
return nil, err
|
|
}
|
|
}
|
|
defer rows.Close()
|
|
|
|
var identities []*UserIdentity
|
|
|
|
for rows.Next() {
|
|
var userIdentity UserIdentity
|
|
|
|
err := rows.Scan(
|
|
&userIdentity.ID,
|
|
&userIdentity.ProviderID,
|
|
&userIdentity.UserID,
|
|
&userIdentity.ProviderUserID,
|
|
&userIdentity.Password.hash,
|
|
&userIdentity.Version,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
identities = append(identities, &userIdentity)
|
|
}
|
|
|
|
if len(identities) == 0 {
|
|
return nil, ErrRecordNotFound
|
|
}
|
|
|
|
// Otherwise, return a pointer to the User struct.
|
|
return identities, nil
|
|
}
|
|
|
|
func (m UserIdentityModel) GetForToken(tokenScope, tokenPlaintext string) (*UserIdentity, error) {
|
|
tokenHash := sha256.Sum256([]byte(tokenPlaintext))
|
|
|
|
query := `
|
|
SELECT ui.id, ui.provider_id, ui.user_id, ui.provider_user_id, ui.password, ui.version
|
|
FROM user_identities ui
|
|
INNER JOIN tokens ON ui.id = tokens.user_identity_id
|
|
WHERE tokens.hash = $1
|
|
AND tokens.scope = $2
|
|
AND tokens.expiry > $3`
|
|
|
|
args := []interface{}{tokenHash[:], tokenScope, time.Now()}
|
|
|
|
var userIdentity UserIdentity
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
|
defer cancel()
|
|
|
|
err := m.DB.QueryRowContext(ctx, query, args...).Scan(
|
|
&userIdentity.ID,
|
|
&userIdentity.ProviderID,
|
|
&userIdentity.UserID,
|
|
&userIdentity.ProviderUserID,
|
|
&userIdentity.Password.hash,
|
|
&userIdentity.Version,
|
|
)
|
|
if err != nil {
|
|
switch {
|
|
case errors.Is(err, sql.ErrNoRows):
|
|
return nil, ErrRecordNotFound
|
|
default:
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return &userIdentity, nil
|
|
}
|
|
|
|
func (m UserIdentityModel) Update(user *UserIdentity) error {
|
|
query := `
|
|
UPDATE user_identities
|
|
SET provider_user_id = $1, password = $2, version = version + 1
|
|
WHERE id = $3 AND version = $4
|
|
RETURNING version`
|
|
|
|
// Create an args slice containing the values for the placeholder parameters.
|
|
args := []interface{}{
|
|
user.ProviderUserID,
|
|
user.Password.hash,
|
|
user.ID,
|
|
user.Version,
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
|
defer cancel()
|
|
|
|
err := m.DB.QueryRowContext(ctx, query, args...).Scan(&user.Version)
|
|
if err != nil {
|
|
switch {
|
|
case err.Error() == `pq: duplicate key value violates unique constraint "users_email_key"`:
|
|
return ErrDuplicateEmail
|
|
case errors.Is(err, sql.ErrNoRows):
|
|
return ErrEditConflict
|
|
default:
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (m UserIdentityModel) Delete(id int64) error {
|
|
if id < 1 {
|
|
return ErrRecordNotFound
|
|
}
|
|
|
|
// Construct the SQL query to delete the record.
|
|
query := `
|
|
DELETE FROM user_identities
|
|
WHERE id = $1`
|
|
|
|
// Execute the SQL query using the Exec() method, passing in the id variable as
|
|
// the value for the placeholder parameter. The Exec() method returns a sql.Result
|
|
// object.
|
|
result, err := m.DB.Exec(query, id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Call the RowsAffected() method on the sql.Result object to get the number of rows
|
|
// affected by the query.
|
|
rowsAffected, err := result.RowsAffected()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// If no rows were affected, we know that the issues table didn't contain a record
|
|
// with the provided ID at the moment we tried to delete it. In that case we
|
|
// return an ErrRecordNotFound error.
|
|
if rowsAffected == 0 {
|
|
return ErrRecordNotFound
|
|
}
|
|
|
|
return nil
|
|
}
|