Skip to content

Commit 2a61239

Browse files
committed
Refactor CI/CD pipeline to align with MaxText and support PyPI release
- Consolidate requirements management to use generated requirements as the source of truth. - Update pyproject.toml, setup.sh, and internal scripts to source dependencies from the new path. - Refactor Docker build process to remove JAX AI image dependencies. - Add GitHub workflow for automated PyPI package publication via OIDC. - Streamline UploadDockerImages.yml for stable and nightly image builds.
1 parent 2a74af1 commit 2a61239

23 files changed

+650
-565
lines changed

.github/workflows/UnitTests.yml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,6 @@ jobs:
4242
python-version: '3.12'
4343
- name: Install dependencies
4444
run: |
45-
pip install -e .
46-
pip uninstall jax jaxlib libtpu-nightly libtpu -y
4745
bash setup.sh MODE=stable
4846
export PATH=$PATH:$HOME/.local/bin
4947
pip install ruff
@@ -66,4 +64,4 @@ jobs:
6664
# checks: read
6765
# pull-requests: write
6866
# needs: build
69-
# uses: ./.github/workflows/AddLabel.yml
67+
# uses: ./.github/workflows/AddLabel.yml

.github/workflows/UploadDockerImages.yml

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,12 @@ jobs:
2828
build-image:
2929
runs-on: ["self-hosted", "e2", "cpu"]
3030
steps:
31-
- uses: actions/checkout@v3
31+
- uses: actions/checkout@v5
3232
- name: Cleanup old docker images
3333
run: docker system prune --all --force
34-
- name: build maxdiffusion jax ai image
34+
- name: build maxdiffusion stable image
3535
run: |
36-
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_stack MODE=jax_ai_image PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest
37-
- name: build maxdiffusion w/ nightly jax ai image
38-
run: |
39-
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_stack_nightly MODE=jax_ai_image PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_stack BASEIMAGE=us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/tpu/jax_nightly:latest
40-
- name: build maxdiffusion jax nightly image
36+
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable MODE=stable PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_stable
37+
- name: build maxdiffusion nightly image
4138
run: |
4239
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_nightly MODE=nightly PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_nightly

.github/workflows/pypi_release.yml

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright 2025 Google LLC
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# This workflow will build, test and automatically release MaxDiffusion package to PyPI using Trusted Publishing (OIDC).
16+
17+
name: Publish MaxDiffusion to PyPI
18+
19+
# Triggers when a new "release" is published in the GitHub UI
20+
on:
21+
release:
22+
types: [published]
23+
workflow_dispatch:
24+
25+
permissions:
26+
contents: read
27+
id-token: write
28+
29+
jobs:
30+
build_and_publish:
31+
name: Build and Publish MaxDiffusion Package
32+
runs-on: ubuntu-latest
33+
steps:
34+
- uses: actions/checkout@v5
35+
- name: Set up Python
36+
uses: actions/setup-python@v5
37+
with:
38+
python-version: '3.12'
39+
- name: Install build dependencies
40+
run: |
41+
python -m pip install --upgrade pip
42+
pip install build hatchling hatch-requirements-txt
43+
- name: Build package
44+
run: python -m build
45+
- name: Publish package
46+
uses: pypa/gh-action-pypi-publish@release/v1
47+
with:
48+
packages-dir: dist/

Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@ modified_only_fixup:
1818
# Update src/maxdiffusion/dependency_versions_table.py
1919

2020
deps_table_update:
21-
@python setup.py deps_table_update
21+
@python utils/update_dependency_table.py
2222

