110 lines
2.1 KiB
Go
110 lines
2.1 KiB
Go
package data
|
||
|
||
import (
|
||
"context"
|
||
"database/sql"
|
||
"errors"
|
||
"fmt"
|
||
"math/big"
|
||
"time"
|
||
|
||
"github.com/lib/pq"
|
||
)
|
||
|
||
type BlindSign struct {
|
||
UserID int64 `json:"user_id"`
|
||
IssueID int64 `json:"issue_id"`
|
||
Created time.Time `json:"created"`
|
||
}
|
||
|
||
type BlindSignModel struct {
|
||
DB *sql.DB
|
||
}
|
||
|
||
func (m BlindSignModel) Get(userID int64, issueID int64) (*BlindSign, error) {
|
||
query :=
|
||
`SELECT user_id, issue_id, created
|
||
FROM blind_signs
|
||
WHERE user_id = $1 AND issue_id = $2`
|
||
|
||
args := []interface{}{
|
||
userID,
|
||
issueID,
|
||
}
|
||
|
||
var blindSign BlindSign
|
||
|
||
err := m.DB.QueryRow(query, args...).Scan(&blindSign.UserID, &blindSign.IssueID, &blindSign.Created)
|
||
|
||
if err != nil {
|
||
switch {
|
||
case errors.Is(err, sql.ErrNoRows):
|
||
return nil, ErrRecordNotFound
|
||
default:
|
||
return nil, err
|
||
}
|
||
}
|
||
|
||
return &blindSign, nil
|
||
}
|
||
|
||
func (m BlindSignModel) Insert(blind_sign *BlindSign) error {
|
||
query := `
|
||
INSERT INTO blind_signs (user_id, issue_id)
|
||
VALUES ($1, $2)
|
||
RETURNING created`
|
||
|
||
args := []interface{}{
|
||
blind_sign.UserID,
|
||
blind_sign.IssueID,
|
||
}
|
||
|
||
err := m.DB.QueryRow(query, args...).Scan(&blind_sign.Created)
|
||
if pgErr, ok := err.(*pq.Error); ok {
|
||
if pgErr.Code == "23505" {
|
||
return ErrDuplicateBlindSign
|
||
}
|
||
}
|
||
return err
|
||
}
|
||
|
||
func (m BlindSignModel) BlindSign(issueID int64, blindedVoteBytes []byte) ([]byte, error) {
|
||
if issueID < 1 {
|
||
return nil, ErrRecordNotFound
|
||
}
|
||
|
||
query := `SELECT rsa_private_pem FROM issues WHERE id = $1`
|
||
|
||
ctx, cancel := context.WithTimeout(context.Background(), 3 * time.Second)
|
||
defer cancel()
|
||
|
||
var pemBytes []byte
|
||
err := m.DB.QueryRowContext(ctx, query, issueID).Scan(&pemBytes)
|
||
if err != nil {
|
||
switch {
|
||
case errors.Is(err, sql.ErrNoRows):
|
||
return nil, ErrRecordNotFound
|
||
default:
|
||
return nil, err
|
||
}
|
||
}
|
||
|
||
key, err := parsePrivateKey(pemBytes)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("parse private key: %w", err)
|
||
}
|
||
|
||
m_ := new(big.Int).SetBytes(blindedVoteBytes)
|
||
|
||
// Validate range: m′ must be in [1, n-1]
|
||
one := big.NewInt(1)
|
||
if m_.Cmp(one) < 0 || m_.Cmp(key.N) >= 0 {
|
||
return nil, ErrInvalidBlindedVote
|
||
}
|
||
|
||
// s′ = m′^d mod n
|
||
sig := new(big.Int).Exp(m_, key.D, key.N)
|
||
|
||
return sig.Bytes(), nil
|
||
}
|