Skip to main content

Command Palette

Search for a command to run...

BF16 Edge Cases When Porting CUDA-Oriented TTS Models to MPS and MLX-Audio

Updated
8 min read

When porting CUDA-oriented TTS models to Apple Silicon, one subtle but important issue is BF16 behavior on Metal. This has shown up in DiaTTS work on MPS, and the same class of issue appears relevant for ports such as Microsoft's VibeVoice, and (my still unreleased look at converting) MisoTTS into MPS or MLX-Audio.

tl-dr;

BF16 subnormal values that remain nonzero on CPU can be flushed to zero on Metal/MPS/MLX.

It does not automatically mean the model will fail. But CUDA or CPU behavior is not always a reliable reference for BF16 inference on Apple GPUs, especially around normalization, attention, logits, softmax tails, sampling, KV caches, and quantization calibration. This can very strongly lead you to use 32 bit mode (VibeVoice, DiaTTS) in Metal, which kinda sucks for speed.

The Core Issue

BF16 has a normal range and a subnormal range. Very tiny values below the normal threshold can still exist as subnormals on some platforms. On Apple Metal paths, those BF16 subnormal values may be flushed to zero.

A local reproduction showed the pattern directly (more on that in next section):

[cpu] cast nonzero=7
[mps] cast nonzero=4
[mlx] cast matches MPS behavior

Representative operation results showed the same thing across operation families:

[cpu] embedding_sum [2.9387e-39, 2.3510e-38, 1.8808e-37, 3.7616e-37]
[mps] embedding_sum [0.0000e+00, 0.0000e+00, 0.0000e+00, 3.7616e-37]
[mlx] embedding_sum [0.0000e+00, 0.0000e+00, 0.0000e+00, 3.7616e-37]

And for softmax tails:

[cpu] softmax [..., 8.2652e-40]
[mps] softmax [..., 0.0000e+00]
[mlx] softmax [..., 0.0000e+00]

So the issue is not just theoretical. BF16 values that survive on CPU can disappear on Metal.

Testing the Full MisoTTS BF16 Path

I wanted to see whether the full MisoTTS model would catastrophically collapse under BF16 on MPS. this is something I've experienced first hand working with Dia in my BomDia project.

To test that, we used a tiny full-model forward probe. The goal was not to run full generation. It was to instantiate the real model, bind the real checkpoint, run a minimal training-style forward pass, and compare FP32, CPU BF16, and MPS BF16 behavior.

The probe intentionally avoided loading the entire model in the naive way. A naive load would create full random initialized weights and then load the checkpoint on top, and I'm trying to keep memory impact to a minimum. Instead, the script used PyTorch’s meta device and load_state_dict(assign=True).

The basic loading pattern was:

with torch.device("meta"):
    model = Model(config)

state_dict = load_state_dict_for_dtype(checkpoint, dtype)
model.load_state_dict(state_dict, assign=True)

With assign=True, instead of copying checkpoint tensors into already-allocated parameter tensors, we bind the checkpoint tensors into the module.

Because the model is first built on the meta device, some non-persistent buffers need to be rebuilt afterward. In this case, the relevant buffers were RoPE-related:

def rebuild_rope_buffers(model: torch.nn.Module) -> int:
    count = 0
    for module in model.modules():
        if hasattr(module, "rope_init") and callable(module.rope_init):
            module.rope_init()
            count += 1
    return count

For FP32, the checkpoint can be loaded directly:

from safetensors.torch import load_file

state_dict = load_file(checkpoint, device="cpu")

For BF16, the script streamed tensors from the safetensors file and converted them one at a time:

from safetensors import safe_open

def load_state_dict_for_dtype(checkpoint: str, dtype: torch.dtype) -> dict[str, torch.Tensor]:
    if dtype == torch.float32:
        return load_file(checkpoint, device="cpu")

    converted = {}
    with safe_open(checkpoint, framework="pt", device="cpu") as f:
        keys = list(f.keys())
        for key in keys:
            converted[key] = f.get_tensor(key).to(dtype=dtype)
    return converted

This is still a full checkpoint load, but it avoids unnecessary extra initialization cost and makes the memory behavior easier to observe.

The probe also included explicit memory logging:

import os
import psutil
import resource

def mem(label: str) -> None:
    proc = psutil.Process(os.getpid())
    vm = psutil.virtual_memory()

    rss = proc.memory_info().rss / 1e9
    avail = vm.available / 1e9
    high = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1e9

    print(
        f"[mem] {label}: "
        f"rss={rss:.2f}GB "
        f"available={avail:.2f}GB "
        f"highwater={high:.2f}GB",
        flush=True,
    )

That was useful because the test was partly about correctness and partly about whether a full BF16 MPS probe could be run locally without blowing up memory.

One implementation issue surfaced during the probe. The training-style Model.forward path could produce a dtype mismatch: decoder_h may come back as fp32 while audio_head is BF16.

The inference path already handles this by casting decoder output back to the model dtype before applying audio_head. The probe mirrored that behavior:

dtype = next(model.parameters()).dtype

decoder_h = model.depth_decoder(...)
decoder_h = decoder_h.to(dtype=dtype)

