party/internal/data/blind_signs.go

110 lines
2.1 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package data
import (
"context"
"database/sql"
"errors"
"fmt"
"math/big"
"time"
"github.com/lib/pq"
)
type BlindSign struct {
UserID int64 `json:"user_id"`
IssueID int64 `json:"issue_id"`
Created time.Time `json:"created"`
}
type BlindSignModel struct {
DB *sql.DB
}
func (m BlindSignModel) Get(userID int64, issueID int64) (*BlindSign, error) {
query :=
`SELECT user_id, issue_id, created
FROM blind_signs
WHERE user_id = $1 AND issue_id = $2`
args := []interface{}{
userID,
issueID,
}
var blindSign BlindSign
err := m.DB.QueryRow(query, args...).Scan(&blindSign.UserID, &blindSign.IssueID, &blindSign.Created)
if err != nil {
switch {
case errors.Is(err, sql.ErrNoRows):
return nil, ErrRecordNotFound
default:
return nil, err
}
}
return &blindSign, nil
}
func (m BlindSignModel) Insert(blind_sign *BlindSign) error {
query := `
INSERT INTO blind_signs (user_id, issue_id)
VALUES ($1, $2)
RETURNING created`
args := []interface{}{
blind_sign.UserID,
blind_sign.IssueID,
}
err := m.DB.QueryRow(query, args...).Scan(&blind_sign.Created)
if pgErr, ok := err.(*pq.Error); ok {
if pgErr.Code == "23505" {
return ErrDuplicateBlindSign
}
}
return err
}
func (m BlindSignModel) BlindSign(issueID int64, blindedVoteBytes []byte) ([]byte, error) {
if issueID < 1 {
return nil, ErrRecordNotFound
}
query := `SELECT rsa_private_pem FROM issues WHERE id = $1`
ctx, cancel := context.WithTimeout(context.Background(), 3 * time.Second)
defer cancel()
var pemBytes []byte
err := m.DB.QueryRowContext(ctx, query, issueID).Scan(&pemBytes)
if err != nil {
switch {
case errors.Is(err, sql.ErrNoRows):
return nil, ErrRecordNotFound
default:
return nil, err
}
}
key, err := parsePrivateKey(pemBytes)
if err != nil {
return nil, fmt.Errorf("parse private key: %w", err)
}
m_ := new(big.Int).SetBytes(blindedVoteBytes)
// Validate range: m must be in [1, n-1]
one := big.NewInt(1)
if m_.Cmp(one) < 0 || m_.Cmp(key.N) >= 0 {
return nil, ErrInvalidBlindedVote
}
// s = m^d mod n
sig := new(big.Int).Exp(m_, key.D, key.N)
return sig.Bytes(), nil
}