-
Notifications
You must be signed in to change notification settings - Fork 53
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add script to run FHE programs on GCP TPU via HEIR+Jaxite path.
Currently this PR: - Ports over Ivan's gcloud scripts to HEIR repo. - It has an example jaxite_program to test out TODO: - Package HEIR compiled jaxite code and run on TPU. - Run on a test gcp account. Sending this out for initial review to get thoughts on how best to get this started. PiperOrigin-RevId: 702221303
- Loading branch information
1 parent
ce83b38
commit 27141cd
Showing
7 changed files
with
692 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
# HEIR on Google cloud TPU | ||
|
||
This document provides a brief introduction to running FHE programs compiled | ||
with [HEIR](https://heir.dev) on cloud TPU. Jaxite library is used to run | ||
programs on TPU. | ||
|
||
### Setting up the GCP project | ||
|
||
Before you follow this quickstart you must create a Google Cloud Platform | ||
account, install the Google Cloud CLI and configure the ```gcloud``` command. | ||
For more information see [Set up an account and a Cloud TPU project](https://cloud.google.com/tpu/docs/setup-gcp-account) | ||
|
||
#### Install the Google Cloud CLI | ||
The Google Cloud CLI contains tools and libraries for interfacing with Google | ||
Cloud products and services. For more information | ||
see [Installing the Google Cloud CLI](https://cloud.google.com/sdk/docs/install) | ||
|
||
#### Configure the gcloud command | ||
Run the following commands to configure ```gcloud``` to use your Google Cloud | ||
project. | ||
|
||
```sh | ||
$ gcloud config set account your-email-account | ||
$ gcloud config set project your-project-id | ||
``` | ||
|
||
#### Enable the cloud TPU API | ||
Enable the Cloud TPU API and create a service identity. | ||
|
||
```sh | ||
$ gcloud services enable tpu.googleapis.com | ||
$ gcloud beta services identity create --service tpu.googleapis.com | ||
``` | ||
|
||
## Provision a TPU | ||
* Clone HEIR repo and install dependencies | ||
|
||
```sh | ||
$ git clone [email protected]:google/heir.git | ||
$ cd heir/scripts/gcp | ||
$ pip install -r requirements.txt | ||
``` | ||
|
||
* Create a TPU | ||
|
||
Create a new TPU called "heir_tpu" and the required infrastructure | ||
|
||
```sh | ||
$ ./tool provision heir_tpu | ||
``` | ||
|
||
## Execute an FHE program on the TPU | ||
|
||
```sh | ||
$ ./tool run examples/jaxite_example.py | ||
``` | ||
|
||
## Execute a HEIR program on the TPU | ||
Compile a HEIR program and run on TPU | ||
```sh | ||
$ bazel run //heir/tools:heir-opt -- --tosa-to-boolean-jaxite=entry-function=test_add_one_lut3 | ||
$ bazel run //heir/tools:heir-translate --emit-jaxite | ||
$ ./tool run --files="/bazel-bin/heir/tools/add_one_lut3_jaxite.py" --main="/bazel-bin/heir/tools/add_one_lut3_jaxite.py" | ||
``` | ||
|
||
## Pricing | ||
Though stopped TPU VM does not incur any cost, the disk attached to the VM does. | ||
The cost is the same as the cost of the disk when the VM is running. See | ||
[Disk and image pricing](https://cloud.google.com/compute/disks-image-pricing#disk) | ||
for more details. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
"""Core functions for interacting with GCP TPUs.""" | ||
|
||
import functools | ||
from typing import Optional | ||
|
||
import google.auth | ||
from google.cloud import tpu_v2 | ||
|
||
|
||
@functools.lru_cache() | ||
def tpu_client(): | ||
return tpu_v2.TpuClient() | ||
|
||
|
||
@functools.lru_cache() | ||
def default_project() -> str: | ||
_, project = google.auth.default() | ||
return project | ||
|
||
|
||
class Core: | ||
"""Core GCP TPU provisioningfunctionality. | ||
project: The GCP project to use. | ||
zone: The GCP zone to use. | ||
https://cloud.google.com/tpu/docs/regions-zones#europe | ||
""" | ||
|
||
def __init__(self, project: Optional[str], zone: Optional[str]) -> None: | ||
if project is None: | ||
project = default_project() | ||
self.project = project | ||
|
||
if zone is None: | ||
zone = "us-central1-c" | ||
if zone not in self.available_zones(): | ||
raise ValueError( | ||
f"{zone=} is not available in {project=}. Available zones are:" | ||
f" {self.available_zones().keys()}" | ||
) | ||
self.zone = zone | ||
self.region = zone[:-2] | ||
|
||
self.parent = f"projects/{self.project}/locations/{self.zone}" | ||
print(f"Using {project=} {zone=}") | ||
|
||
@staticmethod | ||
def add_args(parser) -> None: | ||
"""Adds project and zone arguments to the parser.""" | ||
parser.add_argument("--project", type=str, required=False) | ||
parser.add_argument("--zone", type=str, required=False) | ||
|
||
def from_args(self, args) -> "Core": | ||
"""Creates a Core object from the given args.""" | ||
return Core(project=args.project, zone=args.zone) | ||
|
||
@functools.lru_cache() | ||
def available_zones(self) -> dict[str, str]: | ||
client = tpu_v2.TpuClient() | ||
resp = client.list_locations( | ||
{"name": f"projects/{self.project}", "page_size": 1000} | ||
) | ||
return {l.location_id: l.name for l in resp.locations} | ||
|
||
def list_nodes(self, all_zones: bool = False) -> list[tpu_v2.Node]: | ||
"""Returns all nodes in the given zones.""" | ||
parents = ( | ||
list(self.available_zones().values()) if all_zones else [self.parent] | ||
) | ||
|
||
nodes = [] | ||
for parent in parents: | ||
nodes.extend( | ||
tpu_client() | ||
.list_nodes(request={"parent": parent, "page_size": 1000}) | ||
.nodes | ||
) | ||
return nodes | ||
|
||
|
||
def get_node(parent: str, node_id: str) -> tpu_v2.Node: | ||
return tpu_client().get_node(request={"name": f"{parent}/nodes/{node_id}"}) | ||
|
||
|
||
def start_node(node: tpu_v2.Node) -> tpu_v2.Node: | ||
"""Starts a TPU node.""" | ||
print(f"Starting {node.name} ...") | ||
node = tpu_client().start_node(request={"name": node.name}).result() | ||
print(f"Started {node.name}") | ||
return node | ||
|
||
|
||
def stop_node(node: tpu_v2.Node) -> tpu_v2.Node: | ||
"""Stops a TPU node.""" | ||
print(f"Stopping {node.name} ...") | ||
node = tpu_client().stop_node(request={"name": node.name}).result() | ||
print(f"Stopped {node.name}") | ||
return node |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
"""Example script showing how to use Jaxite.""" | ||
|
||
import timeit | ||
|
||
from jaxite.jaxite_bool import jaxite_bool | ||
|
||
|
||
bool_params = jaxite_bool.bool_params | ||
|
||
# Note: In real applications, a cryptographically secure seed needs to be | ||
# used. | ||
lwe_rng = bool_params.get_lwe_rng_for_128_bit_security(seed=1) | ||
rlwe_rng = bool_params.get_rlwe_rng_for_128_bit_security(seed=1) | ||
params = bool_params.get_params_for_128_bit_security() | ||
|
||
cks = jaxite_bool.ClientKeySet( | ||
params, | ||
lwe_rng=lwe_rng, | ||
rlwe_rng=rlwe_rng, | ||
) | ||
print("Client keygen done") | ||
|
||
sks = jaxite_bool.ServerKeySet( | ||
cks, | ||
params, | ||
lwe_rng=lwe_rng, | ||
rlwe_rng=rlwe_rng, | ||
bootstrap_callback=None, | ||
) | ||
print("Server keygen done.") | ||
|
||
ct_true = jaxite_bool.encrypt(True, cks, lwe_rng) | ||
ct_false = jaxite_bool.encrypt(False, cks, lwe_rng) | ||
|
||
# Calling function once before timing it so compile-time doesn't get | ||
# included in timing metircs. | ||
and_gate = jaxite_bool.and_(ct_false, ct_true, sks, params) | ||
|
||
# Using Timeit | ||
def timed_fn(): | ||
and_gate = jaxite_bool.and_(ct_false, ct_true, sks, params) | ||
and_gate.block_until_ready() | ||
timer = timeit.Timer(timed_fn) | ||
execution_time = timer.repeat(repeat=1, number=1) | ||
print("And gate execution time: ", execution_time) | ||
|
||
actual = jaxite_bool.decrypt(and_gate, cks) | ||
expected = False | ||
print(f"{actual=}, {expected=}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
"""Example script to test the GCP setup.""" | ||
|
||
print("Hello, World from GCP") |
Oops, something went wrong.