Applying ASVS, Part 1: Safe Concurrency
ASVSv5 15.4.1 and 15.4.2, demonstrated through a worked Go example.
When I joined the OWASP ASVS working group, I wanted to contribute deeper guidance in areas where the standard had room to grow. I gravitated towards safe concurrency, because race conditions tend to get misunderstood as a single class of bug or as something the language handles for you. In my experience, though, the underlying logic bugs aren’t any easier to avoid just because concurrency syntax is easier to write.
Race conditions are logic bugs, where timing is the broken logic: what happens, and when. The most concrete kind is two threads writing to the same piece of memory at the same moment; this is what most code reviewers are trained to look for. But most race conditions are more abstract than that. For example, in a TOCTOU race, a program reads some state, decides what to do based on it, and acts, while another request changes that state in the gap between the steps. This class has produced critical bugs in production systems.
A reviewer watching only for the memory-level kind misses all of these, and they are often the more serious ones. A low-level race usually crashes a program or corrupts its data, loudly enough to notice. A logic race can leave it running exactly as intended while the wrong person gets in, or learns something they shouldn’t. Code that is clean of every low-level race can still act on stale data, leading to even more unforeseen vulnerabilities.
The existing requirement in 4.0.3 was:
V4.0.3-11.1.6: Verify that the application doesn’t suffer from TOCTOU issues or other race conditions for sensitive operations.
That sentence reads as shorthand for the already-initiated, pointing at TOCTOU and a wider class of race conditions while leaving them undefined for the reader. But there’s more than one way for bad timing to wreck your code, and the old requirement from ASVSv4 treats them all as one catch-all category.
5.0 splits the original requirement into four discrete recommendations we felt would be most impactful for code reviewers:
- shared objects and the primitives that protect them (15.4.1)
- atomic compound operations, with TOCTOU named explicitly (15.4.2)
- deadlock and livelock prevention (15.4.3)
- thread starvation prevention (15.4.4)
This is the first in a series of worked examples I’m planning to write, each one grounded in the kind of code a reader might plausibly encounter in a production system. I’ll focus on the first two concurrency requirements here; the rest will get their own treatment another day.
The Service
The codebase is a small Go service for issuing and verifying stateful one-time passwords. Stateful here means the server picks the code at issuance time and persists it for the validity window, as opposed to stateless schemes like TOTP where both sides derive the code from a shared secret and the clock. Picture an SMS or email second factor: the server generates a six-digit code, sends it to the user, waits for the user to submit it back, and decides whether it matches. The service is built around three components: a challenge store that holds the active codes, a rate limiter that caps how many guesses a given user can make within a window, and a validator that wires them together.
Go makes sense for a service like this because its concurrency model is one of the language’s strongest features. Goroutines and channels are first-class, the standard library leans heavily on them, and writing concurrent code requires very little ceremony compared to most other languages. None of that makes races impossible though, as we’ll see.
I’ll describe the service from the perspective of a hypothetical, small dev team, shipping this OTP service alongside several others they own. They work in an organization where the platform team runs shared infrastructure including Redis, but they have a standing preference for keeping their Redis footprint lean where they can justify it. They aren’t building for scale they don’t have yet.
The Threat Model
Before writing any code, the team threat-modeled the service. Of the risks in their threat model, the one we’ll follow here is online brute force against the six-digit code: an attacker guessing values until one works. The team addressed it by limiting how many wrong guesses each user gets before the service stops accepting them; we’ll review how they built that limiter. For the full picture, including the risks they accepted as tradeoffs, see the OTP service threat model).
| ID | Risk | ROAM | STRIDE | Notes |
|---|---|---|---|---|
OTP-3 | Code prediction / weak RNG | Owned | Spoofing | Predictable codes let an attacker derive a valid second factor without observing it |
OTP-9 | Cross-user code confusion | Owned | Spoofing Elevation of Privilege | A code valid for one user is accepted for another, granting access to a different account |
OTP-1 | Online brute force of the code | Owned | Spoofing | Attacker impersonates the legitimate user by guessing their second factor |
OTP-8 | Replay of consumed codes | Owned | Spoofing Elevation of Privilege | An already-used code is presented again to authenticate a second time |
OTP-7 | Timing attacks on code comparison | Owned | Information Disclosure | Response latency leaks information about correct prefixes of the code |
OTP-6 | Email account compromise | Accepted | Spoofing | Attacker controls the delivery channel and receives codes intended for the legitimate user |
OTP-2 | Issuance flooding / SMS bombing | Accepted | Denial of Service | Exhausts SMS budget and disrupts the user's device with unsolicited codes |
OTP-11 | Multi-replica state divergence | Accepted | Tampering | Routing requests across replicas circumvents the rate limit counter on any single replica |
OTP-10 | Service restart wipes rate limit state | Accepted | Tampering | The rate limit counter is effectively reset by an out-of-band action, weakening the control |
OTP-4 | SIM swapping | Accepted | Spoofing | Attacker takes control of the delivery channel to receive codes intended for the legitimate user |
OTP-5 | SS7 interception | Accepted | Information Disclosure | Code is exposed to a network-level adversary in transit |
No risks match the current filters.
The Naive Implementation
The limiter is backed by a map from user identifier to a small entry struct that tracks the failure count and when the current window ends. Locked reads the entry for a key and returns whether the count has reached the limit. RecordFailure reads the entry, checks the count against the limit, and increments it if there’s room.
With the two methods in place, the team wired the limiter into the validator and started writing tests. The threat model had named parallelized brute force as an Owned risk, so they wrote a test that fires that exact load at the limiter: fifty goroutines submit failed verification attempts against the same user. With the limit set to five, the team expected the limiter to admit five and reject the other forty-five.
package main
import (
"context"
"errors"
"flag"
"log/slog"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"github.com/cronchie/totpauth/internal/auth"
"github.com/cronchie/totpauth/internal/auth/challenge"
"github.com/cronchie/totpauth/internal/httpapi"
"github.com/cronchie/totpauth/internal/ratelimit"
)
func main() {
addr := flag.String("addr", ":8080", "listen address")
flag.Parse()
log := slog.New(slog.NewJSONHandler(os.Stdout, nil))
notifier := auth.NotifierFunc(func(_ context.Context, userID, code string) error {
log.Warn("OTP code generated; configure a real Notifier", "user_id", userID, "code", code)
return nil
})
limiter := ratelimit.New()
challenges := challenge.New()
validator := auth.NewValidator(limiter, challenges, notifier)
handler := httpapi.Recover(log)(httpapi.NewHandler(validator, log).Routes())
srv := &http.Server{
Addr: *addr,
Handler: handler,
ReadHeaderTimeout: 5 * time.Second,
}
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer stop()
go func() {
log.Info("authd listening", "addr", srv.Addr)
if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Error("listen failed", "err", err)
stop()
}
}()
<-ctx.Done()
log.Info("authd shutting down")
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := srv.Shutdown(shutdownCtx); err != nil {
log.Error("shutdown failed", "err", err)
os.Exit(1)
}
}
module github.com/cronchie/totpauth
go 1.23
package auth
import (
"errors"
"github.com/cronchie/totpauth/internal/auth/challenge"
)
var ErrLockedOut = errors.New("auth: too many failed attempts")
var (
ErrInvalidCode = challenge.ErrInvalidCode
ErrNoChallenge = challenge.ErrNoChallenge
)
package challenge
import (
"context"
"errors"
"time"
)
const DefaultTTL = 5 * time.Minute
var (
ErrInvalidCode = errors.New("challenge: invalid code")
ErrNoChallenge = errors.New("challenge: no active challenge")
)
type Challenge struct {
Code string
ExpiresAt time.Time
}
type Store interface {
Issue(ctx context.Context, userID, code string, ttl time.Duration) error
ConsumeIfMatches(ctx context.Context, userID, candidate string) error
}
package challenge_test
import (
"context"
"errors"
"fmt"
"os"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/cronchie/totpauth/internal/auth/challenge"
)
func TestStore_SingleConsumeUnderBurst(t *testing.T) {
if os.Getenv("CHALLENGE_RACE_DEMO") == "" {
t.Skip("set CHALLENGE_RACE_DEMO=1")
}
const (
iterations = 50
burst = 200
ttl = time.Minute
)
store := challenge.New()
ctx := context.Background()
var totalMatched int64
var other atomic.Value
for iter := range iterations {
userID := fmt.Sprintf("user-%d", iter)
code := fmt.Sprintf("%06d", iter)
if err := store.Issue(ctx, userID, code, ttl); err != nil {
t.Fatalf("Issue: %v", err)
}
var wg sync.WaitGroup
start := make(chan struct{})
for range burst {
wg.Add(1)
go func() {
defer wg.Done()
<-start
switch err := store.ConsumeIfMatches(ctx, userID, code); {
case err == nil:
atomic.AddInt64(&totalMatched, 1)
case errors.Is(err, challenge.ErrNoChallenge),
errors.Is(err, challenge.ErrInvalidCode):
default:
other.Store(err)
}
}()
}
close(start)
wg.Wait()
}
if err, _ := other.Load().(error); err != nil {
t.Fatalf("unexpected store error: %v", err)
}
if totalMatched != iterations {
t.Errorf("totalMatched = %d across %d iterations of %d-goroutine bursts, want exactly %d",
totalMatched, iterations, burst, iterations)
}
}
func TestStore_RejectsWrongCode(t *testing.T) {
const userID = "user-42"
store := challenge.New()
ctx := context.Background()
if err := store.Issue(ctx, userID, "123456", time.Minute); err != nil {
t.Fatalf("Issue: %v", err)
}
err := store.ConsumeIfMatches(ctx, userID, "000000")
if !errors.Is(err, challenge.ErrInvalidCode) {
t.Fatalf("ConsumeIfMatches(wrong) = %v, want ErrInvalidCode", err)
}
if err := store.ConsumeIfMatches(ctx, userID, "123456"); err != nil {
t.Fatalf("ConsumeIfMatches(correct after wrong) = %v, want nil", err)
}
}
func TestStore_ExpiredChallenge(t *testing.T) {
const (
userID = "user-42"
code = "123456"
ttl = time.Minute
)
now := time.Date(2026, 5, 21, 12, 0, 0, 0, time.UTC)
store := challenge.New(challenge.WithClock(func() time.Time { return now }))
ctx := context.Background()
if err := store.Issue(ctx, userID, code, ttl); err != nil {
t.Fatalf("Issue: %v", err)
}
now = now.Add(2 * time.Minute)
err := store.ConsumeIfMatches(ctx, userID, code)
if !errors.Is(err, challenge.ErrNoChallenge) {
t.Fatalf("expired Consume = %v, want ErrNoChallenge", err)
}
}
package store
import "time"
type Settings struct {
Now func() time.Time
}
type Option func(*Settings)
func Apply(opts []Option) Settings {
s := Settings{Now: time.Now}
for _, opt := range opts {
opt(&s)
}
return s
}
func WithClock(now func() time.Time) Option {
return func(s *Settings) { s.Now = now }
}
package challenge
import (
"time"
"github.com/cronchie/totpauth/internal/auth/challenge/internal/store"
)
type Option = store.Option
func WithClock(now func() time.Time) Option { return store.WithClock(now) }
package challenge
import (
"context"
"crypto/subtle"
"time"
"github.com/cronchie/totpauth/internal/auth/challenge/internal/store"
)
type mapStore struct {
challenges map[string]Challenge
now func() time.Time
}
func New(opts ...Option) Store {
s := store.Apply(opts)
return &mapStore{
challenges: make(map[string]Challenge),
now: s.Now,
}
}
func (s *mapStore) Issue(_ context.Context, userID, code string, ttl time.Duration) error {
s.challenges[userID] = Challenge{
Code: code,
ExpiresAt: s.now().Add(ttl),
}
return nil
}
func (s *mapStore) ConsumeIfMatches(_ context.Context, userID, candidate string) error {
ch, ok := s.challenges[userID]
if !ok || s.now().After(ch.ExpiresAt) {
delete(s.challenges, userID)
return ErrNoChallenge
}
if subtle.ConstantTimeCompare([]byte(ch.Code), []byte(candidate)) != 1 {
return ErrInvalidCode
}
delete(s.challenges, userID)
return nil
}
package auth
import "context"
type Notifier interface {
Notify(ctx context.Context, userID, code string) error
}
type NotifierFunc func(ctx context.Context, userID, code string) error
func (f NotifierFunc) Notify(ctx context.Context, userID, code string) error {
return f(ctx, userID, code)
}
package auth
import (
"context"
"crypto/rand"
"errors"
"fmt"
"math/big"
"time"
"github.com/cronchie/totpauth/internal/auth/challenge"
"github.com/cronchie/totpauth/internal/ratelimit"
)
const DefaultMaxAttempts = 5
const codeDigits = 6
var codeSpace = big.NewInt(1_000_000)
type Validator struct {
limiter ratelimit.Limiter
challenges challenge.Store
notifier Notifier
maxAttempts int
challengeTTL time.Duration
}
type Option func(*Validator)
func WithMaxAttempts(n int) Option {
return func(v *Validator) { v.maxAttempts = n }
}
func WithChallengeTTL(d time.Duration) Option {
return func(v *Validator) { v.challengeTTL = d }
}
func NewValidator(limiter ratelimit.Limiter, challenges challenge.Store, notifier Notifier, opts ...Option) *Validator {
v := &Validator{
limiter: limiter,
challenges: challenges,
notifier: notifier,
maxAttempts: DefaultMaxAttempts,
challengeTTL: challenge.DefaultTTL,
}
for _, opt := range opts {
opt(v)
}
return v
}
func (v *Validator) Issue(ctx context.Context, userID string) error {
code, err := generateCode()
if err != nil {
return fmt.Errorf("generating OTP code: %w", err)
}
if err := v.challenges.Issue(ctx, userID, code, v.challengeTTL); err != nil {
return fmt.Errorf("storing challenge: %w", err)
}
if err := v.notifier.Notify(ctx, userID, code); err != nil {
return fmt.Errorf("dispatching OTP: %w", err)
}
return nil
}
func (v *Validator) Verify(ctx context.Context, userID, code string) error {
if err := v.limiter.Locked(ctx, userID, v.maxAttempts); err != nil {
if errors.Is(err, ratelimit.ErrLimitExceeded) {
return ErrLockedOut
}
return fmt.Errorf("rate limiter: %w", err)
}
err := v.challenges.ConsumeIfMatches(ctx, userID, code)
switch {
case err == nil:
return nil
case errors.Is(err, challenge.ErrInvalidCode), errors.Is(err, challenge.ErrNoChallenge):
if rfErr := v.limiter.RecordFailure(ctx, userID, v.maxAttempts); rfErr != nil {
if errors.Is(rfErr, ratelimit.ErrLimitExceeded) {
return ErrLockedOut
}
return fmt.Errorf("rate limiter: %w", rfErr)
}
return err
default:
return fmt.Errorf("consuming challenge: %w", err)
}
}
func generateCode() (string, error) {
n, err := rand.Int(rand.Reader, codeSpace)
if err != nil {
return "", err
}
return fmt.Sprintf("%0*d", codeDigits, n.Int64()), nil
}
package auth_test
import (
"context"
"errors"
"os"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/cronchie/totpauth/internal/auth"
"github.com/cronchie/totpauth/internal/auth/challenge"
"github.com/cronchie/totpauth/internal/ratelimit"
)
const testUser = "user-42"
type capturingNotifier struct {
mu sync.Mutex
last string
}
func (n *capturingNotifier) Notify(_ context.Context, _, code string) error {
n.mu.Lock()
defer n.mu.Unlock()
n.last = code
return nil
}
func (n *capturingNotifier) code() string {
n.mu.Lock()
defer n.mu.Unlock()
return n.last
}
func newTestValidator(t *testing.T) (*auth.Validator, *capturingNotifier) {
t.Helper()
notifier := &capturingNotifier{}
v := auth.NewValidator(
ratelimit.New(),
challenge.New(),
notifier,
auth.WithMaxAttempts(3),
)
return v, notifier
}
func issueAndCode(t *testing.T, v *auth.Validator, n *capturingNotifier, userID string) string {
t.Helper()
if err := v.Issue(context.Background(), userID); err != nil {
t.Fatalf("Issue: %v", err)
}
code := n.code()
if len(code) != 6 {
t.Fatalf("issued code = %q, want 6 digits", code)
}
return code
}
func TestVerify_AcceptsIssuedCode(t *testing.T) {
v, n := newTestValidator(t)
code := issueAndCode(t, v, n, testUser)
if err := v.Verify(context.Background(), testUser, code); err != nil {
t.Fatalf("Verify(issued) = %v, want nil", err)
}
}
func TestVerify_RejectsWrongCode(t *testing.T) {
v, n := newTestValidator(t)
_ = issueAndCode(t, v, n, testUser)
err := v.Verify(context.Background(), testUser, "000000")
if !errors.Is(err, auth.ErrInvalidCode) {
t.Fatalf("Verify(wrong) = %v, want ErrInvalidCode", err)
}
}
func TestVerify_RejectsAfterConsumption(t *testing.T) {
v, n := newTestValidator(t)
code := issueAndCode(t, v, n, testUser)
ctx := context.Background()
if err := v.Verify(ctx, testUser, code); err != nil {
t.Fatalf("first Verify: %v", err)
}
err := v.Verify(ctx, testUser, code)
if !errors.Is(err, auth.ErrNoChallenge) {
t.Fatalf("replay Verify = %v, want ErrNoChallenge", err)
}
}
func TestVerify_NoChallengeIssued(t *testing.T) {
v, _ := newTestValidator(t)
err := v.Verify(context.Background(), testUser, "000000")
if !errors.Is(err, auth.ErrNoChallenge) {
t.Fatalf("Verify(no challenge) = %v, want ErrNoChallenge", err)
}
}
func TestVerify_LocksOutAfterMaxAttempts(t *testing.T) {
v, n := newTestValidator(t)
_ = issueAndCode(t, v, n, testUser)
ctx := context.Background()
for range 3 {
err := v.Verify(ctx, testUser, "000000")
if errors.Is(err, auth.ErrLockedOut) {
t.Fatal("locked out during budget")
}
}
err := v.Verify(ctx, testUser, "000000")
if !errors.Is(err, auth.ErrLockedOut) {
t.Fatalf("after budget: got %v, want ErrLockedOut", err)
}
}
func TestVerify_ConsumeRaceUnderBurst(t *testing.T) {
if os.Getenv("CHALLENGE_RACE_DEMO") == "" {
t.Skip("set CHALLENGE_RACE_DEMO=1")
}
const burst = 50
notifier := &capturingNotifier{}
v := auth.NewValidator(
ratelimit.New(),
challenge.New(),
notifier,
auth.WithMaxAttempts(burst),
)
if err := v.Issue(context.Background(), testUser); err != nil {
t.Fatalf("Issue: %v", err)
}
code := notifier.code()
var wg sync.WaitGroup
start := make(chan struct{})
var successes int64
for range burst {
wg.Add(1)
go func() {
defer wg.Done()
<-start
if err := v.Verify(context.Background(), testUser, code); err == nil {
atomic.AddInt64(&successes, 1)
}
}()
}
close(start)
wg.Wait()
if successes != 1 {
t.Errorf("successes = %d, want exactly 1 (the code is single-use)", successes)
}
}
func TestVerify_ExpiredChallenge(t *testing.T) {
now := time.Date(2026, 5, 21, 12, 0, 0, 0, time.UTC)
store := challenge.New(
challenge.WithClock(func() time.Time { return now }),
)
notifier := &capturingNotifier{}
v := auth.NewValidator(
ratelimit.New(),
store,
notifier,
auth.WithChallengeTTL(time.Minute),
)
if err := v.Issue(context.Background(), testUser); err != nil {
t.Fatalf("Issue: %v", err)
}
code := notifier.code()
now = now.Add(2 * time.Minute)
err := v.Verify(context.Background(), testUser, code)
if !errors.Is(err, auth.ErrNoChallenge) {
t.Fatalf("expired Verify = %v, want ErrNoChallenge", err)
}
}
package httpapi
import (
"encoding/json"
"errors"
"log/slog"
"net/http"
"github.com/cronchie/totpauth/internal/auth"
)
type Handler struct {
validator *auth.Validator
log *slog.Logger
}
func NewHandler(validator *auth.Validator, log *slog.Logger) *Handler {
if log == nil {
log = slog.Default()
}
return &Handler{validator: validator, log: log}
}
func (h *Handler) Routes() *http.ServeMux {
mux := http.NewServeMux()
mux.HandleFunc("POST /v1/otp/issue", h.issueOTP)
mux.HandleFunc("POST /v1/otp/verify", h.verifyOTP)
return mux
}
type issueRequest struct {
UserID string `json:"user_id"`
}
func (h *Handler) issueOTP(w http.ResponseWriter, r *http.Request) {
var req issueRequest
if err := json.NewDecoder(http.MaxBytesReader(w, r.Body, 1<<10)).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid request body")
return
}
if req.UserID == "" {
writeError(w, http.StatusBadRequest, "user_id is required")
return
}
if err := h.validator.Issue(r.Context(), req.UserID); err != nil {
h.log.ErrorContext(r.Context(), "otp issue failed",
"user_id", req.UserID, "err", err)
writeError(w, http.StatusInternalServerError, "internal error")
return
}
w.WriteHeader(http.StatusAccepted)
}
type verifyRequest struct {
UserID string `json:"user_id"`
Code string `json:"code"`
}
func (h *Handler) verifyOTP(w http.ResponseWriter, r *http.Request) {
var req verifyRequest
if err := json.NewDecoder(http.MaxBytesReader(w, r.Body, 1<<10)).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid request body")
return
}
if req.UserID == "" || req.Code == "" {
writeError(w, http.StatusBadRequest, "user_id and code are required")
return
}
switch err := h.validator.Verify(r.Context(), req.UserID, req.Code); {
case err == nil:
w.WriteHeader(http.StatusNoContent)
case errors.Is(err, auth.ErrLockedOut):
writeError(w, http.StatusTooManyRequests, "too many failed attempts")
case errors.Is(err, auth.ErrInvalidCode), errors.Is(err, auth.ErrNoChallenge):
writeError(w, http.StatusUnauthorized, "invalid credentials")
default:
h.log.ErrorContext(r.Context(), "otp verify failed",
"user_id", req.UserID, "err", err)
writeError(w, http.StatusInternalServerError, "internal error")
}
}
func writeError(w http.ResponseWriter, status int, msg string) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
_ = json.NewEncoder(w).Encode(map[string]string{"error": msg})
}
package httpapi_test
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"sync"
"testing"
"github.com/cronchie/totpauth/internal/auth"
"github.com/cronchie/totpauth/internal/auth/challenge"
"github.com/cronchie/totpauth/internal/httpapi"
"github.com/cronchie/totpauth/internal/ratelimit"
)
const testUser = "user-42"
type capturingNotifier struct {
mu sync.Mutex
last string
}
func (n *capturingNotifier) Notify(_ context.Context, _, code string) error {
n.mu.Lock()
defer n.mu.Unlock()
n.last = code
return nil
}
func (n *capturingNotifier) code() string {
n.mu.Lock()
defer n.mu.Unlock()
return n.last
}
func newTestServer(t *testing.T, maxAttempts int) (*httptest.Server, *capturingNotifier) {
t.Helper()
notifier := &capturingNotifier{}
validator := auth.NewValidator(
ratelimit.New(),
challenge.New(),
notifier,
auth.WithMaxAttempts(maxAttempts),
)
srv := httptest.NewServer(httpapi.NewHandler(validator, nil).Routes())
t.Cleanup(srv.Close)
return srv, notifier
}
func postJSON(t *testing.T, srv *httptest.Server, path string, body any) *http.Response {
t.Helper()
buf, err := json.Marshal(body)
if err != nil {
t.Fatalf("marshal: %v", err)
}
resp, err := http.Post(srv.URL+path, "application/json", bytes.NewReader(buf))
if err != nil {
t.Fatalf("POST %s: %v", path, err)
}
t.Cleanup(func() { resp.Body.Close() })
return resp
}
func TestIssue_Accepted(t *testing.T) {
srv, n := newTestServer(t, 5)
resp := postJSON(t, srv, "/v1/otp/issue", map[string]string{"user_id": testUser})
if resp.StatusCode != http.StatusAccepted {
t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusAccepted)
}
if got := n.code(); len(got) != 6 {
t.Fatalf("notifier received %q, want 6-digit code", got)
}
}
func TestVerify_Success(t *testing.T) {
srv, n := newTestServer(t, 5)
_ = postJSON(t, srv, "/v1/otp/issue", map[string]string{"user_id": testUser})
resp := postJSON(t, srv, "/v1/otp/verify", map[string]string{
"user_id": testUser, "code": n.code(),
})
if resp.StatusCode != http.StatusNoContent {
t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusNoContent)
}
}
func TestVerify_WrongCode(t *testing.T) {
srv, _ := newTestServer(t, 5)
_ = postJSON(t, srv, "/v1/otp/issue", map[string]string{"user_id": testUser})
resp := postJSON(t, srv, "/v1/otp/verify", map[string]string{
"user_id": testUser, "code": "000000",
})
if resp.StatusCode != http.StatusUnauthorized {
t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusUnauthorized)
}
}
func TestVerify_NoChallengeMasked(t *testing.T) {
srv, _ := newTestServer(t, 5)
resp := postJSON(t, srv, "/v1/otp/verify", map[string]string{
"user_id": testUser, "code": "000000",
})
if resp.StatusCode != http.StatusUnauthorized {
t.Fatalf("status = %d, want %d (must not leak whether a challenge was issued)",
resp.StatusCode, http.StatusUnauthorized)
}
}
func TestVerify_LockoutReturns429(t *testing.T) {
srv, _ := newTestServer(t, 2)
_ = postJSON(t, srv, "/v1/otp/issue", map[string]string{"user_id": testUser})
for range 2 {
_ = postJSON(t, srv, "/v1/otp/verify", map[string]string{
"user_id": testUser, "code": "000000",
})
}
resp := postJSON(t, srv, "/v1/otp/verify", map[string]string{
"user_id": testUser, "code": "000000",
})
if resp.StatusCode != http.StatusTooManyRequests {
body, _ := io.ReadAll(resp.Body)
t.Fatalf("status = %d, body = %s, want 429", resp.StatusCode, body)
}
}
func TestVerify_BadRequest(t *testing.T) {
srv, _ := newTestServer(t, 5)
resp, err := http.Post(srv.URL+"/v1/otp/verify", "application/json", bytes.NewBufferString("not json"))
if err != nil {
t.Fatalf("POST: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("status = %d, want 400", resp.StatusCode)
}
}
package httpapi
import (
"log/slog"
"net/http"
"runtime/debug"
)
func Recover(log *slog.Logger) func(http.Handler) http.Handler {
if log == nil {
log = slog.Default()
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rw := &responseRecorder{ResponseWriter: w}
defer func() {
rec := recover()
if rec == nil {
return
}
if rec == http.ErrAbortHandler {
panic(rec)
}
log.ErrorContext(r.Context(), "panic recovered",
"panic", rec,
"method", r.Method,
"path", r.URL.Path,
"stack", string(debug.Stack()),
)
if !rw.wroteHeader {
rw.Header().Set("Content-Type", "application/json")
rw.WriteHeader(http.StatusInternalServerError)
_, _ = rw.Write([]byte(`{"error":"internal error"}`))
}
}()
next.ServeHTTP(rw, r)
})
}
}
type responseRecorder struct {
http.ResponseWriter
wroteHeader bool
status int
}
func (r *responseRecorder) WriteHeader(status int) {
if r.wroteHeader {
return
}
r.wroteHeader = true
r.status = status
r.ResponseWriter.WriteHeader(status)
}
func (r *responseRecorder) Write(b []byte) (int, error) {
if !r.wroteHeader {
r.WriteHeader(http.StatusOK)
}
return r.ResponseWriter.Write(b)
}
package httpapi_test
import (
"bytes"
"encoding/json"
"io"
"log/slog"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/cronchie/totpauth/internal/httpapi"
)
func TestRecover_CatchesPanicAndReturns500(t *testing.T) {
var buf bytes.Buffer
log := slog.New(slog.NewTextHandler(&buf, &slog.HandlerOptions{Level: slog.LevelError}))
handler := httpapi.Recover(log)(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {
panic("kaboom")
}))
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/totp/verify", nil))
if rec.Code != http.StatusInternalServerError {
t.Errorf("status = %d, want 500", rec.Code)
}
var body map[string]string
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
t.Fatalf("body not JSON: %v (body=%q)", err, rec.Body.String())
}
if body["error"] != "internal error" {
t.Errorf("body[error] = %q, want %q", body["error"], "internal error")
}
logs := buf.String()
for _, want := range []string{"panic recovered", "kaboom", "/totp/verify"} {
if !strings.Contains(logs, want) {
t.Errorf("log missing %q; got: %s", want, logs)
}
}
}
func TestRecover_PassesThroughWhenNoPanic(t *testing.T) {
log := slog.New(slog.NewTextHandler(io.Discard, nil))
handler := httpapi.Recover(log)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusTeapot)
_, _ = w.Write([]byte("hi"))
}))
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
if rec.Code != http.StatusTeapot {
t.Errorf("status = %d, want %d", rec.Code, http.StatusTeapot)
}
if rec.Body.String() != "hi" {
t.Errorf("body = %q, want %q", rec.Body.String(), "hi")
}
}
func TestRecover_PanicAfterWriteKeepsEarlyStatus(t *testing.T) {
log := slog.New(slog.NewTextHandler(io.Discard, nil))
handler := httpapi.Recover(log)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("partial"))
panic("too late")
}))
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
if rec.Code != http.StatusOK {
t.Errorf("status = %d, want 200 (panic after Write must not overwrite status)", rec.Code)
}
if !strings.HasPrefix(rec.Body.String(), "partial") {
t.Errorf("body = %q, want it to start with %q", rec.Body.String(), "partial")
}
}
func TestRecover_RepanicsErrAbortHandler(t *testing.T) {
log := slog.New(slog.NewTextHandler(io.Discard, nil))
handler := httpapi.Recover(log)(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {
panic(http.ErrAbortHandler)
}))
defer func() {
switch rec := recover(); rec {
case http.ErrAbortHandler:
case nil:
t.Error("middleware swallowed http.ErrAbortHandler; want re-panic")
default:
t.Errorf("re-panicked with %v, want http.ErrAbortHandler", rec)
}
}()
handler.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/", nil))
}
package store
import "time"
const DefaultEntryTTL = 10 * time.Minute
type Entry struct {
Count int
ExpiresAt time.Time
}
func Expired(e *Entry, now time.Time) bool {
return e == nil || now.After(e.ExpiresAt)
}
func Lookup(m map[string]*Entry, key string, now time.Time) (*Entry, bool) {
e, ok := m[key]
if ok && Expired(e, now) {
delete(m, key)
return nil, false
}
return e, ok
}
func Ensure(m map[string]*Entry, key string, ttl time.Duration, now time.Time) *Entry {
if e, ok := Lookup(m, key, now); ok {
return e
}
e := &Entry{ExpiresAt: now.Add(ttl)}
m[key] = e
return e
}
package store
import "time"
type Settings struct {
TTL time.Duration
Now func() time.Time
}
type Option func(*Settings)
func Apply(opts []Option) Settings {
s := Settings{TTL: DefaultEntryTTL, Now: time.Now}
for _, opt := range opts {
opt(&s)
}
return s
}
func WithTTL(d time.Duration) Option {
return func(s *Settings) { s.TTL = d }
}
func WithClock(now func() time.Time) Option {
return func(s *Settings) { s.Now = now }
}
package ratelimit
import (
"context"
"time"
"github.com/cronchie/totpauth/internal/ratelimit/internal/store"
)
type mapLimiter struct {
attempts map[string]*store.Entry
ttl time.Duration
now func() time.Time
}
func New(opts ...Option) Limiter {
s := store.Apply(opts)
return &mapLimiter{
attempts: make(map[string]*store.Entry),
ttl: s.TTL,
now: s.Now,
}
}
func (l *mapLimiter) Locked(_ context.Context, key string, limit int) error {
e, ok := store.Lookup(l.attempts, key, l.now())
if !ok {
return nil
}
if e.Count >= limit {
return ErrLimitExceeded
}
return nil
}
func (l *mapLimiter) RecordFailure(_ context.Context, key string, limit int) error {
e := store.Ensure(l.attempts, key, l.ttl, l.now())
if e.Count >= limit {
return ErrLimitExceeded
}
e.Count++
return nil
}
package ratelimit
import (
"time"
"github.com/cronchie/totpauth/internal/ratelimit/internal/store"
)
const DefaultEntryTTL = store.DefaultEntryTTL
type Option = store.Option
func WithTTL(d time.Duration) Option { return store.WithTTL(d) }
func WithClock(now func() time.Time) Option { return store.WithClock(now) }
package ratelimit
import (
"context"
"errors"
)
var ErrLimitExceeded = errors.New("ratelimit: limit exceeded")
type Limiter interface {
Locked(ctx context.Context, key string, limit int) error
RecordFailure(ctx context.Context, key string, limit int) error
}
package ratelimit_test
import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/cronchie/totpauth/internal/ratelimit"
)
func TestLimiter_LockoutHoldsUnderConcurrentFailures(t *testing.T) {
const (
limit = 5
goroutines = 50
key = "user-42"
)
limiter := ratelimit.New()
ctx := context.Background()
var wg sync.WaitGroup
var admitted int64
for range goroutines {
wg.Add(1)
go func() {
defer wg.Done()
if err := limiter.RecordFailure(ctx, key, limit); err == nil {
atomic.AddInt64(&admitted, 1)
}
}()
}
wg.Wait()
if got := int(admitted); got != limit {
t.Errorf("admitted %d failures past the limit of %d", got, limit)
}
}
func TestLimiter_WindowExpiry(t *testing.T) {
const (
limit = 2
key = "user-42"
ttl = time.Minute
)
now := time.Date(2026, 5, 21, 12, 0, 0, 0, time.UTC)
advance := func(d time.Duration) {
now = now.Add(d)
}
limiter := ratelimit.New(
ratelimit.WithTTL(ttl),
ratelimit.WithClock(func() time.Time { return now }),
)
ctx := context.Background()
for range limit {
if err := limiter.RecordFailure(ctx, key, limit); err != nil {
t.Fatalf("failure should be recorded: %v", err)
}
}
if err := limiter.Locked(ctx, key, limit); !errors.Is(err, ratelimit.ErrLimitExceeded) {
t.Fatalf("want ErrLimitExceeded after budget exhausted, got %v", err)
}
advance(ttl + time.Second)
if err := limiter.Locked(ctx, key, limit); err != nil {
t.Fatalf("key should be unlocked after window expiry: %v", err)
}
}When they ran it, the test failed in a way that took them a minute to understand. Forty-seven of the fifty attempts had been admitted before the limiter started rejecting, against a limit of five.
They ran it again with Go’s race detector turned on. The -race flag instruments the program at compile time so the runtime can watch for goroutines reading and writing the same memory without coordinating, and report what it sees. It reported eight separate data races inside RecordFailure, all on the same pattern: one goroutine reading the entry while another wrote to it, or two goroutines writing to it at once.
Reaching for sync.Map
Looking at the code with the race detector’s output in hand, the structure of the problem was clear. Inside RecordFailure, the read of e.Count happens at one moment, the comparison against limit happens just after, and the increment writes back at a third moment. In between those steps, nothing prevents another goroutine from doing the same thing on the same entry. Two goroutines could both read e.Count as 2, both see that 2 is under the limit of 5, and both increment to 3. The map ended up with whichever write landed last, and the count would lag behind the number of failures that had actually occurred.
This is what v5.0.0-15.4.1, one of the new requirements in ASVS 5.0, calls out:
Verify that shared objects in multi-threaded code (such as caches, files, or in-memory objects accessed by multiple threads) are accessed safely by using thread-safe types and synchronization mechanisms like locks or semaphores to avoid race conditions and data corruption.
The team reached for sync.Map. The choice made sense: the documentation calls out the case where many goroutines read, write, and overwrite entries across disjoint keys, which is what a per-user rate limiter does, one key per user. The methods on a sync.Map are safe to call from any number of goroutines at once, so the team swapped the map for a sync.Map, adjusted the call sites in RecordFailure and Locked to use the Load, Store, and Delete methods, and reran the tests under -race. The race detector stayed silent and the test passed, so the team considered the work done and the bug fixed.
package main
import (
"context"
"errors"
"flag"
"log/slog"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"github.com/cronchie/totpauth/internal/auth"
"github.com/cronchie/totpauth/internal/auth/challenge"
"github.com/cronchie/totpauth/internal/httpapi"
"github.com/cronchie/totpauth/internal/ratelimit"
)
func main() {
addr := flag.String("addr", ":8080", "listen address")
flag.Parse()
log := slog.New(slog.NewJSONHandler(os.Stdout, nil))
notifier := auth.NotifierFunc(func(_ context.Context, userID, code string) error {
log.Warn("OTP code generated; configure a real Notifier", "user_id", userID, "code", code)
return nil
})
limiter := ratelimit.New()
challenges := challenge.New()
validator := auth.NewValidator(limiter, challenges, notifier)
handler := httpapi.Recover(log)(httpapi.NewHandler(validator, log).Routes())
srv := &http.Server{
Addr: *addr,
Handler: handler,
ReadHeaderTimeout: 5 * time.Second,
}
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer stop()
go func() {
log.Info("authd listening", "addr", srv.Addr)
if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Error("listen failed", "err", err)
stop()
}
}()
<-ctx.Done()
log.Info("authd shutting down")
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := srv.Shutdown(shutdownCtx); err != nil {
log.Error("shutdown failed", "err", err)
os.Exit(1)
}
}
module github.com/cronchie/totpauth
go 1.23
package auth
import (
"errors"
"github.com/cronchie/totpauth/internal/auth/challenge"
)
var ErrLockedOut = errors.New("auth: too many failed attempts")
var (
ErrInvalidCode = challenge.ErrInvalidCode
ErrNoChallenge = challenge.ErrNoChallenge
)
package challenge
import (
"context"
"errors"
"time"
)
const DefaultTTL = 5 * time.Minute
var (
ErrInvalidCode = errors.New("challenge: invalid code")
ErrNoChallenge = errors.New("challenge: no active challenge")
)
type Challenge struct {
Code string
ExpiresAt time.Time
}
type Store interface {
Issue(ctx context.Context, userID, code string, ttl time.Duration) error
ConsumeIfMatches(ctx context.Context, userID, candidate string) error
}
package challenge_test
import (
"context"
"errors"
"fmt"
"os"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/cronchie/totpauth/internal/auth/challenge"
)
func TestStore_SingleConsumeUnderBurst(t *testing.T) {
if os.Getenv("CHALLENGE_RACE_DEMO") == "" {
t.Skip("set CHALLENGE_RACE_DEMO=1")
}
const (
iterations = 50
burst = 200
ttl = time.Minute
)
store := challenge.New()
ctx := context.Background()
var totalMatched int64
var other atomic.Value
for iter := range iterations {
userID := fmt.Sprintf("user-%d", iter)
code := fmt.Sprintf("%06d", iter)
if err := store.Issue(ctx, userID, code, ttl); err != nil {
t.Fatalf("Issue: %v", err)
}
var wg sync.WaitGroup
start := make(chan struct{})
for range burst {
wg.Add(1)
go func() {
defer wg.Done()
<-start
switch err := store.ConsumeIfMatches(ctx, userID, code); {
case err == nil:
atomic.AddInt64(&totalMatched, 1)
case errors.Is(err, challenge.ErrNoChallenge),
errors.Is(err, challenge.ErrInvalidCode):
default:
other.Store(err)
}
}()
}
close(start)
wg.Wait()
}
if err, _ := other.Load().(error); err != nil {
t.Fatalf("unexpected store error: %v", err)
}
if totalMatched != iterations {
t.Errorf("totalMatched = %d across %d iterations of %d-goroutine bursts, want exactly %d",
totalMatched, iterations, burst, iterations)
}
}
func TestStore_RejectsWrongCode(t *testing.T) {
const userID = "user-42"
store := challenge.New()
ctx := context.Background()
if err := store.Issue(ctx, userID, "123456", time.Minute); err != nil {
t.Fatalf("Issue: %v", err)
}
err := store.ConsumeIfMatches(ctx, userID, "000000")
if !errors.Is(err, challenge.ErrInvalidCode) {
t.Fatalf("ConsumeIfMatches(wrong) = %v, want ErrInvalidCode", err)
}
if err := store.ConsumeIfMatches(ctx, userID, "123456"); err != nil {
t.Fatalf("ConsumeIfMatches(correct after wrong) = %v, want nil", err)
}
}
func TestStore_ExpiredChallenge(t *testing.T) {
const (
userID = "user-42"
code = "123456"
ttl = time.Minute
)
now := time.Date(2026, 5, 21, 12, 0, 0, 0, time.UTC)
store := challenge.New(challenge.WithClock(func() time.Time { return now }))
ctx := context.Background()
if err := store.Issue(ctx, userID, code, ttl); err != nil {
t.Fatalf("Issue: %v", err)
}
now = now.Add(2 * time.Minute)
err := store.ConsumeIfMatches(ctx, userID, code)
if !errors.Is(err, challenge.ErrNoChallenge) {
t.Fatalf("expired Consume = %v, want ErrNoChallenge", err)
}
}
package store
import "time"
type Settings struct {
Now func() time.Time
}
type Option func(*Settings)
func Apply(opts []Option) Settings {
s := Settings{Now: time.Now}
for _, opt := range opts {
opt(&s)
}
return s
}
func WithClock(now func() time.Time) Option {
return func(s *Settings) { s.Now = now }
}
package challenge
import (
"time"
"github.com/cronchie/totpauth/internal/auth/challenge/internal/store"
)
type Option = store.Option
func WithClock(now func() time.Time) Option { return store.WithClock(now) }
package challenge
import (
"context"
"crypto/subtle"
"sync"
"time"
"github.com/cronchie/totpauth/internal/auth/challenge/internal/store"
)
type syncMapStore struct {
challenges sync.Map
now func() time.Time
}
func New(opts ...Option) Store {
s := store.Apply(opts)
return &syncMapStore{now: s.Now}
}
func (s *syncMapStore) Issue(_ context.Context, userID, code string, ttl time.Duration) error {
s.challenges.Store(userID, Challenge{
Code: code,
ExpiresAt: s.now().Add(ttl),
})
return nil
}
func (s *syncMapStore) ConsumeIfMatches(_ context.Context, userID, candidate string) error {
val, ok := s.challenges.Load(userID)
if !ok {
return ErrNoChallenge
}
ch := val.(Challenge)
if s.now().After(ch.ExpiresAt) {
s.challenges.Delete(userID)
return ErrNoChallenge
}
if subtle.ConstantTimeCompare([]byte(ch.Code), []byte(candidate)) != 1 {
return ErrInvalidCode
}
s.challenges.Delete(userID)
return nil
}
package auth
import "context"
type Notifier interface {
Notify(ctx context.Context, userID, code string) error
}
type NotifierFunc func(ctx context.Context, userID, code string) error
func (f NotifierFunc) Notify(ctx context.Context, userID, code string) error {
return f(ctx, userID, code)
}
package auth
import (
"context"
"crypto/rand"
"errors"
"fmt"
"math/big"
"time"
"github.com/cronchie/totpauth/internal/auth/challenge"
"github.com/cronchie/totpauth/internal/ratelimit"
)
const DefaultMaxAttempts = 5
const codeDigits = 6
var codeSpace = big.NewInt(1_000_000)
type Validator struct {
limiter ratelimit.Limiter
challenges challenge.Store
notifier Notifier
maxAttempts int
challengeTTL time.Duration
}
type Option func(*Validator)
func WithMaxAttempts(n int) Option {
return func(v *Validator) { v.maxAttempts = n }
}
func WithChallengeTTL(d time.Duration) Option {
return func(v *Validator) { v.challengeTTL = d }
}
func NewValidator(limiter ratelimit.Limiter, challenges challenge.Store, notifier Notifier, opts ...Option) *Validator {
v := &Validator{
limiter: limiter,
challenges: challenges,
notifier: notifier,
maxAttempts: DefaultMaxAttempts,
challengeTTL: challenge.DefaultTTL,
}
for _, opt := range opts {
opt(v)
}
return v
}
func (v *Validator) Issue(ctx context.Context, userID string) error {
code, err := generateCode()
if err != nil {
return fmt.Errorf("generating OTP code: %w", err)
}
if err := v.challenges.Issue(ctx, userID, code, v.challengeTTL); err != nil {
return fmt.Errorf("storing challenge: %w", err)
}
if err := v.notifier.Notify(ctx, userID, code); err != nil {
return fmt.Errorf("dispatching OTP: %w", err)
}
return nil
}
func (v *Validator) Verify(ctx context.Context, userID, code string) error {
if err := v.limiter.Locked(ctx, userID, v.maxAttempts); err != nil {
if errors.Is(err, ratelimit.ErrLimitExceeded) {
return ErrLockedOut
}
return fmt.Errorf("rate limiter: %w", err)
}
err := v.challenges.ConsumeIfMatches(ctx, userID, code)
switch {
case err == nil:
return nil
case errors.Is(err, challenge.ErrInvalidCode), errors.Is(err, challenge.ErrNoChallenge):
if rfErr := v.limiter.RecordFailure(ctx, userID, v.maxAttempts); rfErr != nil {
if errors.Is(rfErr, ratelimit.ErrLimitExceeded) {
return ErrLockedOut
}
return fmt.Errorf("rate limiter: %w", rfErr)
}
return err
default:
return fmt.Errorf("consuming challenge: %w", err)
}
}
func generateCode() (string, error) {
n, err := rand.Int(rand.Reader, codeSpace)
if err != nil {
return "", err
}
return fmt.Sprintf("%0*d", codeDigits, n.Int64()), nil
}
package auth_test
import (
"context"
"errors"
"os"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/cronchie/totpauth/internal/auth"
"github.com/cronchie/totpauth/internal/auth/challenge"
"github.com/cronchie/totpauth/internal/ratelimit"
)
const testUser = "user-42"
type capturingNotifier struct {
mu sync.Mutex
last string
}
func (n *capturingNotifier) Notify(_ context.Context, _, code string) error {
n.mu.Lock()
defer n.mu.Unlock()
n.last = code
return nil
}
func (n *capturingNotifier) code() string {
n.mu.Lock()
defer n.mu.Unlock()
return n.last
}
func newTestValidator(t *testing.T) (*auth.Validator, *capturingNotifier) {
t.Helper()
notifier := &capturingNotifier{}
v := auth.NewValidator(
ratelimit.New(),
challenge.New(),
notifier,
auth.WithMaxAttempts(3),
)
return v, notifier
}
func issueAndCode(t *testing.T, v *auth.Validator, n *capturingNotifier, userID string) string {
t.Helper()
if err := v.Issue(context.Background(), userID); err != nil {
t.Fatalf("Issue: %v", err)
}
code := n.code()
if len(code) != 6 {
t.Fatalf("issued code = %q, want 6 digits", code)
}
return code
}
func TestVerify_AcceptsIssuedCode(t *testing.T) {
v, n := newTestValidator(t)
code := issueAndCode(t, v, n, testUser)
if err := v.Verify(context.Background(), testUser, code); err != nil {
t.Fatalf("Verify(issued) = %v, want nil", err)
}
}
func TestVerify_RejectsWrongCode(t *testing.T) {
v, n := newTestValidator(t)
_ = issueAndCode(t, v, n, testUser)
err := v.Verify(context.Background(), testUser, "000000")
if !errors.Is(err, auth.ErrInvalidCode) {
t.Fatalf("Verify(wrong) = %v, want ErrInvalidCode", err)
}
}
func TestVerify_RejectsAfterConsumption(t *testing.T) {
v, n := newTestValidator(t)
code := issueAndCode(t, v, n, testUser)
ctx := context.Background()
if err := v.Verify(ctx, testUser, code); err != nil {
t.Fatalf("first Verify: %v", err)
}
err := v.Verify(ctx, testUser, code)
if !errors.Is(err, auth.ErrNoChallenge) {
t.Fatalf("replay Verify = %v, want ErrNoChallenge", err)
}
}
func TestVerify_NoChallengeIssued(t *testing.T) {
v, _ := newTestValidator(t)
err := v.Verify(context.Background(), testUser, "000000")
if !errors.Is(err, auth.ErrNoChallenge) {
t.Fatalf("Verify(no challenge) = %v, want ErrNoChallenge", err)
}
}
func TestVerify_LocksOutAfterMaxAttempts(t *testing.T) {
v, n := newTestValidator(t)
_ = issueAndCode(t, v, n, testUser)
ctx := context.Background()
for range 3 {
err := v.Verify(ctx, testUser, "000000")
if errors.Is(err, auth.ErrLockedOut) {
t.Fatal("locked out during budget")
}
}
err := v.Verify(ctx, testUser, "000000")
if !errors.Is(err, auth.ErrLockedOut) {
t.Fatalf("after budget: got %v, want ErrLockedOut", err)
}
}
func TestVerify_ConsumeRaceUnderBurst(t *testing.T) {
if os.Getenv("CHALLENGE_RACE_DEMO") == "" {
t.Skip("set CHALLENGE_RACE_DEMO=1")
}
const burst = 50
notifier := &capturingNotifier{}
v := auth.NewValidator(
ratelimit.New(),
challenge.New(),
notifier,
auth.WithMaxAttempts(burst),
)
if err := v.Issue(context.Background(), testUser); err != nil {
t.Fatalf("Issue: %v", err)
}
code := notifier.code()
var wg sync.WaitGroup
start := make(chan struct{})
var successes int64
for range burst {
wg.Add(1)
go func() {
defer wg.Done()
<-start
if err := v.Verify(context.Background(), testUser, code); err == nil {
atomic.AddInt64(&successes, 1)
}
}()
}
close(start)
wg.Wait()
if successes != 1 {
t.Errorf("successes = %d, want exactly 1 (the code is single-use)", successes)
}
}
func TestVerify_ExpiredChallenge(t *testing.T) {
now := time.Date(2026, 5, 21, 12, 0, 0, 0, time.UTC)
store := challenge.New(
challenge.WithClock(func() time.Time { return now }),
)
notifier := &capturingNotifier{}
v := auth.NewValidator(
ratelimit.New(),
store,
notifier,
auth.WithChallengeTTL(time.Minute),
)
if err := v.Issue(context.Background(), testUser); err != nil {
t.Fatalf("Issue: %v", err)
}
code := notifier.code()
now = now.Add(2 * time.Minute)
err := v.Verify(context.Background(), testUser, code)
if !errors.Is(err, auth.ErrNoChallenge) {
t.Fatalf("expired Verify = %v, want ErrNoChallenge", err)
}
}
package httpapi
import (
"encoding/json"
"errors"
"log/slog"
"net/http"
"github.com/cronchie/totpauth/internal/auth"
)
type Handler struct {
validator *auth.Validator
log *slog.Logger
}
func NewHandler(validator *auth.Validator, log *slog.Logger) *Handler {
if log == nil {
log = slog.Default()
}
return &Handler{validator: validator, log: log}
}
func (h *Handler) Routes() *http.ServeMux {
mux := http.NewServeMux()
mux.HandleFunc("POST /v1/otp/issue", h.issueOTP)
mux.HandleFunc("POST /v1/otp/verify", h.verifyOTP)
return mux
}
type issueRequest struct {
UserID string `json:"user_id"`
}
func (h *Handler) issueOTP(w http.ResponseWriter, r *http.Request) {
var req issueRequest
if err := json.NewDecoder(http.MaxBytesReader(w, r.Body, 1<<10)).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid request body")
return
}
if req.UserID == "" {
writeError(w, http.StatusBadRequest, "user_id is required")
return
}
if err := h.validator.Issue(r.Context(), req.UserID); err != nil {
h.log.ErrorContext(r.Context(), "otp issue failed",
"user_id", req.UserID, "err", err)
writeError(w, http.StatusInternalServerError, "internal error")
return
}
w.WriteHeader(http.StatusAccepted)
}
type verifyRequest struct {
UserID string `json:"user_id"`
Code string `json:"code"`
}
func (h *Handler) verifyOTP(w http.ResponseWriter, r *http.Request) {
var req verifyRequest
if err := json.NewDecoder(http.MaxBytesReader(w, r.Body, 1<<10)).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid request body")
return
}
if req.UserID == "" || req.Code == "" {
writeError(w, http.StatusBadRequest, "user_id and code are required")
return
}
switch err := h.validator.Verify(r.Context(), req.UserID, req.Code); {
case err == nil:
w.WriteHeader(http.StatusNoContent)
case errors.Is(err, auth.ErrLockedOut):
writeError(w, http.StatusTooManyRequests, "too many failed attempts")
case errors.Is(err, auth.ErrInvalidCode), errors.Is(err, auth.ErrNoChallenge):
writeError(w, http.StatusUnauthorized, "invalid credentials")
default:
h.log.ErrorContext(r.Context(), "otp verify failed",
"user_id", req.UserID, "err", err)
writeError(w, http.StatusInternalServerError, "internal error")
}
}
func writeError(w http.ResponseWriter, status int, msg string) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
_ = json.NewEncoder(w).Encode(map[string]string{"error": msg})
}
package httpapi_test
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"sync"
"testing"
"github.com/cronchie/totpauth/internal/auth"
"github.com/cronchie/totpauth/internal/auth/challenge"
"github.com/cronchie/totpauth/internal/httpapi"
"github.com/cronchie/totpauth/internal/ratelimit"
)
const testUser = "user-42"
type capturingNotifier struct {
mu sync.Mutex
last string
}
func (n *capturingNotifier) Notify(_ context.Context, _, code string) error {
n.mu.Lock()
defer n.mu.Unlock()
n.last = code
return nil
}
func (n *capturingNotifier) code() string {
n.mu.Lock()
defer n.mu.Unlock()
return n.last
}
func newTestServer(t *testing.T, maxAttempts int) (*httptest.Server, *capturingNotifier) {
t.Helper()
notifier := &capturingNotifier{}
validator := auth.NewValidator(
ratelimit.New(),
challenge.New(),
notifier,
auth.WithMaxAttempts(maxAttempts),
)
srv := httptest.NewServer(httpapi.NewHandler(validator, nil).Routes())
t.Cleanup(srv.Close)
return srv, notifier
}
func postJSON(t *testing.T, srv *httptest.Server, path string, body any) *http.Response {
t.Helper()
buf, err := json.Marshal(body)
if err != nil {
t.Fatalf("marshal: %v", err)
}
resp, err := http.Post(srv.URL+path, "application/json", bytes.NewReader(buf))
if err != nil {
t.Fatalf("POST %s: %v", path, err)
}
t.Cleanup(func() { resp.Body.Close() })
return resp
}
func TestIssue_Accepted(t *testing.T) {
srv, n := newTestServer(t, 5)
resp := postJSON(t, srv, "/v1/otp/issue", map[string]string{"user_id": testUser})
if resp.StatusCode != http.StatusAccepted {
t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusAccepted)
}
if got := n.code(); len(got) != 6 {
t.Fatalf("notifier received %q, want 6-digit code", got)
}
}
func TestVerify_Success(t *testing.T) {
srv, n := newTestServer(t, 5)
_ = postJSON(t, srv, "/v1/otp/issue", map[string]string{"user_id": testUser})
resp := postJSON(t, srv, "/v1/otp/verify", map[string]string{
"user_id": testUser, "code": n.code(),
})
if resp.StatusCode != http.StatusNoContent {
t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusNoContent)
}
}
func TestVerify_WrongCode(t *testing.T) {
srv, _ := newTestServer(t, 5)
_ = postJSON(t, srv, "/v1/otp/issue", map[string]string{"user_id": testUser})
resp := postJSON(t, srv, "/v1/otp/verify", map[string]string{
"user_id": testUser, "code": "000000",
})
if resp.StatusCode != http.StatusUnauthorized {
t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusUnauthorized)
}
}
func TestVerify_NoChallengeMasked(t *testing.T) {
srv, _ := newTestServer(t, 5)
resp := postJSON(t, srv, "/v1/otp/verify", map[string]string{
"user_id": testUser, "code": "000000",
})
if resp.StatusCode != http.StatusUnauthorized {
t.Fatalf("status = %d, want %d (must not leak whether a challenge was issued)",
resp.StatusCode, http.StatusUnauthorized)
}
}
func TestVerify_LockoutReturns429(t *testing.T) {
srv, _ := newTestServer(t, 2)
_ = postJSON(t, srv, "/v1/otp/issue", map[string]string{"user_id": testUser})
for range 2 {
_ = postJSON(t, srv, "/v1/otp/verify", map[string]string{
"user_id": testUser, "code": "000000",
})
}
resp := postJSON(t, srv, "/v1/otp/verify", map[string]string{
"user_id": testUser, "code": "000000",
})
if resp.StatusCode != http.StatusTooManyRequests {
body, _ := io.ReadAll(resp.Body)
t.Fatalf("status = %d, body = %s, want 429", resp.StatusCode, body)
}
}
func TestVerify_BadRequest(t *testing.T) {
srv, _ := newTestServer(t, 5)
resp, err := http.Post(srv.URL+"/v1/otp/verify", "application/json", bytes.NewBufferString("not json"))
if err != nil {
t.Fatalf("POST: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("status = %d, want 400", resp.StatusCode)
}
}
package httpapi
import (
"log/slog"
"net/http"
"runtime/debug"
)
func Recover(log *slog.Logger) func(http.Handler) http.Handler {
if log == nil {
log = slog.Default()
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rw := &responseRecorder{ResponseWriter: w}
defer func() {
rec := recover()
if rec == nil {
return
}
if rec == http.ErrAbortHandler {
panic(rec)
}
log.ErrorContext(r.Context(), "panic recovered",
"panic", rec,
"method", r.Method,
"path", r.URL.Path,
"stack", string(debug.Stack()),
)
if !rw.wroteHeader {
rw.Header().Set("Content-Type", "application/json")
rw.WriteHeader(http.StatusInternalServerError)
_, _ = rw.Write([]byte(`{"error":"internal error"}`))
}
}()
next.ServeHTTP(rw, r)
})
}
}
type responseRecorder struct {
http.ResponseWriter
wroteHeader bool
status int
}
func (r *responseRecorder) WriteHeader(status int) {
if r.wroteHeader {
return
}
r.wroteHeader = true
r.status = status
r.ResponseWriter.WriteHeader(status)
}
func (r *responseRecorder) Write(b []byte) (int, error) {
if !r.wroteHeader {
r.WriteHeader(http.StatusOK)
}
return r.ResponseWriter.Write(b)
}
package httpapi_test
import (
"bytes"
"encoding/json"
"io"
"log/slog"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/cronchie/totpauth/internal/httpapi"
)
func TestRecover_CatchesPanicAndReturns500(t *testing.T) {
var buf bytes.Buffer
log := slog.New(slog.NewTextHandler(&buf, &slog.HandlerOptions{Level: slog.LevelError}))
handler := httpapi.Recover(log)(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {
panic("kaboom")
}))
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/totp/verify", nil))
if rec.Code != http.StatusInternalServerError {
t.Errorf("status = %d, want 500", rec.Code)
}
var body map[string]string
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
t.Fatalf("body not JSON: %v (body=%q)", err, rec.Body.String())
}
if body["error"] != "internal error" {
t.Errorf("body[error] = %q, want %q", body["error"], "internal error")
}
logs := buf.String()
for _, want := range []string{"panic recovered", "kaboom", "/totp/verify"} {
if !strings.Contains(logs, want) {
t.Errorf("log missing %q; got: %s", want, logs)
}
}
}
func TestRecover_PassesThroughWhenNoPanic(t *testing.T) {
log := slog.New(slog.NewTextHandler(io.Discard, nil))
handler := httpapi.Recover(log)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusTeapot)
_, _ = w.Write([]byte("hi"))
}))
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
if rec.Code != http.StatusTeapot {
t.Errorf("status = %d, want %d", rec.Code, http.StatusTeapot)
}
if rec.Body.String() != "hi" {
t.Errorf("body = %q, want %q", rec.Body.String(), "hi")
}
}
func TestRecover_PanicAfterWriteKeepsEarlyStatus(t *testing.T) {
log := slog.New(slog.NewTextHandler(io.Discard, nil))
handler := httpapi.Recover(log)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("partial"))
panic("too late")
}))
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
if rec.Code != http.StatusOK {
t.Errorf("status = %d, want 200 (panic after Write must not overwrite status)", rec.Code)
}
if !strings.HasPrefix(rec.Body.String(), "partial") {
t.Errorf("body = %q, want it to start with %q", rec.Body.String(), "partial")
}
}
func TestRecover_RepanicsErrAbortHandler(t *testing.T) {
log := slog.New(slog.NewTextHandler(io.Discard, nil))
handler := httpapi.Recover(log)(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {
panic(http.ErrAbortHandler)
}))
defer func() {
switch rec := recover(); rec {
case http.ErrAbortHandler:
case nil:
t.Error("middleware swallowed http.ErrAbortHandler; want re-panic")
default:
t.Errorf("re-panicked with %v, want http.ErrAbortHandler", rec)
}
}()
handler.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/", nil))
}
package store
import "time"
const DefaultEntryTTL = 10 * time.Minute
type Entry struct {
Count int
ExpiresAt time.Time
}
func Expired(e *Entry, now time.Time) bool {
return e == nil || now.After(e.ExpiresAt)
}
func Lookup(m map[string]*Entry, key string, now time.Time) (*Entry, bool) {
e, ok := m[key]
if ok && Expired(e, now) {
delete(m, key)
return nil, false
}
return e, ok
}
func Ensure(m map[string]*Entry, key string, ttl time.Duration, now time.Time) *Entry {
if e, ok := Lookup(m, key, now); ok {
return e
}
e := &Entry{ExpiresAt: now.Add(ttl)}
m[key] = e
return e
}
package store
import "time"
type Settings struct {
TTL time.Duration
Now func() time.Time
}
type Option func(*Settings)
func Apply(opts []Option) Settings {
s := Settings{TTL: DefaultEntryTTL, Now: time.Now}
for _, opt := range opts {
opt(&s)
}
return s
}
func WithTTL(d time.Duration) Option {
return func(s *Settings) { s.TTL = d }
}
func WithClock(now func() time.Time) Option {
return func(s *Settings) { s.Now = now }
}
package ratelimit
import (
"context"
"sync"
"time"
"github.com/cronchie/totpauth/internal/ratelimit/internal/store"
)
type syncMapLimiter struct {
attempts sync.Map
ttl time.Duration
now func() time.Time
}
func New(opts ...Option) Limiter {
s := store.Apply(opts)
return &syncMapLimiter{ttl: s.TTL, now: s.Now}
}
func (l *syncMapLimiter) load(key string, now time.Time) (store.Entry, bool) {
val, ok := l.attempts.Load(key)
if !ok {
return store.Entry{}, false
}
e := val.(store.Entry)
if store.Expired(&e, now) {
l.attempts.Delete(key)
return store.Entry{}, false
}
return e, true
}
func (l *syncMapLimiter) Locked(_ context.Context, key string, limit int) error {
e, ok := l.load(key, l.now())
if !ok {
return nil
}
if e.Count >= limit {
return ErrLimitExceeded
}
return nil
}
func (l *syncMapLimiter) RecordFailure(_ context.Context, key string, limit int) error {
now := l.now()
e, ok := l.load(key, now)
if !ok {
e = store.Entry{ExpiresAt: now.Add(l.ttl)}
}
if e.Count >= limit {
return ErrLimitExceeded
}
e.Count++
l.attempts.Store(key, e)
return nil
}
package ratelimit
import (
"time"
"github.com/cronchie/totpauth/internal/ratelimit/internal/store"
)
const DefaultEntryTTL = store.DefaultEntryTTL
type Option = store.Option
func WithTTL(d time.Duration) Option { return store.WithTTL(d) }
func WithClock(now func() time.Time) Option { return store.WithClock(now) }
package ratelimit
import (
"context"
"errors"
)
var ErrLimitExceeded = errors.New("ratelimit: limit exceeded")
type Limiter interface {
Locked(ctx context.Context, key string, limit int) error
RecordFailure(ctx context.Context, key string, limit int) error
}
package ratelimit_test
import (
"context"
"errors"
"os"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/cronchie/totpauth/internal/ratelimit"
)
func TestLimiter_ConcurrentSafety(t *testing.T) {
if os.Getenv("RATELIMIT_RACE_DEMO") == "" {
t.Skip("set RATELIMIT_RACE_DEMO=1")
}
const (
limit = 5
goroutines = 50
key = "user-42"
)
limiter := ratelimit.New()
var (
wg sync.WaitGroup
start = make(chan struct{})
recorded int64
rejected int64
other atomic.Value
)
for range goroutines {
wg.Add(1)
go func() {
defer wg.Done()
<-start
switch err := limiter.RecordFailure(context.Background(), key, limit); {
case err == nil:
atomic.AddInt64(&recorded, 1)
case errors.Is(err, ratelimit.ErrLimitExceeded):
atomic.AddInt64(&rejected, 1)
default:
other.Store(err)
}
}()
}
close(start)
wg.Wait()
if err, _ := other.Load().(error); err != nil {
t.Fatalf("unexpected limiter error: %v", err)
}
if got := int(recorded); got != limit {
t.Errorf("recorded %d failures, want exactly %d (rejected=%d)",
got, limit, rejected)
}
if got := int(recorded + rejected); got != goroutines {
t.Errorf("accounted for %d of %d requests", got, goroutines)
}
}
func runFailureBurstAndProbeLockout(limiter ratelimit.Limiter, key string, limit, burst int) (locked bool) {
ctx := context.Background()
var wg sync.WaitGroup
start := make(chan struct{})
for range burst {
wg.Add(1)
go func() {
defer wg.Done()
<-start
_ = limiter.RecordFailure(ctx, key, limit)
}()
}
close(start)
wg.Wait()
return errors.Is(limiter.Locked(ctx, key, limit), ratelimit.ErrLimitExceeded)
}
func TestLimiter_LockoutSilentlyFailsAfterBurst(t *testing.T) {
if os.Getenv("RATELIMIT_RACE_DEMO") == "" {
t.Skip("set RATELIMIT_RACE_DEMO=1")
}
const (
iterations = 50
key = "user-42"
limit = 5
burst = 100
)
var unlockedCount int
for range iterations {
if !runFailureBurstAndProbeLockout(ratelimit.New(), key, limit, burst) {
unlockedCount++
}
}
if unlockedCount == 0 {
t.Errorf("syncmap impl: 0/%d iterations left key unlocked after %d-attempt bursts; "+
"expected at least 1 to demonstrate under-lock race",
iterations, burst)
}
t.Logf("syncmap under-locked %d of %d iterations (silent failure: account stays open after %d-attempt burst against limit %d)",
unlockedCount, iterations, burst, limit)
}
func TestLimiter_WindowExpiry(t *testing.T) {
const (
limit = 2
key = "user-42"
ttl = time.Minute
)
now := time.Date(2026, 5, 21, 12, 0, 0, 0, time.UTC)
advance := func(d time.Duration) {
now = now.Add(d)
}
limiter := ratelimit.New(
ratelimit.WithTTL(ttl),
ratelimit.WithClock(func() time.Time { return now }),
)
ctx := context.Background()
for range limit {
if err := limiter.RecordFailure(ctx, key, limit); err != nil {
t.Fatalf("failure should be recorded: %v", err)
}
}
if err := limiter.Locked(ctx, key, limit); !errors.Is(err, ratelimit.ErrLimitExceeded) {
t.Fatalf("want ErrLimitExceeded after budget exhausted, got %v", err)
}
advance(ttl + time.Second)
if err := limiter.Locked(ctx, key, limit); err != nil {
t.Fatalf("key should be unlocked after window expiry: %v", err)
}
}
A Second Test
Something about the work felt unfinished. The team’s test had checked that exactly five attempts were admitted, but that’s not quite the property their threat model cared about. The identified risk was about whether an attacker could keep guessing, and whether the account would lock and stay locked. The team went back to the test and added a second assertion: after the burst of fifty concurrent failures, calling Locked on the key should return ErrLimitExceeded.
The new assertion failed even though the race detector stayed silent and the first assertion passed. After fifty concurrent failures against a limit of five, the stored count sat at one. Every goroutine had loaded the entry while its count was still zero, checked that zero was under five, incremented its own local copy to one, and stored it back. The fifty stores all wrote the same value, and the count converged on a single increment regardless of how many failures had occurred. The account wasn’t locked, an attacker could come back and trigger another burst, and the count would climb by one.
To make matters worse, a counter stuck at one never locks the account, and an unlocked account sets off no alarms. An attacker can make dozens of OTP guesses while the team’s logs and service monitors show nothing amiss. So in this case, this broken rate limiter was arguably worse than no limiter at all.
Going Back to the Standard
Satisfying 15.4.1 had made the data structure safe, but the bug was still there. The bug pointed at something 15.4.2 names directly:
Verify that checks on a resource’s state, such as its existence or permissions, and the actions that depend on them are performed as a single atomic operation to prevent time-of-check to time-of-use (TOCTOU) race conditions. For example, checking if a file exists before opening it, or verifying a user’s access before granting it.
That’s exactly what RecordFailure does wrong. The check is “is e.Count under the limit?” The action that depends on it is incrementing e.Count and storing the result. The two steps aren’t a single atomic operation, even when the underlying map is thread-safe. Fifty goroutines can all check, all see room under the limit, all increment their local copy, and all store back the same value.
package main
import (
"context"
"errors"
"flag"
"log/slog"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"github.com/cronchie/totpauth/internal/auth"
"github.com/cronchie/totpauth/internal/auth/challenge"
"github.com/cronchie/totpauth/internal/httpapi"
"github.com/cronchie/totpauth/internal/ratelimit"
)
func main() {
addr := flag.String("addr", ":8080", "listen address")
flag.Parse()
log := slog.New(slog.NewJSONHandler(os.Stdout, nil))
notifier := auth.NotifierFunc(func(_ context.Context, userID, code string) error {
log.Warn("OTP code generated; configure a real Notifier", "user_id", userID, "code", code)
return nil
})
limiter := ratelimit.New()
challenges := challenge.New()
validator := auth.NewValidator(limiter, challenges, notifier)
handler := httpapi.Recover(log)(httpapi.NewHandler(validator, log).Routes())
srv := &http.Server{
Addr: *addr,
Handler: handler,
ReadHeaderTimeout: 5 * time.Second,
}
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer stop()
go func() {
log.Info("authd listening", "addr", srv.Addr)
if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Error("listen failed", "err", err)
stop()
}
}()
<-ctx.Done()
log.Info("authd shutting down")
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := srv.Shutdown(shutdownCtx); err != nil {
log.Error("shutdown failed", "err", err)
os.Exit(1)
}
}
module github.com/cronchie/totpauth
go 1.23
package auth
import (
"errors"
"github.com/cronchie/totpauth/internal/auth/challenge"
)
var ErrLockedOut = errors.New("auth: too many failed attempts")
var (
ErrInvalidCode = challenge.ErrInvalidCode
ErrNoChallenge = challenge.ErrNoChallenge
)
package challenge
import (
"context"
"errors"
"time"
)
const DefaultTTL = 5 * time.Minute
var (
ErrInvalidCode = errors.New("challenge: invalid code")
ErrNoChallenge = errors.New("challenge: no active challenge")
)
type Challenge struct {
Code string
ExpiresAt time.Time
}
type Store interface {
Issue(ctx context.Context, userID, code string, ttl time.Duration) error
ConsumeIfMatches(ctx context.Context, userID, candidate string) error
}
package challenge_test
import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/cronchie/totpauth/internal/auth/challenge"
)
func TestStore_SingleConsumeUnderBurst(t *testing.T) {
const (
iterations = 50
burst = 200
ttl = time.Minute
)
store := challenge.New()
ctx := context.Background()
var totalMatched int64
var other atomic.Value
for iter := range iterations {
userID := fmt.Sprintf("user-%d", iter)
code := fmt.Sprintf("%06d", iter)
if err := store.Issue(ctx, userID, code, ttl); err != nil {
t.Fatalf("Issue: %v", err)
}
var wg sync.WaitGroup
start := make(chan struct{})
for range burst {
wg.Add(1)
go func() {
defer wg.Done()
<-start
switch err := store.ConsumeIfMatches(ctx, userID, code); {
case err == nil:
atomic.AddInt64(&totalMatched, 1)
case errors.Is(err, challenge.ErrNoChallenge),
errors.Is(err, challenge.ErrInvalidCode):
default:
other.Store(err)
}
}()
}
close(start)
wg.Wait()
}
if err, _ := other.Load().(error); err != nil {
t.Fatalf("unexpected store error: %v", err)
}
if totalMatched != iterations {
t.Errorf("totalMatched = %d across %d iterations of %d-goroutine bursts, want exactly %d",
totalMatched, iterations, burst, iterations)
}
}
func TestStore_RejectsWrongCode(t *testing.T) {
const userID = "user-42"
store := challenge.New()
ctx := context.Background()
if err := store.Issue(ctx, userID, "123456", time.Minute); err != nil {
t.Fatalf("Issue: %v", err)
}
err := store.ConsumeIfMatches(ctx, userID, "000000")
if !errors.Is(err, challenge.ErrInvalidCode) {
t.Fatalf("ConsumeIfMatches(wrong) = %v, want ErrInvalidCode", err)
}
if err := store.ConsumeIfMatches(ctx, userID, "123456"); err != nil {
t.Fatalf("ConsumeIfMatches(correct after wrong) = %v, want nil", err)
}
}
func TestStore_ExpiredChallenge(t *testing.T) {
const (
userID = "user-42"
code = "123456"
ttl = time.Minute
)
now := time.Date(2026, 5, 21, 12, 0, 0, 0, time.UTC)
store := challenge.New(challenge.WithClock(func() time.Time { return now }))
ctx := context.Background()
if err := store.Issue(ctx, userID, code, ttl); err != nil {
t.Fatalf("Issue: %v", err)
}
now = now.Add(2 * time.Minute)
err := store.ConsumeIfMatches(ctx, userID, code)
if !errors.Is(err, challenge.ErrNoChallenge) {
t.Fatalf("expired Consume = %v, want ErrNoChallenge", err)
}
}
package store
import "time"
type Settings struct {
Now func() time.Time
}
type Option func(*Settings)
func Apply(opts []Option) Settings {
s := Settings{Now: time.Now}
for _, opt := range opts {
opt(&s)
}
return s
}
func WithClock(now func() time.Time) Option {
return func(s *Settings) { s.Now = now }
}
package challenge
import (
"time"
"github.com/cronchie/totpauth/internal/auth/challenge/internal/store"
)
type Option = store.Option
func WithClock(now func() time.Time) Option { return store.WithClock(now) }
package challenge
import (
"context"
"crypto/subtle"
"sync"
"time"
"github.com/cronchie/totpauth/internal/auth/challenge/internal/store"
)
type mutexStore struct {
mu sync.Mutex
challenges map[string]Challenge
now func() time.Time
}
func New(opts ...Option) Store {
s := store.Apply(opts)
return &mutexStore{
challenges: make(map[string]Challenge),
now: s.Now,
}
}
func (s *mutexStore) Issue(_ context.Context, userID, code string, ttl time.Duration) error {
s.mu.Lock()
defer s.mu.Unlock()
s.challenges[userID] = Challenge{
Code: code,
ExpiresAt: s.now().Add(ttl),
}
return nil
}
func (s *mutexStore) ConsumeIfMatches(_ context.Context, userID, candidate string) error {
s.mu.Lock()
defer s.mu.Unlock()
ch, ok := s.challenges[userID]
if !ok || s.now().After(ch.ExpiresAt) {
delete(s.challenges, userID)
return ErrNoChallenge
}
if subtle.ConstantTimeCompare([]byte(ch.Code), []byte(candidate)) != 1 {
return ErrInvalidCode
}
delete(s.challenges, userID)
return nil
}
package auth
import "context"
type Notifier interface {
Notify(ctx context.Context, userID, code string) error
}
type NotifierFunc func(ctx context.Context, userID, code string) error
func (f NotifierFunc) Notify(ctx context.Context, userID, code string) error {
return f(ctx, userID, code)
}
package auth
import (
"context"
"crypto/rand"
"errors"
"fmt"
"math/big"
"time"
"github.com/cronchie/totpauth/internal/auth/challenge"
"github.com/cronchie/totpauth/internal/ratelimit"
)
const DefaultMaxAttempts = 5
const codeDigits = 6
var codeSpace = big.NewInt(1_000_000)
type Validator struct {
limiter ratelimit.Limiter
challenges challenge.Store
notifier Notifier
maxAttempts int
challengeTTL time.Duration
}
type Option func(*Validator)
func WithMaxAttempts(n int) Option {
return func(v *Validator) { v.maxAttempts = n }
}
func WithChallengeTTL(d time.Duration) Option {
return func(v *Validator) { v.challengeTTL = d }
}
func NewValidator(limiter ratelimit.Limiter, challenges challenge.Store, notifier Notifier, opts ...Option) *Validator {
v := &Validator{
limiter: limiter,
challenges: challenges,
notifier: notifier,
maxAttempts: DefaultMaxAttempts,
challengeTTL: challenge.DefaultTTL,
}
for _, opt := range opts {
opt(v)
}
return v
}
func (v *Validator) Issue(ctx context.Context, userID string) error {
code, err := generateCode()
if err != nil {
return fmt.Errorf("generating OTP code: %w", err)
}
if err := v.challenges.Issue(ctx, userID, code, v.challengeTTL); err != nil {
return fmt.Errorf("storing challenge: %w", err)
}
if err := v.notifier.Notify(ctx, userID, code); err != nil {
return fmt.Errorf("dispatching OTP: %w", err)
}
return nil
}
func (v *Validator) Verify(ctx context.Context, userID, code string) error {
if err := v.limiter.Locked(ctx, userID, v.maxAttempts); err != nil {
if errors.Is(err, ratelimit.ErrLimitExceeded) {
return ErrLockedOut
}
return fmt.Errorf("rate limiter: %w", err)
}
err := v.challenges.ConsumeIfMatches(ctx, userID, code)
switch {
case err == nil:
return nil
case errors.Is(err, challenge.ErrInvalidCode), errors.Is(err, challenge.ErrNoChallenge):
if rfErr := v.limiter.RecordFailure(ctx, userID, v.maxAttempts); rfErr != nil {
if errors.Is(rfErr, ratelimit.ErrLimitExceeded) {
return ErrLockedOut
}
return fmt.Errorf("rate limiter: %w", rfErr)
}
return err
default:
return fmt.Errorf("consuming challenge: %w", err)
}
}
func generateCode() (string, error) {
n, err := rand.Int(rand.Reader, codeSpace)
if err != nil {
return "", err
}
return fmt.Sprintf("%0*d", codeDigits, n.Int64()), nil
}
package auth_test
import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/cronchie/totpauth/internal/auth"
"github.com/cronchie/totpauth/internal/auth/challenge"
"github.com/cronchie/totpauth/internal/ratelimit"
)
const testUser = "user-42"
type capturingNotifier struct {
mu sync.Mutex
last string
}
func (n *capturingNotifier) Notify(_ context.Context, _, code string) error {
n.mu.Lock()
defer n.mu.Unlock()
n.last = code
return nil
}
func (n *capturingNotifier) code() string {
n.mu.Lock()
defer n.mu.Unlock()
return n.last
}
func newTestValidator(t *testing.T) (*auth.Validator, *capturingNotifier) {
t.Helper()
notifier := &capturingNotifier{}
v := auth.NewValidator(
ratelimit.New(),
challenge.New(),
notifier,
auth.WithMaxAttempts(3),
)
return v, notifier
}
func issueAndCode(t *testing.T, v *auth.Validator, n *capturingNotifier, userID string) string {
t.Helper()
if err := v.Issue(context.Background(), userID); err != nil {
t.Fatalf("Issue: %v", err)
}
code := n.code()
if len(code) != 6 {
t.Fatalf("issued code = %q, want 6 digits", code)
}
return code
}
func TestVerify_AcceptsIssuedCode(t *testing.T) {
v, n := newTestValidator(t)
code := issueAndCode(t, v, n, testUser)
if err := v.Verify(context.Background(), testUser, code); err != nil {
t.Fatalf("Verify(issued) = %v, want nil", err)
}
}
func TestVerify_RejectsWrongCode(t *testing.T) {
v, n := newTestValidator(t)
_ = issueAndCode(t, v, n, testUser)
err := v.Verify(context.Background(), testUser, "000000")
if !errors.Is(err, auth.ErrInvalidCode) {
t.Fatalf("Verify(wrong) = %v, want ErrInvalidCode", err)
}
}
func TestVerify_RejectsAfterConsumption(t *testing.T) {
v, n := newTestValidator(t)
code := issueAndCode(t, v, n, testUser)
ctx := context.Background()
if err := v.Verify(ctx, testUser, code); err != nil {
t.Fatalf("first Verify: %v", err)
}
err := v.Verify(ctx, testUser, code)
if !errors.Is(err, auth.ErrNoChallenge) {
t.Fatalf("replay Verify = %v, want ErrNoChallenge", err)
}
}
func TestVerify_NoChallengeIssued(t *testing.T) {
v, _ := newTestValidator(t)
err := v.Verify(context.Background(), testUser, "000000")
if !errors.Is(err, auth.ErrNoChallenge) {
t.Fatalf("Verify(no challenge) = %v, want ErrNoChallenge", err)
}
}
func TestVerify_LocksOutAfterMaxAttempts(t *testing.T) {
v, n := newTestValidator(t)
_ = issueAndCode(t, v, n, testUser)
ctx := context.Background()
for range 3 {
err := v.Verify(ctx, testUser, "000000")
if errors.Is(err, auth.ErrLockedOut) {
t.Fatal("locked out during budget")
}
}
err := v.Verify(ctx, testUser, "000000")
if !errors.Is(err, auth.ErrLockedOut) {
t.Fatalf("after budget: got %v, want ErrLockedOut", err)
}
}
func TestVerify_ConsumeRaceUnderBurst(t *testing.T) {
const burst = 50
notifier := &capturingNotifier{}
v := auth.NewValidator(
ratelimit.New(),
challenge.New(),
notifier,
auth.WithMaxAttempts(burst),
)
if err := v.Issue(context.Background(), testUser); err != nil {
t.Fatalf("Issue: %v", err)
}
code := notifier.code()
var wg sync.WaitGroup
start := make(chan struct{})
var successes int64
for range burst {
wg.Add(1)
go func() {
defer wg.Done()
<-start
if err := v.Verify(context.Background(), testUser, code); err == nil {
atomic.AddInt64(&successes, 1)
}
}()
}
close(start)
wg.Wait()
if successes != 1 {
t.Errorf("successes = %d, want exactly 1 (the code is single-use)", successes)
}
}
func TestVerify_ExpiredChallenge(t *testing.T) {
now := time.Date(2026, 5, 21, 12, 0, 0, 0, time.UTC)
store := challenge.New(
challenge.WithClock(func() time.Time { return now }),
)
notifier := &capturingNotifier{}
v := auth.NewValidator(
ratelimit.New(),
store,
notifier,
auth.WithChallengeTTL(time.Minute),
)
if err := v.Issue(context.Background(), testUser); err != nil {
t.Fatalf("Issue: %v", err)
}
code := notifier.code()
now = now.Add(2 * time.Minute)
err := v.Verify(context.Background(), testUser, code)
if !errors.Is(err, auth.ErrNoChallenge) {
t.Fatalf("expired Verify = %v, want ErrNoChallenge", err)
}
}
package httpapi
import (
"encoding/json"
"errors"
"log/slog"
"net/http"
"github.com/cronchie/totpauth/internal/auth"
)
type Handler struct {
validator *auth.Validator
log *slog.Logger
}
func NewHandler(validator *auth.Validator, log *slog.Logger) *Handler {
if log == nil {
log = slog.Default()
}
return &Handler{validator: validator, log: log}
}
func (h *Handler) Routes() *http.ServeMux {
mux := http.NewServeMux()
mux.HandleFunc("POST /v1/otp/issue", h.issueOTP)
mux.HandleFunc("POST /v1/otp/verify", h.verifyOTP)
return mux
}
type issueRequest struct {
UserID string `json:"user_id"`
}
func (h *Handler) issueOTP(w http.ResponseWriter, r *http.Request) {
var req issueRequest
if err := json.NewDecoder(http.MaxBytesReader(w, r.Body, 1<<10)).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid request body")
return
}
if req.UserID == "" {
writeError(w, http.StatusBadRequest, "user_id is required")
return
}
if err := h.validator.Issue(r.Context(), req.UserID); err != nil {
h.log.ErrorContext(r.Context(), "otp issue failed",
"user_id", req.UserID, "err", err)
writeError(w, http.StatusInternalServerError, "internal error")
return
}
w.WriteHeader(http.StatusAccepted)
}
type verifyRequest struct {
UserID string `json:"user_id"`
Code string `json:"code"`
}
func (h *Handler) verifyOTP(w http.ResponseWriter, r *http.Request) {
var req verifyRequest
if err := json.NewDecoder(http.MaxBytesReader(w, r.Body, 1<<10)).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid request body")
return
}
if req.UserID == "" || req.Code == "" {
writeError(w, http.StatusBadRequest, "user_id and code are required")
return
}
switch err := h.validator.Verify(r.Context(), req.UserID, req.Code); {
case err == nil:
w.WriteHeader(http.StatusNoContent)
case errors.Is(err, auth.ErrLockedOut):
writeError(w, http.StatusTooManyRequests, "too many failed attempts")
case errors.Is(err, auth.ErrInvalidCode), errors.Is(err, auth.ErrNoChallenge):
writeError(w, http.StatusUnauthorized, "invalid credentials")
default:
h.log.ErrorContext(r.Context(), "otp verify failed",
"user_id", req.UserID, "err", err)
writeError(w, http.StatusInternalServerError, "internal error")
}
}
func writeError(w http.ResponseWriter, status int, msg string) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
_ = json.NewEncoder(w).Encode(map[string]string{"error": msg})
}
package httpapi_test
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"sync"
"testing"
"github.com/cronchie/totpauth/internal/auth"
"github.com/cronchie/totpauth/internal/auth/challenge"
"github.com/cronchie/totpauth/internal/httpapi"
"github.com/cronchie/totpauth/internal/ratelimit"
)
const testUser = "user-42"
type capturingNotifier struct {
mu sync.Mutex
last string
}
func (n *capturingNotifier) Notify(_ context.Context, _, code string) error {
n.mu.Lock()
defer n.mu.Unlock()
n.last = code
return nil
}
func (n *capturingNotifier) code() string {
n.mu.Lock()
defer n.mu.Unlock()
return n.last
}
func newTestServer(t *testing.T, maxAttempts int) (*httptest.Server, *capturingNotifier) {
t.Helper()
notifier := &capturingNotifier{}
validator := auth.NewValidator(
ratelimit.New(),
challenge.New(),
notifier,
auth.WithMaxAttempts(maxAttempts),
)
srv := httptest.NewServer(httpapi.NewHandler(validator, nil).Routes())
t.Cleanup(srv.Close)
return srv, notifier
}
func postJSON(t *testing.T, srv *httptest.Server, path string, body any) *http.Response {
t.Helper()
buf, err := json.Marshal(body)
if err != nil {
t.Fatalf("marshal: %v", err)
}
resp, err := http.Post(srv.URL+path, "application/json", bytes.NewReader(buf))
if err != nil {
t.Fatalf("POST %s: %v", path, err)
}
t.Cleanup(func() { resp.Body.Close() })
return resp
}
func TestIssue_Accepted(t *testing.T) {
srv, n := newTestServer(t, 5)
resp := postJSON(t, srv, "/v1/otp/issue", map[string]string{"user_id": testUser})
if resp.StatusCode != http.StatusAccepted {
t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusAccepted)
}
if got := n.code(); len(got) != 6 {
t.Fatalf("notifier received %q, want 6-digit code", got)
}
}
func TestVerify_Success(t *testing.T) {
srv, n := newTestServer(t, 5)
_ = postJSON(t, srv, "/v1/otp/issue", map[string]string{"user_id": testUser})
resp := postJSON(t, srv, "/v1/otp/verify", map[string]string{
"user_id": testUser, "code": n.code(),
})
if resp.StatusCode != http.StatusNoContent {
t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusNoContent)
}
}
func TestVerify_WrongCode(t *testing.T) {
srv, _ := newTestServer(t, 5)
_ = postJSON(t, srv, "/v1/otp/issue", map[string]string{"user_id": testUser})
resp := postJSON(t, srv, "/v1/otp/verify", map[string]string{
"user_id": testUser, "code": "000000",
})
if resp.StatusCode != http.StatusUnauthorized {
t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusUnauthorized)
}
}
func TestVerify_NoChallengeMasked(t *testing.T) {
srv, _ := newTestServer(t, 5)
resp := postJSON(t, srv, "/v1/otp/verify", map[string]string{
"user_id": testUser, "code": "000000",
})
if resp.StatusCode != http.StatusUnauthorized {
t.Fatalf("status = %d, want %d (must not leak whether a challenge was issued)",
resp.StatusCode, http.StatusUnauthorized)
}
}
func TestVerify_LockoutReturns429(t *testing.T) {
srv, _ := newTestServer(t, 2)
_ = postJSON(t, srv, "/v1/otp/issue", map[string]string{"user_id": testUser})
for range 2 {
_ = postJSON(t, srv, "/v1/otp/verify", map[string]string{
"user_id": testUser, "code": "000000",
})
}
resp := postJSON(t, srv, "/v1/otp/verify", map[string]string{
"user_id": testUser, "code": "000000",
})
if resp.StatusCode != http.StatusTooManyRequests {
body, _ := io.ReadAll(resp.Body)
t.Fatalf("status = %d, body = %s, want 429", resp.StatusCode, body)
}
}
func TestVerify_BadRequest(t *testing.T) {
srv, _ := newTestServer(t, 5)
resp, err := http.Post(srv.URL+"/v1/otp/verify", "application/json", bytes.NewBufferString("not json"))
if err != nil {
t.Fatalf("POST: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("status = %d, want 400", resp.StatusCode)
}
}
package httpapi
import (
"log/slog"
"net/http"
"runtime/debug"
)
func Recover(log *slog.Logger) func(http.Handler) http.Handler {
if log == nil {
log = slog.Default()
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rw := &responseRecorder{ResponseWriter: w}
defer func() {
rec := recover()
if rec == nil {
return
}
if rec == http.ErrAbortHandler {
panic(rec)
}
log.ErrorContext(r.Context(), "panic recovered",
"panic", rec,
"method", r.Method,
"path", r.URL.Path,
"stack", string(debug.Stack()),
)
if !rw.wroteHeader {
rw.Header().Set("Content-Type", "application/json")
rw.WriteHeader(http.StatusInternalServerError)
_, _ = rw.Write([]byte(`{"error":"internal error"}`))
}
}()
next.ServeHTTP(rw, r)
})
}
}
type responseRecorder struct {
http.ResponseWriter
wroteHeader bool
status int
}
func (r *responseRecorder) WriteHeader(status int) {
if r.wroteHeader {
return
}
r.wroteHeader = true
r.status = status
r.ResponseWriter.WriteHeader(status)
}
func (r *responseRecorder) Write(b []byte) (int, error) {
if !r.wroteHeader {
r.WriteHeader(http.StatusOK)
}
return r.ResponseWriter.Write(b)
}
package httpapi_test
import (
"bytes"
"encoding/json"
"io"
"log/slog"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/cronchie/totpauth/internal/httpapi"
)
func TestRecover_CatchesPanicAndReturns500(t *testing.T) {
var buf bytes.Buffer
log := slog.New(slog.NewTextHandler(&buf, &slog.HandlerOptions{Level: slog.LevelError}))
handler := httpapi.Recover(log)(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {
panic("kaboom")
}))
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/totp/verify", nil))
if rec.Code != http.StatusInternalServerError {
t.Errorf("status = %d, want 500", rec.Code)
}
var body map[string]string
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
t.Fatalf("body not JSON: %v (body=%q)", err, rec.Body.String())
}
if body["error"] != "internal error" {
t.Errorf("body[error] = %q, want %q", body["error"], "internal error")
}
logs := buf.String()
for _, want := range []string{"panic recovered", "kaboom", "/totp/verify"} {
if !strings.Contains(logs, want) {
t.Errorf("log missing %q; got: %s", want, logs)
}
}
}
func TestRecover_PassesThroughWhenNoPanic(t *testing.T) {
log := slog.New(slog.NewTextHandler(io.Discard, nil))
handler := httpapi.Recover(log)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusTeapot)
_, _ = w.Write([]byte("hi"))
}))
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
if rec.Code != http.StatusTeapot {
t.Errorf("status = %d, want %d", rec.Code, http.StatusTeapot)
}
if rec.Body.String() != "hi" {
t.Errorf("body = %q, want %q", rec.Body.String(), "hi")
}
}
func TestRecover_PanicAfterWriteKeepsEarlyStatus(t *testing.T) {
log := slog.New(slog.NewTextHandler(io.Discard, nil))
handler := httpapi.Recover(log)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("partial"))
panic("too late")
}))
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
if rec.Code != http.StatusOK {
t.Errorf("status = %d, want 200 (panic after Write must not overwrite status)", rec.Code)
}
if !strings.HasPrefix(rec.Body.String(), "partial") {
t.Errorf("body = %q, want it to start with %q", rec.Body.String(), "partial")
}
}
func TestRecover_RepanicsErrAbortHandler(t *testing.T) {
log := slog.New(slog.NewTextHandler(io.Discard, nil))
handler := httpapi.Recover(log)(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {
panic(http.ErrAbortHandler)
}))
defer func() {
switch rec := recover(); rec {
case http.ErrAbortHandler:
case nil:
t.Error("middleware swallowed http.ErrAbortHandler; want re-panic")
default:
t.Errorf("re-panicked with %v, want http.ErrAbortHandler", rec)
}
}()
handler.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/", nil))
}
package store
import "time"
const DefaultEntryTTL = 10 * time.Minute
type Entry struct {
Count int
ExpiresAt time.Time
}
func Expired(e *Entry, now time.Time) bool {
return e == nil || now.After(e.ExpiresAt)
}
func Lookup(m map[string]*Entry, key string, now time.Time) (*Entry, bool) {
e, ok := m[key]
if ok && Expired(e, now) {
delete(m, key)
return nil, false
}
return e, ok
}
func Ensure(m map[string]*Entry, key string, ttl time.Duration, now time.Time) *Entry {
if e, ok := Lookup(m, key, now); ok {
return e
}
e := &Entry{ExpiresAt: now.Add(ttl)}
m[key] = e
return e
}
package store
import "time"
type Settings struct {
TTL time.Duration
Now func() time.Time
}
type Option func(*Settings)
func Apply(opts []Option) Settings {
s := Settings{TTL: DefaultEntryTTL, Now: time.Now}
for _, opt := range opts {
opt(&s)
}
return s
}
func WithTTL(d time.Duration) Option {
return func(s *Settings) { s.TTL = d }
}
func WithClock(now func() time.Time) Option {
return func(s *Settings) { s.Now = now }
}
package ratelimit
import (
"context"
"sync"
"time"
"github.com/cronchie/totpauth/internal/ratelimit/internal/store"
)
type mutexLimiter struct {
mu sync.Mutex
attempts map[string]*store.Entry
ttl time.Duration
now func() time.Time
}
func New(opts ...Option) Limiter {
s := store.Apply(opts)
return &mutexLimiter{
attempts: make(map[string]*store.Entry),
ttl: s.TTL,
now: s.Now,
}
}
func (l *mutexLimiter) Locked(_ context.Context, key string, limit int) error {
l.mu.Lock()
defer l.mu.Unlock()
e, ok := store.Lookup(l.attempts, key, l.now())
if !ok {
return nil
}
if e.Count >= limit {
return ErrLimitExceeded
}
return nil
}
func (l *mutexLimiter) RecordFailure(_ context.Context, key string, limit int) error {
l.mu.Lock()
defer l.mu.Unlock()
e := store.Ensure(l.attempts, key, l.ttl, l.now())
if e.Count >= limit {
return ErrLimitExceeded
}
e.Count++
return nil
}
package ratelimit
import (
"time"
"github.com/cronchie/totpauth/internal/ratelimit/internal/store"
)
const DefaultEntryTTL = store.DefaultEntryTTL
type Option = store.Option
func WithTTL(d time.Duration) Option { return store.WithTTL(d) }
func WithClock(now func() time.Time) Option { return store.WithClock(now) }
package ratelimit
import (
"context"
"errors"
)
var ErrLimitExceeded = errors.New("ratelimit: limit exceeded")
type Limiter interface {
Locked(ctx context.Context, key string, limit int) error
RecordFailure(ctx context.Context, key string, limit int) error
}
package ratelimit_test
import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/cronchie/totpauth/internal/ratelimit"
)
func TestLimiter_ConcurrentSafety(t *testing.T) {
const (
limit = 5
goroutines = 50
key = "user-42"
)
limiter := ratelimit.New()
var (
wg sync.WaitGroup
start = make(chan struct{})
recorded int64
rejected int64
other atomic.Value
)
for range goroutines {
wg.Add(1)
go func() {
defer wg.Done()
<-start
switch err := limiter.RecordFailure(context.Background(), key, limit); {
case err == nil:
atomic.AddInt64(&recorded, 1)
case errors.Is(err, ratelimit.ErrLimitExceeded):
atomic.AddInt64(&rejected, 1)
default:
other.Store(err)
}
}()
}
close(start)
wg.Wait()
if err, _ := other.Load().(error); err != nil {
t.Fatalf("unexpected limiter error: %v", err)
}
if got := int(recorded); got != limit {
t.Errorf("recorded %d failures, want exactly %d (rejected=%d)",
got, limit, rejected)
}
if got := int(recorded + rejected); got != goroutines {
t.Errorf("accounted for %d of %d requests", got, goroutines)
}
}
func runFailureBurstAndProbeLockout(limiter ratelimit.Limiter, key string, limit, burst int) (isLocked bool) {
ctx := context.Background()
var wg sync.WaitGroup
start := make(chan struct{})
for range burst {
wg.Add(1)
go func() {
defer wg.Done()
<-start
_ = limiter.RecordFailure(ctx, key, limit)
}()
}
close(start)
wg.Wait()
return errors.Is(limiter.Locked(ctx, key, limit), ratelimit.ErrLimitExceeded)
}
func TestLimiter_LockoutHoldsAfterBurst(t *testing.T) {
const (
key = "user-42"
limit = 5
burst = 100
)
if !runFailureBurstAndProbeLockout(ratelimit.New(), key, limit, burst) {
t.Errorf("locked impl: key NOT locked after %d-attempt burst against limit %d; "+
"safe impl must always lock", burst, limit)
}
}
func TestLimiter_WindowExpiry(t *testing.T) {
const (
limit = 2
key = "user-42"
ttl = time.Minute
)
now := time.Date(2026, 5, 21, 12, 0, 0, 0, time.UTC)
advance := func(d time.Duration) {
now = now.Add(d)
}
limiter := ratelimit.New(
ratelimit.WithTTL(ttl),
ratelimit.WithClock(func() time.Time { return now }),
)
ctx := context.Background()
for range limit {
if err := limiter.RecordFailure(ctx, key, limit); err != nil {
t.Fatalf("failure should be recorded: %v", err)
}
}
if err := limiter.Locked(ctx, key, limit); !errors.Is(err, ratelimit.ErrLimitExceeded) {
t.Fatalf("want ErrLimitExceeded after budget exhausted, got %v", err)
}
advance(ttl + time.Second)
if err := limiter.Locked(ctx, key, limit); err != nil {
t.Fatalf("key should be unlocked after window expiry: %v", err)
}
}
A single mutex around RecordFailure fixes the bug by serializing every call through one lock. For this service, that costs nothing the team will feel: one instance, modest traffic, a lock held just long enough to read a couple of fields and store one back. They dropped the sync.Map, and instead put a regular map under one sync.Mutex. One global lock is enough until the service needs more throughput than a single lock can pass.
An alternate fix could have given each account its own lock, held across the check and the increment. Requests for different users never wait on each other, and two requests for the same account serialize, which is the one place the contention belongs.
The point worth leaving on is that sync.Map was never the wrong tool. It did what its documentation promises: concurrent access to the map stays safe, and lock contention drops across disjoint keys. The bug lived a level above the map, in the check-and-act sequence no collection is responsible for making atomic.
Final Thoughts
Two requirements caught two distinct bugs in the same method. 15.4.1 surfaced the data structure problem, which sync.Map fixed. 15.4.2 surfaced the atomicity problem, which the mutex fixed. The split between them isn’t arbitrary; a team can satisfy one without satisfying the other, which is what the team in this article ended up doing.
Starting with a mutex could have satisfied both requirements in this example, but that does not make “use a mutex” the general lesson. Developers often reach for concurrency-safe collections because they provide safe access without forcing every operation through one application-level lock. Go’s sync.Map, C#‘s ConcurrentDictionary, and similar types can be good fits when the operation being performed is the individual load, store, add, update, or delete they protect.
I hope teams use these requirements to inspect the places where their code checks state, makes a security decision from that state, and then updates it. If two requests can make the same decision from the same stale answer, the bug isn’t gone just because the map stopped racing.
The Applying ASVS Series
In my time doing AppSec, the talks I’ve enjoyed most gave me problems to chew on: something concrete enough to follow, but open enough that I kept thinking about how it applied to my own work. That’s why I started building more code samples into my talks, and why I wanted this series to include runnable examples instead of just commentary (although my commentary has been, I’m assured, solid so far). I find that security guidance becomes easier to internalize when you can look at sample code, tweak it, and run it to watch the failure happen for yourself.
If any of these examples help you review code or reminds you of a related security flaw, I’d like to hear about it. That feedback is what keeps the examples grounded in the kind of code I imagine you, the reader, are routinely writing and reviewing. My next article on applying ASVS (Part 2 for Safe Concurrency) is already in the works, and in true concurrent fashion, you’re welcome to appreciate it now while it’s still being written 🙂