Skip to content

Fix thread-safety in mixed_precision resolving developer TODOs#308

Open
ibfavas wants to merge 1 commit intogoogle-deepmind:v2from
ibfavas:fix-mixed-precision-thread-safety
Open

Fix thread-safety in mixed_precision resolving developer TODOs#308
ibfavas wants to merge 1 commit intogoogle-deepmind:v2from
ibfavas:fix-mixed-precision-thread-safety

Conversation

@ibfavas
Copy link
Copy Markdown

@ibfavas ibfavas commented Apr 8, 2026

Description

This PR resolves the existing TODO(loreno) comments in sonnet/src/mixed_precision.py regarding thread-local storage and thread safety.

Currently, snt.mixed_precision relies on a plain, module-level Python variable (_mixed_precision_mode) without locking or thread-local isolation. In concurrent environments (most notably tf.distribute.MirroredStrategy), this leads to silent precision corruption where threads overwrite each other's state during scope() context manager exits or enable() calls.

Changes Proposed

  1. Moved global state to threading.local(): Fixes the scope() save/restore TOCTOU race and the enable() lost-update race. Each thread now independently manages its precision state.
  2. Locked the seen_none flag: Implemented a double-checked lock around the initialization guard to prevent multiple threads from concurrently bypassing the initialization logic and permanently dropping dynamic range.

Proof of Concept (Reproduction)

Before this patch, running multiple threads utilizing the scope() context manager or enable() would result in silent state corruption. I wrote a minimal, pure-Python script reproducing the exact logic of the module to isolate and prove the race conditions (the scope TOCTOU and the enable lost update):

Click to view PoC Script & Output

poc_mixed_precision_race.py

#!/usr/bin/env python3
import threading
import time
import contextlib
import uuid

_mixed_precision_mode = None  

def enable(dtype):
    global _mixed_precision_mode
    _mixed_precision_mode = dtype

def _get_mixed_precision_mode():
    return _mixed_precision_mode

@contextlib.contextmanager
def scope(dtype):
    old_mode = _get_mixed_precision_mode()
    enable(dtype)
    try:
        yield
    finally:
        enable(old_mode) # TOCTOU occurs here

print("[*] Starting Sonnet Mixed Precision Race Condition PoC...\n")

# Race 1: The scope() Context Manager Wipeout
print("[*] 1. Forcing the scope() save/restore TOCTOU...")
enable(None)
barrier = threading.Barrier(2)

def thread_a():
    # Thread A enters the scope, sets fp16, and takes a nap.
    with scope("fp16"):
        barrier.wait()
        time.sleep(0.05) 

def thread_b():
    barrier.wait()
    time.sleep(0.01)
    # While A is sleeping, Thread B sets its own global precision.
    print("    [Thread B] Setting global mode to 'fp32'")
    enable("fp32")

ta = threading.Thread(target=thread_a)
tb = threading.Thread(target=thread_b)

ta.start()
tb.start()
ta.join()
tb.join()

result = _get_mixed_precision_mode()
if result != "fp32":
    print(f"    [!] Expected 'fp32', but got: {result!r}")
    print("    [!] Thread A's finally block just silently paved over Thread B's setting.\n")
else:
    print("    [+] No race this run. (Non-deterministic, try again)\n")


# Race 2: The enable() Lost Update
print("[*] 2. Forcing the enable() lost-update with 80 concurrent threads...")
enable(None)
mismatches = []
mismatch_lock = threading.Lock()
dtypes_to_spam = ["fp16", "fp32", "bf16", None]

def worker(tid):
    my_dtype = dtypes_to_spam[tid % len(dtypes_to_spam)]
    
    # Write to the global
    enable(my_dtype)
    time.sleep(0) # Yield just enough to let another thread step on our toes
    
    # Read it immediately back
    read_back = _get_mixed_precision_mode()
    
    if read_back != my_dtype:
        with mismatch_lock: 
            mismatches.append((tid, my_dtype, read_back))

threads = [threading.Thread(target=worker, args=(i,)) for i in range(80)]
for t in threads: t.start()
for t in threads: t.join()

if mismatches:
    print(f"    [!] Mismatch: {len(mismatches)} out of 80 threads read back corrupted state.")
    for tid, wrote, read in mismatches[:5]: # Just show the first 5
        print(f"        Thread {tid:02d} | Wrote: {str(wrote):>4} -> Read Back: {str(read)}")
    print("        ...")
else:
    print("    [+] Survived. (Re-run to observe the collision)")

Pre-Patch Output:

[*] Starting Sonnet Mixed Precision Race Condition PoC...

[*] 1. Forcing the scope() save/restore TOCTOU...
    [Thread B] Setting global mode to 'fp32'
    [!] Expected 'fp32', but got: None
    [!] Thread A's finally block just silently paved over Thread B's setting.

[*] 2. Forcing the enable() lost-update with 80 concurrent threads...
    [!] Mismatch: 25 out of 80 threads read back corrupted state.
        Thread 05 | Wrote: fp32 -> Read Back: bf16
        Thread 10 | Wrote: bf16 -> Read Back: None
        Thread 14 | Wrote: bf16 -> Read Back: None
        Thread 19 | Wrote: None -> Read Back: fp16
        Thread 21 | Wrote: fp32 -> Read Back: bf16
        ...

@google-cla
Copy link
Copy Markdown

google-cla bot commented Apr 8, 2026

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant