package data import ( "golang.org/x/crypto/bcrypt" "errors" "party.at/party/internal/validator" "database/sql" "context" "time" "crypto/sha256" ) type UserIdentity struct { ID int64 `json:"id"` ProviderID int64 `json:"provider_id"` UserID int64 `json:"user_id"` ProviderUserID string `json:"provider_user_id"` Password password `json:"-"` Version int32 `json:"version"` } type password struct { plaintext *string hash []byte } func (p *password) Set(plaintextPassword string) error { hash, err := bcrypt.GenerateFromPassword([]byte(plaintextPassword), 12) if err != nil { return err } p.plaintext = &plaintextPassword p.hash = hash return nil } func (p *password) Matches(plaintextPassword string) (bool, error) { err := bcrypt.CompareHashAndPassword(p.hash, []byte(plaintextPassword)) if err != nil { switch { case errors.Is(err, bcrypt.ErrMismatchedHashAndPassword): return false, nil default: return false, err } } return true, nil } func ValidatePasswordPlaintext(v *validator.Validator, password string) { v.Check(password != "", "password", "must be provided") v.Check(len(password) >= 8, "password", "must be at least 8 bytes long") v.Check(len(password) <= 72, "password", "must not be more than 72 bytes long") } func ValidateUserIdentity(v *validator.Validator, userIdentity *UserIdentity) { v.Check(userIdentity.ProviderID == int64(ProviderLocal), "provider_id", "must be 1"); if userIdentity.ProviderID == int64(ProviderLocal) { v.Check(userIdentity.ProviderUserID != "", "username", "must be provided") v.Check(len(userIdentity.ProviderUserID) <= 500, "username", "must not be more than 500 bytes long") } if userIdentity.Password.plaintext != nil { ValidatePasswordPlaintext(v, *userIdentity.Password.plaintext) } if userIdentity.Password.hash == nil { panic("missing password hash for user") } } type UserIdentityModel struct { DB *sql.DB } // func (m UserIdentityModel) Insert(userIdentity *UserIdentity) error { // query := ` // INSERT INTO user_identities (provider_id, user_id, provider_user, password) // VALUES ($1, $2, $3, $4) // RETURNING id, version` // args := []interface{}{ // userIdentity.ProviderID, // userIdentity.UserID, // userIdentity.ProviderUserID, // userIdentity.Password.hash, // } // ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) // defer cancel() // err := m.DB.QueryRowContext(ctx, query, args...).Scan(&userIdentity.ID, &userIdentity.Version) // if err != nil { // switch { // case err.Error() == // `pq: duplicate key value violates unique constraint "users_email_key"`: // return ErrDuplicateEmail // default: // return err // } // } // return nil // } func (m UserIdentityModel) Get(id int64) (*UserIdentity, error) { if id < 1 { return nil, ErrRecordNotFound } query :=` SELECT id, provider_id, user_id, provider_user, password, version FROM user_identities WHERE id = $1` // Declare a User struct to hold the data returned by the query. var userIdentity UserIdentity // Execute the query using the QueryRow() method, passing in the provided id value // as a placeholder parameter, and scan the response data into the fields of the // User struct. Importantly, notice that we need to convert the scan target for the // genres column using the pq.Array() adapter function again. err := m.DB.QueryRow(query, id).Scan( &userIdentity.ID, &userIdentity.ProviderID, &userIdentity.UserID, &userIdentity.ProviderUserID, &userIdentity.Password, &userIdentity.Version, ) if err != nil { switch { case errors.Is(err, sql.ErrNoRows): return nil, ErrRecordNotFound default: return nil, err } } // Otherwise, return a pointer to the User struct. return &userIdentity, nil } func (m UserIdentityModel) GetByUser(user_id int64) ([]*UserIdentity, error) { if user_id < 1 { return nil, ErrRecordNotFound } query :=` SELECT identity.id, identity.provider_id, identity.user_id, identity.provider_user_id, identity.password, identity.version FROM user_identities identity JOIN users u on identity.user_id = u.id WHERE u.id = $1` // Execute the query using the QueryRow() method, passing in the provided id value // as a placeholder parameter, and scan the response data into the fields of the // User struct. Importantly, notice that we need to convert the scan target for the // genres column using the pq.Array() adapter function again. rows, err := m.DB.Query(query, user_id) if err != nil { switch { case errors.Is(err, sql.ErrNoRows): return nil, ErrRecordNotFound default: return nil, err } } defer rows.Close() var identities []*UserIdentity for rows.Next() { var userIdentity UserIdentity err := rows.Scan( &userIdentity.ID, &userIdentity.ProviderID, &userIdentity.UserID, &userIdentity.ProviderUserID, &userIdentity.Password.hash, &userIdentity.Version, ) if err != nil { return nil, err } identities = append(identities, &userIdentity) } if len(identities) == 0 { return nil, ErrRecordNotFound } // Otherwise, return a pointer to the User struct. return identities, nil } func (m UserIdentityModel) GetForToken(tokenScope, tokenPlaintext string) (*UserIdentity, error) { tokenHash := sha256.Sum256([]byte(tokenPlaintext)) query := ` SELECT ui.id, ui.provider_id, ui.user_id, ui.provider_user_id, ui.password, ui.version FROM user_identities ui INNER JOIN tokens ON ui.id = tokens.user_identity_id WHERE tokens.hash = $1 AND tokens.scope = $2 AND tokens.expiry > $3` args := []interface{}{tokenHash[:], tokenScope, time.Now()} var userIdentity UserIdentity ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() err := m.DB.QueryRowContext(ctx, query, args...).Scan( &userIdentity.ID, &userIdentity.ProviderID, &userIdentity.UserID, &userIdentity.ProviderUserID, &userIdentity.Password.hash, &userIdentity.Version, ) if err != nil { switch { case errors.Is(err, sql.ErrNoRows): return nil, ErrRecordNotFound default: return nil, err } } return &userIdentity, nil } func (m UserIdentityModel) Update(user *UserIdentity) error { query := ` UPDATE user_identities SET provider_user_id = $1, password = $2, version = version + 1 WHERE id = $3 AND version = $4 RETURNING version` // Create an args slice containing the values for the placeholder parameters. args := []interface{}{ user.ProviderUserID, user.Password.hash, user.ID, user.Version, } ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() err := m.DB.QueryRowContext(ctx, query, args...).Scan(&user.Version) if err != nil { switch { case err.Error() == `pq: duplicate key value violates unique constraint "users_email_key"`: return ErrDuplicateEmail case errors.Is(err, sql.ErrNoRows): return ErrEditConflict default: return err } } return nil } func (m UserIdentityModel) Delete(id int64) error { if id < 1 { return ErrRecordNotFound } // Construct the SQL query to delete the record. query := ` DELETE FROM user_identities WHERE id = $1` // Execute the SQL query using the Exec() method, passing in the id variable as // the value for the placeholder parameter. The Exec() method returns a sql.Result // object. result, err := m.DB.Exec(query, id) if err != nil { return err } // Call the RowsAffected() method on the sql.Result object to get the number of rows // affected by the query. rowsAffected, err := result.RowsAffected() if err != nil { return err } // If no rows were affected, we know that the issues table didn't contain a record // with the provided ID at the moment we tried to delete it. In that case we // return an ErrRecordNotFound error. if rowsAffected == 0 { return ErrRecordNotFound } return nil }