AI/ML-Native Framework for Composites Simulation Using JAX

Apollo 11 Postdoctoral Fellowships at Purdue - Proposal
Wenbin Yu, 
Professor 


Project Summary

This project will implement and further develop the Mechanics of Structure Genome (MSG) in JAX, enabling differentiable, massively parallel, GPU/TPU‑ready simulations. This AI/ML‑native framework will (i) cut compute time by orders of magnitude without losing fidelity, (ii) fuse physics and data seamlessly via autodiff, and (iii) unlock rapid design optimization, uncertainty quantification, and learning‑augmented solvers. In two years, the postdoc will work with graduate students under Prof. Yu’s supervision to establish a reusable, open foundation for AI/ML‑assisted multiscale modeling with immediate, high-potential impact.

Introduction

Traditional engineering simulation codes are typically written in Fortran or C/C++ and aggressively optimized for serial execution. Even as parallel computing has become ubiquitous, many legacy algorithms struggle to exploit multicore CPUs, and they seldom leverage modern accelerators—GPUs and TPUs—that now drive rapid advances in AI/ML.

This project seeks to re-platform physics-based simulation into an AI/ML-native stack using JAX, a python library for high-performance numerical computing and machine learning. JAX provides composable automatic differentiation (gradients/hessians), functional vectorization, and just-in-time (JIT) compilation through XLA, delivering near-native performance while scaling from a laptop CPU to multi-GPU/TPU clusters with minimal changes. This will enable a paradigm shift for multiscale modeling in the age of AI/ML.

Research Goals

The postdoc is expected to quickly demonstrate the potential of this idea by achieving the following goals:

Goal 1 — JAX‑native MSG core: Implement linear MSG for microstructures represented by 1D/2D/3D structure genes (SGs), producing macroscopic models for solids, plates, shells, and beams. Provide unit tests, reference problems, and a public GitHub repository.

Goal 2 — Differentiable, accelerated workflows: Exploit JIT, vmap, and autodiff to enable gradient‑based design, sensitivity analysis, and UQ; demonstrate scaling from CPU to multi‑GPU/TPU with stable, consistent APIs.

Goal 3 — Nonlinear MSG and data–physics fusion: Extend to nonlinear, path‑dependent behavior; integrate data‑assisted operators (e.g., learned correctors, neural operators) while keeping physics as the backbone.

Goal 4 — Open, reusable research infrastructure: Release reproducible pipelines (benchmarks, profilers, scripts) supporting internal/external collaborations and proposal developments.

This work directly complements ongoing efforts in fundamental MSG-based multiscale constitutive modeling, AI‑assisted multiscale modeling and CompositesAI, while expanding into a machine‑learning‑native simulation stack to leverage emergent hardware and differentiable programming—capabilities not available in legacy Fortran/C++ codes.Expected Deliverables

Year 1 (Foundations & Linear MSG)

  • Code: Linear MSG (1D/2D/3D SGs → solid/plate/shell/beam) in JAX; open GitHub with docs, tests, and examples/tutorials.
  • Benchmarks: Accuracy vs. legacy implementations; end‑to‑end speedups; CPU vs. single‑/multi‑GPU scaling plots.
  • Publications: ≥2 conference talks; ≥2 journal manuscripts submitted.
  • Proposals: AFOSR white paper/full proposal leveraging preliminary results.
  • Training & Mentoring: Mentor 1–2 PhD students on the JAX toolchain and differentiable simulation workflows.

Year 2 (Nonlinear, Scale‑out, and Data–Physics Fusion)

  • Code: Nonlinear MSG (path dependence), larger demonstration problems (e.g., laminate/ply‑level nonlinearities, complex layups).
  • Methods: Gradient‑based design/UQ case studies; initial learned operators where beneficial while preserving physics guarantees.
  • Publications: ≥3 conference talks; ≥3 journal manuscripts submitted.
  • Proposals: NSF submission (e.g., CMMI) plus agency‑specific follow‑ons.
  • Community: Tutorials, recorded notebooks, and reproducibility packages to seed external users/collaborators.

Expected Deliverables

Year 1 (Foundations & Linear MSG)

  • Code: Linear MSG (1D/2D/3D SGs → solid/plate/shell/beam) in JAX; open GitHub with docs, tests, and examples/tutorials.
  • Benchmarks: Accuracy vs. legacy implementations; end‑to‑end speedups; CPU vs. single‑/multi‑GPU scaling plots.
  • Publications: ≥2 conference talks; ≥2 journal manuscripts submitted.
  • Proposals: AFOSR white paper/full proposal leveraging preliminary results.
  • Training & Mentoring: Mentor 1–2 PhD students on the JAX toolchain and differentiable simulation workflows.

Year 2 (Nonlinear, Scale‑out, and Data–Physics Fusion)

  • Code: Nonlinear MSG (path dependence), larger demonstration problems (e.g., laminate/ply‑level nonlinearities, complex layups).
  • Methods: Gradient‑based design/UQ case studies; initial learned operators where beneficial while preserving physics guarantees.
  • Publications: ≥3 conference talks; ≥3 journal manuscripts submitted.
  • Proposals: NSF submission (e.g., CMMI) plus agency‑specific follow‑ons.
  • Community: Tutorials, recorded notebooks, and reproducibility packages to seed external users/collaborators.

Affiliated Faculty

Milton Clauser Professor Wenbin Yu
School of Aeronautics and Astronautics, Purdue University

Back