2323
deps_table_check_updated:
2424
@md5sum src/maxdiffusion/dependency_versions_table.py > md5sum.saved
25-
@python setup.py deps_table_update
25+
@python utils/update_dependency_table.py
2626
@md5sum -c --quiet md5sum.saved || (printf "\nError: the version dependency table is outdated.\nPlease run 'make fixup' or 'make style' and commit the changes.\n\n" && exit 1)
2727
@rm md5sum.saved
2828

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
--extra-index-url https://download.pytorch.org/whl/cpu
2+
absl-py
3+
aqtp
4+
chex
5+
datasets
6+
einops
7+
flax
8+
ftfy
9+
google-cloud-storage
10+
grain
11+
hf_transfer
12+
huggingface_hub
13+
imageio-ffmpeg
14+
imageio
15+
jax
16+
jaxlib
17+
Jinja2
18+
opencv-python-headless
19+
optax
20+
orbax-checkpoint
21+
parameterized
22+
Pillow
23+
pyink
24+
pylint
25+
pytest
26+
ruff
27+
scikit-image
28+
sentencepiece
29+
tensorboard-plugin-profile
30+
tensorboard
31+
tensorboardx
32+
tensorflow-datasets
33+
tensorflow
34+
tokamax
35+
tokenizers
36+
transformers<5.0.0
37+
38+
# pinning torch and torchvision to specific versions to avoid
39+
# installing GPU versions from PyPI when running seed-env
40+
torch @ https://download.pytorch.org/whl/cpu/torch-2.10.0%2Bcpu-cp312-cp312-manylinux_2_28_x86_64.whl
41+
torchvision @ https://download.pytorch.org/whl/cpu/torchvision-0.25.0%2Bcpu-cp312-cp312-manylinux_2_28_x86_64.whl
42+
qwix @ https://github.com/google/qwix/archive/408a0f48f988b6c5b180e07f0cb1d05997bf0dcc.zip
43+
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
# Generated by seed-env. Do not edit manually.
2+
# If you need to modify dependencies, please do so in the host requirements file and run seed-env again.
3+
4+
absl-py>=2.3.1
5+
aiofiles>=25.1.0
6+
aiohappyeyeballs>=2.6.1
7+
aiohttp>=3.13.3
8+
aiosignal>=1.4.0
9+
annotated-types>=0.7.0
10+
anyio>=4.13.0
11+
aqtp>=0.9.0
12+
array-record>=0.8.3 ; sys_platform != 'win32'
13+
astroid>=4.0.4
14+
astunparse>=1.6.3
15+
attrs>=25.4.0
16+
auditwheel>=6.6.0
17+
black>=25.12.0
18+
build>=1.4.0
19+
certifi>=2026.1.4
20+
cffi>=2.0.0 ; platform_python_implementation != 'PyPy'
21+
charset-normalizer>=3.4.4
22+
cheroot>=11.1.2
23+
chex>=0.1.91
24+
click>=8.3.1
25+
cloudpickle>=3.1.2
26+
colorama>=0.4.6
27+
contourpy>=1.3.3
28+
cryptography>=46.0.6
29+
cycler>=0.12.1
30+
dataclasses-json>=0.6.7
31+
datasets>=4.8.4
32+
decorator>=5.2.1
33+
dill>=0.4.1
34+
dm-tree>=0.1.9
35+
docstring-parser>=0.17.0
36+
einops>=0.8.2
37+
einshape>=1.0
38+
etils>=1.13.0
39+
execnet>=2.1.2
40+
filelock>=3.20.3
41+
flatbuffers>=25.12.19
42+
flax>=0.12.6
43+
fonttools>=4.61.1
44+
frozenlist>=1.8.0
45+
fsspec>=2026.1.0
46+
ftfy>=6.3.1
47+
gast>=0.7.0
48+
gcsfs>=2026.1.0
49+
google-api-core>=2.30.0
50+
google-auth-oauthlib>=1.3.0
51+
google-auth>=2.49.1
52+
google-cloud-core>=2.5.0
53+
google-cloud-storage-control>=1.11.0
54+
google-cloud-storage>=3.10.1
55+
google-crc32c>=1.8.0
56+
google-pasta>=0.2.0
57+
google-resumable-media>=2.8.0
58+
googleapis-common-protos>=1.73.1
59+
grain>=0.2.16
60+
grpc-google-iam-v1>=0.14.3
61+
grpcio-status>=1.76.0
62+
grpcio>=1.76.0
63+
gviz-api>=1.10.0
64+
h11>=0.16.0
65+
h5py>=3.15.1
66+
hf-transfer>=0.1.9
67+
hf-xet>=1.4.2 ; platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
68+
httpcore>=1.0.9
69+
httpx>=0.28.1
70+
huggingface-hub>=0.36.2
71+
humanize>=4.15.0
72+
hypothesis>=6.142.1
73+
idna>=3.11
74+
imageio-ffmpeg>=0.6.0
75+
imageio>=2.37.3
76+
immutabledict>=4.3.1
77+
importlib-resources>=6.5.2
78+
iniconfig>=2.3.0
79+
isort>=8.0.1
80+
jaraco-functools>=4.4.0
81+
jax>=0.9.0
82+
jaxlib>=0.9.0
83+
jaxtyping>=0.3.9
84+
jinja2>=3.1.6
85+
keras>=3.13.1
86+
kiwisolver>=1.4.9
87+
lazy-loader>=0.5
88+
libclang>=18.1.1
89+
libtpu>=0.0.34 ; platform_machine == 'x86_64' and sys_platform == 'linux'
90+
markdown-it-py>=4.0.0
91+
markdown>=3.10.1
92+
markupsafe>=3.0.3
93+
marshmallow>=3.26.2
94+
matplotlib>=3.10.8
95+
mccabe>=0.7.0
96+
mdurl>=0.1.2
97+
ml-dtypes>=0.5.4
98+
more-itertools>=10.8.0
99+
mpmath>=1.3.0
100+
msgpack>=1.1.2
101+
multidict>=6.7.1
102+
multiprocess>=0.70.19
103+
mypy-extensions>=1.1.0
104+
namex>=0.1.0
105+
networkx>=3.6.1
106+
numpy-typing-compat>=20251206.2.0
107+
numpy>=2.0.2
108+
nvidia-cuda-cccl>=13.1.115
109+
oauthlib>=3.3.1
110+
opencv-python-headless>=4.13.0.92
111+
opt-einsum>=3.4.0
112+
optax>=0.2.8
113+
optree>=0.18.0
114+
optype>=0.15.0
115+
orbax-checkpoint>=0.11.33
116+
orbax-export>=0.0.8
117+
packaging>=26.0
118+
pandas>=3.0.1
119+
parameterized>=0.9.0
120+
pathspec>=1.0.4
121+
pillow>=12.1.0
122+
platformdirs>=4.9.4
123+
pluggy>=1.6.0
124+
portpicker>=1.6.0
125+
promise>=2.3
126+
propcache>=0.4.1
127+
proto-plus>=1.27.2
128+
protobuf>=6.33.6
129+
psutil>=7.2.1
130+
pyarrow>=23.0.1
131+
pyasn1-modules>=0.4.2
132+
pyasn1>=0.6.3
133+
pycparser>=3.0 ; implementation_name != 'PyPy' and platform_python_implementation != 'PyPy'
134+
pydantic-core>=2.41.5
135+
pydantic>=2.12.5
136+
pyelftools>=0.32
137+
pygments>=2.19.2
138+
pyink>=25.12.0
139+
pylint>=4.0.5
140+
pyparsing>=3.3.2
141+
pyproject-hooks>=1.2.0
142+
pytest-xdist>=3.8.0
143+
pytest>=8.4.2
144+
python-dateutil>=2.9.0.post0
145+
pytokens>=0.4.1
146+
pyyaml>=6.0.3
147+
qwix @ https://github.com/google/qwix/archive/408a0f48f988b6c5b180e07f0cb1d05997bf0dcc.zip
148+
regex>=2026.2.28
149+
requests-oauthlib>=2.0.0
150+
requests>=2.32.5
151+
rich>=14.2.0
152+
ruff>=0.15.8
153+
safetensors>=0.7.0
154+
scikit-image>=0.26.0
155+
scipy-stubs>=1.17.0.1
156+
scipy>=1.17.0
157+
sentencepiece>=0.2.1
158+
setuptools>=80.10.1
159+
simple-parsing>=0.1.8
160+
simplejson>=3.20.2
161+
six>=1.17.0
162+
sortedcontainers>=2.4.0
163+
sympy>=1.14.0
164+
tensorboard-data-server>=0.7.2
165+
tensorboard-plugin-profile>=2.22.0
166+
tensorboard>=2.20.0
167+
tensorboardx>=2.6.4
168+
tensorflow-datasets>=4.9.9
169+
tensorflow-metadata>=1.17.3
170+
tensorflow>=2.20.0
171+
tensorstore>=0.1.80
172+
termcolor>=3.3.0
173+
tifffile>=2026.3.3
174+
tokamax>=0.0.10
175+
tokenizers>=0.22.2
176+
toml>=0.10.2
177+
tomlkit>=0.14.0
178+
toolz>=1.1.0
179+
torch @ https://download.pytorch.org/whl/cpu/torch-2.10.0%2Bcpu-cp312-cp312-manylinux_2_28_x86_64.whl
180+
torchvision @ https://download.pytorch.org/whl/cpu/torchvision-0.25.0%2Bcpu-cp312-cp312-manylinux_2_28_x86_64.whl
181+
tqdm>=4.67.3
182+
transformers>=4.57.6
183+
treescope>=0.1.10
184+
typeguard>=2.13.3
185+
typing-extensions>=4.15.0
186+
typing-inspect>=0.9.0
187+
typing-inspection>=0.4.2
188+
tzdata>=2025.3 ; sys_platform == 'emscripten' or sys_platform == 'win32'
189+
urllib3>=2.6.3
190+
uvloop>=0.22.1
191+
wadler-lindig>=0.1.7
192+
wcwidth>=0.6.0
193+
werkzeug>=3.1.5
194+
wheel>=0.46.2
195+
wrapt>=2.1.2
196+
xprof>=2.22.0
197+
xxhash>=3.6.0
198+
yarl>=1.23.0
199+
zipp>=3.23.0
200+
zstandard>=0.25.0

