Hi all!
I was trying to compare the performance between the tfhe-rs and the Jaxite, expecting that the Jaxite would be way faster than the tfhe-rs as it exploits the GPU, but I found that the Jaxite was too slow than the tfhe-rs. I want to know if my configuration is wrong, or the Jaxite is not fully developed yet.
I've tested with the transpiler of the Jaxite and the tfhe-rs, and used the example of hello_world. I do not use the bazel run
when I tested with Jaxite, as bazel run
cannot initiate the CUDA. (Seems that the GPU / TPU test was not publicly opened in the bazel as far as I checked in here.) Rather, I just ran directly with the python.
The Jaxite spends about 10000 seconds per evaluation, which was not successful after the first iteration, while the tfhe-rs spends about 30 seconds.
user@gpu05:/home/user/fully-homomorphic-encryption$ python3 transpiler/tensorflow/examples/hello_world/hello_world_testbench_python.py
Generating keys
I0909 17:04:13.349542 139736576647680 xla_bridge.py:622] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: Interpreter CUDA
I0909 17:04:13.350527 139736576647680 xla_bridge.py:622] Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
Quantized x = [-128 -42 43 127]
Running FHE circuit
FHE circuit took 9849.094140 seconds
f(-128) = 79
Traceback (most recent call last):
File "/home/user/.local/lib/python3.10/site-packages/jax/_src/api_util.py", line 581, in shaped_abstractify
return _shaped_abstractify_handlers[type(x)](x)
KeyError: <class 'jax.Array'>
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/user/fully-homomorphic-encryption/transpiler/tensorflow/examples/hello_world/hello_world_testbench_python.py", line 98, in <module>
app.run(main)
File "/home/user/.local/lib/python3.10/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/home/user/.local/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/home/user/fully-homomorphic-encryption/transpiler/tensorflow/examples/hello_world/hello_world_testbench_python.py", line 93, in main
quantized_result = jnp.append(quantized_result, result_cleartext)
File "/home/user/.local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/user/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 253, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
File "/home/user/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 161, in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
File "/home/user/.local/lib/python3.10/site-packages/jax/_src/api.py", line 324, in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
File "/home/user/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 477, in common_infer_params
avals.append(shaped_abstractify(a))
File "/home/user/.local/lib/python3.10/site-packages/jax/_src/api_util.py", line 583, in shaped_abstractify
return _shaped_abstractify_slow(x)
File "/home/user/.local/lib/python3.10/site-packages/jax/_src/api_util.py", line 572, in _shaped_abstractify_slow
raise TypeError(
jax._src.traceback_util.UnfilteredStackTrace: TypeError: Cannot interpret value of type <class 'jax.Array'> as an abstract array; it does not have a dtype attribute
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/user/fully-homomorphic-encryption/transpiler/tensorflow/examples/hello_world/hello_world_testbench_python.py", line 98, in <module>
app.run(main)
File "/home/user/.local/lib/python3.10/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/home/user/.local/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/home/user/fully-homomorphic-encryption/transpiler/tensorflow/examples/hello_world/hello_world_testbench_python.py", line 93, in main
quantized_result = jnp.append(quantized_result, result_cleartext)
TypeError: Cannot interpret value of type <class 'jax.Array'> as an abstract array; it does not have a dtype attribute
user@gpu05:/home/user/fully-homomorphic-encryption$ ./bazel-bin/transpiler/tensorflow/examples/hello_world/./hello_world_testbench
Inferring sine for 0 FHE computation in 33 s
Sine value: -0.01 Inferring sine for 1.5707964
FHE computation in 33 s Sine value: 0.99
Inferring sine for 3.1415927
FHE computation in 32 s Sine value: -0.01
Inferring sine for 4.712389
FHE computation in 32 s
Sine value: -1.09 Inferring sine for 6.2831855
FHE computation in 32 s
Sine value: -0.12
Both codes were based on the same netlist file, which means that they went through the same step but at the very end with different transpilers.
heir-opt --heir-tosa-to-arith ${INPUT_TOSA} | tee >(heir-translate --emit-metadata -o ${OUTPUT_METADATA}) | heir-translate --emit-verilog -o ${OUTPUT_VERILOG}
YOSYS_SCRIPT="read_verilog ${OUTPUT_VERILOG}; hierarchy -check -top main; techmap; opt; splitnets -ports for_*; abc -lut 3; opt_clean -purge; techmap -map ${LUTMAP_SCRIPT}; opt_clean -purge; flatten; hierarchy -generate lut3 o:Y i:P* i:A i:B i:C; opt_expr; opt; opt_clean -purge; rename -hide */w:*; rename -enumerate */w:*; rename -top ${MODEL_NAME}; clean; write_verilog -noattr ${OUTPUT_NETLIST}"
yosys -p "${YOSYS_SCRIPT}"
# transpiler is exported to the tfhe-rs transpiler
transpiler --ir_path ${OUTPUT_NETLIST} --liberty_path ${LIBERTY_CELLS} --heir_metadata_path ${OUTPUT_METADATA} --parallelism=0 --rs_out ${OUTPUT_RUST}
# transpiler is exported to the jaxite transpiler
transpiler --ir_path ${OUTPUT_NETLIST} --optimizer=yosys --liberty_path ${LIBERTY_CELLS} --metadata_path ${OUTPUT_METADATA} --parallelism=0 --py_out ${OUTPUT_PY}
This is the code that I've used as a testbench for the Jaxite.
"""A jaxite testbench for hello_world tensorflow code."""
from collections.abc import Sequence
import functools
from absl import app
from jaxite.jaxite_bool import bool_params
from jaxite.jaxite_bool import jaxite_bool
from jax import Array as ndarray
from jax import numpy as jnp
import timeit
from transpiler.tensorflow.examples.hello_world import hello_world_fhe_lib_python
def bit_slice_to_int(bit_slice: list[bool]) -> int:
"""Given an list of bits, return a base-10 integer."""
result = 0
for i, bit in enumerate(bit_slice):
result |= int(bit) << i
return result
def int_to_bit_slice(input_int: int) -> list[bool]:
"""Given an integer and bit width, return a bitwise representation."""
result: list[bool] = [False] * 8
for i in range(8):
result[i] = ((input_int >> i) & 1) != 0
return result
def quantize(arr: ndarray) -> ndarray:
"""
Quantize an array of jnp.float32 to jnp.int8.
Args:
- arr (jnp.ndarray): Input array of jnp.float32.
Returns:
- jnp.ndarray: Quantized array of jnp.int8.
"""
return ((arr / 0.024480115622282) - 128.0).astype(jnp.int8)
def dequantize(arr: ndarray) -> ndarray:
"""
Dequantize an array of jnp.int8 to jnp.float32.
Args:
- arr (jnp.ndarray): Input array of jnp.int8.
Returns:
- jnp.ndarray: Dequantized array of jnp.float32.
"""
return ((arr.astype(jnp.float32) - 5) * 0.00829095672816038)
@functools.cache
def setup():
print(f'Generating keys')
boolean_params = bool_params.get_params_for_128_bit_security()
lwe_rng = bool_params.get_lwe_rng_for_128_bit_security(1)
rlwe_rng = bool_params.get_rlwe_rng_for_128_bit_security(1)
# lwe_dimension=800,
# rlwe_dimension=2,
# plaintext_modulus=2^32,
# polynomial_modulus_degree=512,
# bsk log_base=4, level_count=6
# ksk log_base=4, level_count=5
cks = jaxite_bool.ClientKeySet(boolean_params, lwe_rng, rlwe_rng)
sks = jaxite_bool.ServerKeySet(cks, boolean_params, lwe_rng, rlwe_rng)
return (boolean_params, lwe_rng, cks, sks)
def main(argv: Sequence[str]) -> None:
del argv
(boolean_params, lwe_rng, cks, sks) = setup()
pi = 3.14159265358979323846
x_vals = jnp.float32(jnp.linspace(0, 2.0*pi, 4))
quantized_x_vals = quantize(x_vals)
quantized_result = ndarray() # type: ignore
print("Quantized x = ", quantized_x_vals)
for x in quantized_x_vals:
x_cleartext = int_to_bit_slice(x)
x_ciphertext = [jaxite_bool.encrypt(z, cks, lwe_rng) for z in x_cleartext]
print('Running FHE circuit')
start = timeit.default_timer()
result_ciphertext = hello_world_fhe_lib_python.hello_world(
x_ciphertext,
sks,
boolean_params,
)
end = timeit.default_timer()
print(f'FHE circuit took {end - start:1f} seconds')
result_ciphertext = [jaxite_bool.decrypt(z, cks) for z in result_ciphertext]
result_cleartext = bit_slice_to_int(result_ciphertext)
print(f'f({x}) = {result_cleartext}')
quantized_result = jnp.append(quantized_result, result_cleartext)
result = dequantize(quantized_result)
if __name__ == '__main__':
app.run(main)
This is the modified BUILD file to create the py_library
# Hello World example
load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
load("@rules_rust//rust:defs.bzl", "rust_binary", "rust_library")
load("@rules_python//python:defs.bzl", "py_binary", "py_library")
package(
default_applicable_licenses = ["@com_google_fully_homomorphic_encryption//:license"],
default_visibility = ["//visibility:public"],
)
licenses(["notice"])
rust_library(
name = "hello_world_fhe_lib_rust",
srcs = ["hello_world_fhe_lib_rust.rs"],
disable_pipelining = True,
tags = [
"manual",
"notap",
],
deps = [
"@crate_index//:rayon",
"@crate_index//:tfhe",
],
rustc_flags = ["--cfg", "lut"]
)
rust_binary(
name = "hello_world_testbench_rust",
srcs = [
"hello_world_testbench_rust.rs",
],
tags = [
"manual",
"notap",
],
deps = [
":hello_world_fhe_lib_rust",
"@crate_index//:rayon",
"@crate_index//:tfhe",
],
)
py_library(
name = "hello_world_fhe_lib_python",
srcs = ["hello_world_fhe_lib_python.py"],
tags = [
"manual",
"notap",
],
deps = [
"@transpiler_pip_deps//pypi__jaxite",
],
)
py_binary(
name = "hello_world_testbench_python",
srcs = [
"hello_world_testbench_python.py",
],
tags = [
"manual",
"notap",
],
deps = [
":hello_world_fhe_lib_python",
"@com_google_absl_py//absl:app",
"@transpiler_pip_deps//pypi__jaxite",
],
)
I'm using Python 3.10.13, Nvidia V100 as GPU, and CUDA 11.8. Tell me if my testbench or configuration is wrong.