Giter Club home page Giter Club logo

openxla-pjrt-plugin's Introduction

OpenXLA PJRT Plugin

๐Ÿšจ This repo is in process of being deprecated. ๐Ÿšจ

This repository contains an experimental PJRT plugin library which can bridge Jax (and TensorFlow in the future) to OpenXLA/IREE.

Developing

Support for dynamically loaded PJRT plugins is brand new as of 12/21/2022 and there are sharp edges still. The following procedure is being used to develop.

There are multiple development workflows, ranked from easiest to hardest (but most powerful).

Setup options

The below presumes that you have a compatible Jax/Jaxlib installed. Since PJRT plugin support is moving fast, it is rare that released versions are appropriate. See "Building Jax from Source" below.

If you are building without CUDA, you may still need to install IREE's CUDA deps for the bazel build below:

export IREE_CUDA_DEPS_DIR=${HOME?}/.iree_cuda_deps
../iree/build_tools/docker/context/fetch_cuda_deps.sh ${IREE_CUDA_DEPS_DIR?}

Option 0: Pip install (non-dev)

pip install jax openxla_pjrt_plugin_cpu \
  -f https://openxla.github.io/openxla-pjrt-plugin/pip-release-links.html \
  -f https://openxla.github.io/iree/pip-release-links.html
  -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html

Then one can verify & use simply with

$ python -c "import jax; a = jax.numpy.asarray([1, 2, 3, 4, 5, 6, 7, 8, 9]); print(a + a);"
Platform 'iree_cpu' is experimental and not all JAX functionality may be correctly supported!
[IREE-PJRT] DEBUG: Using IREE compiler binary: /tmp/.venv/lib/python3.11/site-packages/iree/compiler/_mlir_libs/libIREECompiler.so
[IREE-PJRT] DEBUG: Compiler Version: 20230813.612 @ b56ac23bd85f0b9f4a9939c9e87fe83e629f8566 (API version 1.4)
[IREE-PJRT] DEBUG: Partitioner was not enabled. The partitioner can be enabled by setting the 'PARTITIONER_LIB_PATH' config var ('IREE_PJRT_PARTITIONER_LIB_PATH' env var)
[IREE-PJRT] DEBUG: CPU driver created
[ 2  4  6  8 10 12 14 16 18]

Option 1: Synchronize to a nightly IREE release

python ./sync_deps.py
python -m pip install -U -r requirements.txt
python ./configure.py --cc=clang --cxx=clang++ --cuda-sdk-dir=$CUDA_SDK_DIR

# Source environment variables to run interactively.
# The above generates a .env and .env.sh file with key setup vars.
source .env.sh

# Build.
bazel build iree/integrations/pjrt/...

# Run a sample.
JAX_PLATFORMS=iree_cpu python test/test_simple.py
JAX_PLATFORMS=iree_cuda python test/test_simple.py
# When multiple CUDA devices are installed, pick one by setting CUDA_VISIBLE_DEVICES=<n>.
CUDA_VISIBLE_DEVICES=0 JAX_PLATFORMS=iree_cuda python test/test_simple.py

Option 2: Set up for a full at-head dev rig

mkdir openxla
cd openxla
python -m venv .env
source .env/bin/activate || die "Could not activate venv"

pip install git+https://github.com/openxla/openxla-devtools.git
openxla-workspace init
openxla-workspace checkout --sync openxla-pjrt-plugin