logits = model.audio_head(decoder_h)

That cast is not the BF16 subnormal issue itself. It is just a practical dtype consistency fix needed to make the BF16 forward path match the inference path.

The actual test matrix was:

python probe_full_misotts_fp32.py \
  --repo-dir . \
  --checkpoint ../../hfmodel/MisoTTS/model.safetensors \
  --seq-len 1 \
  --s-amortized 1

python probe_full_misotts_fp32.py \
  --repo-dir . \
  --checkpoint ../../hfmodel/MisoTTS/model.safetensors \
  --seq-len 1 \
  --s-amortized 1 \
  --model-dtype bfloat16 \
  --run-device cpu \
  --decoder-output-cast

python probe_full_misotts_fp32.py \
  --repo-dir . \
  --checkpoint ../../hfmodel/MisoTTS/model.safetensors \
  --seq-len 1 \
  --s-amortized 1 \
  --model-dtype bfloat16 \
  --run-device mps \
  --decoder-output-cast

First, a small synthetic probe had proved that BF16 subnormal behavior differs across CPU, MPS, and MLX. But now, the full-model tiny forward showed that MisoTTS BF16 on MPS does not automatically collapse catastrophically. The model can run, at least for the tiny forward probe, when loaded carefully and when the decoder output dtype is handled correctly.

That means the porting concern is not “BF16 MPS cannot run this model", but rather:

BF16 MPS can run the model, but numerically sensitive runtime paths still need guardrails, especially logits, sampling, RoPE, RMSNorm, KV cache behavior, and quantization calibration.

Why This Matters for TTS Models

Modern neural TTS models often combine several numerically sensitive components:

  • summed codebook embeddings

  • RoPE

  • RMSNorm

  • attention softmax

  • BF16 KV caches

  • logits sampling

  • quantization calibration statistics

These are exactly the kinds of paths where small values can matter, or at least where unexpected zeroing can create parity problems during a port.

This is especially relevant when a model was originally developed, validated, or tuned in a CUDA-heavy environment. CUDA behavior, CPU behavior, and Metal behavior may not match exactly in BF16 edge cases.

My Dia-Style Guardrails

In my customized Dia code for Bomdia, I had avoided this with f32:

  • RoPE computes timescales and rotation in float32, then casts back.

  • RMSNorm is handled in float32.

  • Encoder positions are created as float32.

  • Decoder logits are returned as float32.

  • Attention masks and softmax paths include MPS compatibility handling.

These are good practices but they cannot guarantee fidelity:

If a BF16 value has already been flushed to zero before an fp32 upcast, fp32 math cannot recover it.

So “just cast to fp32 later” is not always enough (and it wasn't with Dia). The cast has to happen before the vulnerable operation, not after the value has already collapsed. In my case in that project I got enough fp32 casting done that it could be run in bf16, with usually same or at least similar results. But if I wasn't getting the audio I needed, I would fall back on cpu 32 bit (which is deadly slow of course).

For MisoTTS

MisoTTS has the same sensitive operation classes:

  • audio_embeddings are summed across 32 codebooks.

  • RoPE internally casts to fp32, then returns the original dtype.

  • RMSNorm computes in fp32, then returns the original dtype.

  • KV caches are allocated in the model dtype.

  • Attention uses scaled dot-product attention.

  • sample_topk runs log_softmax and softmax on the logits it receives.

The most concerning Miso-specific path is sampling.

Dia returns logits as fp32. Miso returns logits from BF16 model layers and samples immediately. If logits have large gaps, tail probabilities can fall into the BF16 subnormal range. CPU BF16 may preserve those tiny probabilities, while Metal or MLX may flush them to zero.

That does not necessarily break generation. Tiny tail probabilities may be irrelevant after top-k or temperature filtering. But it can cause parity differences, and if the affected values are not merely harmless tails, sampling behavior can diverge.

Checkpoint Weights Maybe not the Problem

A streaming scan of the public MisoTTS FP32 checkpoint found no BF16-subnormal-risk weights:

bf16_min_normal=1.17549435082228751e-38
total_edge_count=0
total_zero_count=33
smallest_positive_weight=1.21650289125994904e-13

That means the checkpoint itself does not appear to contain weights that would collapse to zero when converted to BF16 on Metal.

The likely risk is elsewhere:

runtime activations, logits tails, KV-cache dynamics, and quantization calibration/statistics.

This is an important distinction. The model can load cleanly, the weights can look safe, and the port can still have BF16-sensitive runtime behavior.

Full-Model Probe Result

A tiny full-model probe did not reproduce catastrophic collapse.

The FP32 model ran successfully. CPU BF16 ran successfully. MPS BF16 also ran successfully when loaded carefully.

The MPS BF16 tiny forward result was healthy:

forward_elapsed=3.9s
c0_logits finite=True min=-3.85938 max=8.6875
logits_c1_plus finite=True min=-70.5 max=28.125
c0_loss=2.93695
c1_plus_loss=5.26826e-06
loss=0.0917847

So far as I know, there is no way to avoid these expensive casts -- but its still going to be much better than running it fp32 throughout, or worse cpu.

BF16 on Metal can run these models, but ports need explicit numerical guardrails and parity checks.