I'm Patrick Toulme
Software engineer on the Google TPU XLA team in New York, working on XLA and JAX compilation for TPU. I also created and open sourced pyptx, a Python DSL for writing NVIDIA PTX kernels on Hopper and Blackwell, callable from JAX and PyTorch. Previously on the MTIA compiler team at Meta and at AWS Neuron working on the Trainium compiler and NKI kernels. I write about AI Compilers, AI systems, JAX/PyTorch on my blog JustAByte.
Where I've Worked
Software Engineer
Working on the XLA compiler for TPU in New York, focused on high-performance compilation of JAX workloads down to TPU.
Compiler Engineer
Worked on GenAI inference compilation in vLLM and high performance inference compilation from FX IR to MTIA ISA. Also worked on the bringup of new MTIA silicon.
Machine Learning Engineer II
I led the bringup of the JAX backend for Trainium. I also made GSPMD work on Trainium and work performantly. Worked on native codegeneration and the NKI compiler. I also worked on the bringup of Trainium2. I also wrote the first collective matmul for Trainium — open sourced here.
Software Engineer
Worked on the early Bedrock training service. Trained very large LLMs on GPUs.
Just a Byte
I write a blog about AI compilers, custom silicon, and AI systems. Exploring JAX and PyTorch targeting custom hardware, with a focus on high-performance LLMs.
Frontier Pretraining Infrastructure Is Already Open Source: GPT-OSS on TPU with MaxText
XLA automatically generates optimized training infrastructure — kernel fusion, collective overlap, and MoE routing — from high-level JAX code. Frontier-quality pretraining is already open source.
CuTile on Blackwell: NVIDIA's Compiler Moat Is Already Built
Tracing a Mixture of Experts kernel through CuTile's compilation stages. 86 lines of Python expand into 180KB of shared memory, tcgen05 instructions, and orchestration patterns that form NVIDIA's deepening compiler moat.
When XLA Isn't Enough: From Pallas to VLIW with Splash Attention on TPU
When does XLA hit its limits? How do you write the TPU Pallas kernel that the compiler cannot automatically find? Why can't XLA generate Splash Attention?
From JAX to VLIW: Tracing a Computation Through the TPU Compiler Stack
Eight lines of JAX code become 250 VLIW bundles across 5 fused kernels. A deep dive into what happens between jax.jit(f)(x) and electrons moving through a TPU.
Things I've Built From Scratch
A Python DSL for writing NVIDIA PTX kernels, and a custom tensor processing unit — from RTL hardware to an MLIR compiler and PJRT runtime — that runs JAX and executes Llama as a single fused megakernel.
pyptx
A Python DSL where the function body is the PTX instruction stream — one PTX instruction = one Python call. Direct Hopper + Blackwell ISA support: wgmma, TMA, tcgen05, mbarriers, cluster barriers. Callable from JAX (via typed XLA FFI) and PyTorch (eager + torch.compile + a C++ extension fast path). Includes maintained GEMM, RMSNorm, SwiGLU, and grouped GEMM kernels for both architectures, plus a PTX → Python transpiler that round-trips byte-identical on 218+ real-world kernels (CUTLASS, Triton, DeepGEMM, ThunderKittens, fast.cu). 815 TFLOPS bf16 GEMM on H100 (beats cuBLAS at ≥6K), 1314 TFLOPS on B200.
Custom TPU
A fully custom tensor processing unit built from the ground up: Verilog RTL hardware, an MLIR-based compiler (JAX → HLO → MLIR → ASM → VLIW → Binary), and a PJRT runtime that plugs directly into JAX. No CUDA, no hand-written kernels — pure compiler codegen. The compiler fuses entire models into single megakernel binaries, and it now runs Llama end-to-end.
Academic Background
Georgia Institute of Technology
University of Virginia
Thomas Jefferson High School
Research
Marcus: A Chatbot for Depression Screening Based on the PHQ-9 Assessment
A comprehensive study contrasting the effectiveness of screenings by "Marcus," a BERT-based chatbot, against traditional PHQ-9 assessments. Developed a prototype application integrating BERT for linguistic analysis with DialogFlow and Kommunicate APIs.
Technical Skills
Languages & Compilers
- Python, C++
- PyTorch, JAX
- XLA, GSPMD, PJRT
- MLIR, LLVM
- FX Graph IR
Hardware & Kernels
- MTIA (Meta)
- GPU, TPU
- Neuron Core (AWS)
- Triton, Pallas
- NKI
Machine Learning
- Transformer Models
- Distributed Training
- LLM Pretraining
- Mixture of Experts
- RLHF