docker_build_dependency_image.sh

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,6 @@
2020
# Each time you update the base image via a "bash docker_maxdiffusion_image_upload.sh", there will be a slow upload process
2121
# (minutes). However, if you are simply changing local code and not updating dependencies, uploading just takes a few seconds.
2222

23-
# bash docker_build_dependency_image.sh MODE=jax_ai_image BASEIMAGE={{JAX_AI_IMAGE BASEIMAGE FROM ARTIFACT REGISTRY}}
24-
# Note: The mode stable_stack is marked for deprecation, please use MODE=jax_ai_image instead
25-
# bash docker_build_dependency_image.sh MODE=stable_stack BASEIMAGE={{JAX_STABLE_STACK_IMAGE BASEIMAGE FROM ARTIFACT REGISTRY}}
26-
# bash docker_build_dependency_image.sh MODE=nightly
2723
# bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.13
2824
# bash docker_build_dependency_image.sh MODE=stable
2925

@@ -70,24 +66,14 @@ if [[ ${DEVICE} == "gpu" ]]; then
7066
export BASEIMAGE=ghcr.io/nvidia/jax:base
7167
fi
7268
docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg DEVICE=$DEVICE --build-arg BASEIMAGE=$BASEIMAGE -f ./maxdiffusion_gpu_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} .
73-
else
74-
if [[ ${MODE} == "stable_stack" || ${MODE} == "jax_ai_image" ]]; then
75-
if [[ ! -v BASEIMAGE ]]; then
76-
echo "Erroring out because BASEIMAGE is unset, please set it!"
77-
exit 1
78-
fi
79-
docker build --no-cache \
80-
--build-arg JAX_AI_IMAGE_BASEIMAGE=${BASEIMAGE} \
81-
--build-arg COMMIT_HASH=${COMMIT_HASH} \
82-
--network=host \
83-
-t ${LOCAL_IMAGE_NAME} \
84-
-f maxdiffusion_jax_ai_image_tpu.Dockerfile .
85-
else
86-
docker build --no-cache \
87-
--network=host \
88-
--build-arg MODE=${MODE} \
89-
--build-arg JAX_VERSION=${JAX_VERSION} \
90-
-t ${LOCAL_IMAGE_NAME} \
91-
-f maxdiffusion_dependencies.Dockerfile .
92-
fi
69+
else
70+
# Default to maxdiffusion_dependencies.Dockerfile for non-GPU builds
71+
export BASEIMAGE=${BASEIMAGE:-python:3.12-slim-bullseye}
72+
docker build --no-cache \
73+
--network=host \
74+
--build-arg MODE=${MODE} \
75+
--build-arg JAX_VERSION=${JAX_VERSION} \
76+
--build-arg BASEIMAGE=${BASEIMAGE} \
77+
-t ${LOCAL_IMAGE_NAME} \
78+
-f maxdiffusion_dependencies.Dockerfile .
9379
fi

0 commit comments

Comments
 (0)