cd jax
pip install build numpy wheel
python build/build.py \
    --bazel_options=--override_repository=xla=$PWD/../xla \
     && pip3 install dist/*.whl --force-reinstall
pip install -e .

cd ../iree
cmake -GNinja -B ../iree-build/ -S . \
    -DCMAKE_BUILD_TYPE=RelWithDebInfo \
    -DIREE_ENABLE_ASSERTIONS=ON \
    -DCMAKE_C_COMPILER=clang \
    -DCMAKE_CXX_COMPILER=clang++ \
    -DIREE_ENABLE_LLD=ON -DIREE_ENABLE_CCACHE=ON
cd ../iree-build
ninja libIREECompiler.so
export DYLIB_PATH=$PWD

cd ../openxla-pjrt-plugin
python ./configure.py --cc=clang --cxx=clang++ --iree-compiler-dylib=$DYLIB_PATH/lib/libIREECompiler.so
source .env.sh
bazel build iree/integrations/pjrt/cpu/...

# Do simple smoke test.
JAX_PLATFORMS=iree_cpu python test/test_simple.py

Building Jax from Source

Install Jax with Python sources:

# Starting in the openxla-pjrt-plugin repo, download JAX and sync to a
# compatible commit.
python ./sync_deps.py
python -m pip install -e ../jax

Build a compatible jaxlib:

cd ../jax
# NOTE: Try running `bazel clean --expunge` if you run into undeclared inclusion
# error(s).
python build/build.py \
  --bazel_options=--override_repository=xla=$PWD/../xla
# Install the version of jaxlib you just built.
python -m pip install dist/*.whl --force-reinstall

Generating runtime traces

The plugins can be build with tracing enabled by adding the bazel build flag --iree_enable_runtime_tracing. With this flag, if a profiler is running, instrumentation will be sent to it. It can be useful to set the environment variable TRACY_NO_EXIT=1 in order to block termination of one-shot programs that exit too quickly to stream all events.

Generating compile_commands.json

compile_commands.json can be generated by the following command.

bazel run @hedron_compile_commands//:refresh_all

ASAN

Developing with ASAN is recommended but requires some special steps because we need to arrange for the plugin to be able to link with undefined symbols and load the ASAN runtime library.

  • Edit out the "-Wl,--no-undefined" from build_defs.bzl
  • Set env var LD_PRELOAD=$(clang-12 -print-file-name=libclang_rt.asan-x86_64.so) (assuming compiling with clang-12. See configured.bazelrc in the IREE repo).
  • Set env var ASAN_OPTIONS=detect_leaks=0 (Python allocates a bunch of stuff that it never frees. TODO: Make this more fine-grained so we can detect leaks in plugin code).
  • --config=asan

This can be improved and made more systematic but should work.

Running the Jax test suite

The JAX test suite can be run with pytest. We recommend using pytest-xdist as it spawns tests in workers which can be restarted in the event of individual test case crashes.

Setup:

# Install pytest
pip install pytest pytest-xdist

# Install the ctstools package from this repo (`-e` makes it editable).
pip install -e ctstools

Example of running tests:

JAX_PLATFORMS=iree_cuda pytest -n4 --max-worker-restart=9999 \
  -p openxla_pjrt_artifacts --openxla-pjrt-artifact-dir=/tmp/foobar \
  ~/src/jax/tests/nn_test.py

Note that you will typically want a small number of workers (-n4 above) for CUDA and a larger number can be tolerated for cpu.

The plugin openxla_pjrt_artifacts is in the ctstools directory and performs additional manipulation of the environment in order to save compilation artifacts, reproducers, etc.

Project Maintenance

This section is a work in progress describing various project maintenance tasks.

Pre-requisite: Install openxla-devtools

pip install git+https://github.com/openxla/openxla-devtools.git

Sync all deps to pinned versions

This updates the git repositories and upgrades Python packages.

openxla-workspace sync
python -m pip install -U -r requirements.txt

Update to latest nightlies

This updates the pinned revisions to track upstream nightlies. Note that the roll action will upgrade Python packages implicitly.

# Updates the sync_deps.py metadata.
openxla-workspace roll nightly
# Brings all dependencies to pinned versions.
openxla-workspace sync

Update just IREE to its latest nightly.

This just updates the IREE compiler and source pins to IREE's latest nightly. It is useful for when there is some issue blocking a jax/xla upgrade but progress is desired. Note that the roll action will upgrade Python packages implicitly.

# Updates the sync_deps.py metadata.
openxla-workspace roll iree_nightly
# Brings all dependencies to pinned versions.
openxla-workspace sync

Alternatively, just the IREE source dep (runtime and APIs) can be pinned to head:

# Updates the sync_deps.py metadata.
openxla-workspace roll iree
# Brings all dependencies to pinned versions.
openxla-workspace sync

Pin current versions of all deps

This can be done if local, cross project changes have been made and landed. It snapshots the state of all deps as actually checked out and updates the metadata.

openxla-workspace pin

Contacts

  • GitHub issues: Feature requests, bugs, and other work tracking
  • OpenXLA discord: Daily development discussions with the core team and collaborators

License

OpenXLA PJRT plugin is licensed under the terms of the Apache 2.0 License with LLVM Exceptions. See LICENSE for more information.

openxla-pjrt-plugin's People

Contributors

gmngeoffrey avatar iree-github-actions-bot avatar jackwolfard avatar jpienaar avatar okkwon avatar phoenix-meadowlark avatar ramiro050 avatar rocallahan avatar rsuderman avatar stellaraccident avatar trevor-m avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.