AI Compilers & Systems

I'm Patrick Toulme

Software engineer on the Google TPU XLA team in New York, working on XLA and JAX compilation for TPU. I also created and open sourced pyptx, a Python DSL for writing NVIDIA PTX kernels on Hopper and Blackwell, callable from JAX and PyTorch. Previously on the MTIA compiler team at Meta and at AWS Neuron working on the Trainium compiler and NKI kernels. I write about AI Compilers, AI systems, JAX/PyTorch on my blog JustAByte.

Location New York, NY
Education Georgia Tech MS, UVA BA
Patrick Toulme

Where I've Worked

2026 — Present

Software Engineer

Google — TPU XLA Team

Working on the XLA compiler for TPU in New York, focused on high-performance compilation of JAX workloads down to TPU.

XLA JAX TPU MLIR
2025 — 2026

Compiler Engineer

Meta — MTIA Compiler Performance Team

Worked on GenAI inference compilation in vLLM and high performance inference compilation from FX IR to MTIA ISA. Also worked on the bringup of new MTIA silicon.

PyTorch FX IR MLIR Custom Silicon
2023 — 2025

Machine Learning Engineer II

Amazon AWS Neuron — Annapurna Labs

I led the bringup of the JAX backend for Trainium. I also made GSPMD work on Trainium and work performantly. Worked on native codegeneration and the NKI compiler. I also worked on the bringup of Trainium2. I also wrote the first collective matmul for Trainium — open sourced here.

JAX XLA PJRT GSPMD NKI
2022 — 2023

Software Engineer

Amazon AWS AI Bedrock — Titan Model Training

Worked on the early Bedrock training service. Trained very large LLMs on GPUs.

PyTorch NeMo Distributed Training A100

Things I've Built From Scratch

A Python DSL for writing NVIDIA PTX kernels, and a custom tensor processing unit — from RTL hardware to an MLIR compiler and PJRT runtime — that runs JAX and executes Llama as a single fused megakernel.

pyptx

Python DSL → NVIDIA PTX (Hopper + Blackwell)

A Python DSL where the function body is the PTX instruction stream — one PTX instruction = one Python call. Direct Hopper + Blackwell ISA support: wgmma, TMA, tcgen05, mbarriers, cluster barriers. Callable from JAX (via typed XLA FFI) and PyTorch (eager + torch.compile + a C++ extension fast path). Includes maintained GEMM, RMSNorm, SwiGLU, and grouped GEMM kernels for both architectures, plus a PTX → Python transpiler that round-trips byte-identical on 218+ real-world kernels (CUTLASS, Triton, DeepGEMM, ThunderKittens, fast.cu). 815 TFLOPS bf16 GEMM on H100 (beats cuBLAS at ≥6K), 1314 TFLOPS on B200.

Python PTX Hopper Blackwell JAX PyTorch DSL Compiler

Custom TPU

JAX → PJRT Runtime → MLIR Compiler → RTL

A fully custom tensor processing unit built from the ground up: Verilog RTL hardware, an MLIR-based compiler (JAX → HLO → MLIR → ASM → VLIW → Binary), and a PJRT runtime that plugs directly into JAX. No CUDA, no hand-written kernels — pure compiler codegen. The compiler fuses entire models into single megakernel binaries, and it now runs Llama end-to-end.

Verilog MLIR PJRT JAX VLIW Llama

Academic Background

Georgia Institute of Technology

M.S. Computer Science — Artificial Intelligence
Graduated December 2025

University of Virginia

B.A. Computer Science
Graduated May 2022

Thomas Jefferson High School

Science and Technology
Graduated May 2018

Research

🏆 Best Paper Award

Marcus: A Chatbot for Depression Screening Based on the PHQ-9 Assessment

A comprehensive study contrasting the effectiveness of screenings by "Marcus," a BERT-based chatbot, against traditional PHQ-9 assessments. Developed a prototype application integrating BERT for linguistic analysis with DialogFlow and Kommunicate APIs.

ACHI 2023 — The Sixteenth International Conference on Advances in Computer-Human Interactions

Technical Skills

Languages & Compilers

  • Python, C++
  • PyTorch, JAX
  • XLA, GSPMD, PJRT
  • MLIR, LLVM
  • FX Graph IR

Hardware & Kernels

  • MTIA (Meta)
  • GPU, TPU
  • Neuron Core (AWS)
  • Triton, Pallas
  • NKI

Machine Learning

  • Transformer Models
  • Distributed Training
  • LLM Pretraining
  • Mixture of Experts
  • RLHF