Skip to content

Commit

Permalink
Add script to run FHE programs on GCP TPU via HEIR+Jaxite path.
Browse files Browse the repository at this point in the history
Currently this PR:
- Ports over Ivan's gcloud scripts to HEIR repo.
- It has an example jaxite_program to test out

TODO:
- Package HEIR compiled jaxite code and run on TPU.
- Run on a test gcp account.

Sending this out for initial review to get thoughts on how best to get this started.

PiperOrigin-RevId: 702221303
  • Loading branch information
code-perspective authored and copybara-github committed Dec 20, 2024
1 parent ce83b38 commit 27141cd
Show file tree
Hide file tree
Showing 7 changed files with 692 additions and 0 deletions.
70 changes: 70 additions & 0 deletions scripts/gcp/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# HEIR on Google cloud TPU

This document provides a brief introduction to running FHE programs compiled
with [HEIR](https://heir.dev) on cloud TPU. Jaxite library is used to run
programs on TPU.

### Setting up the GCP project

Before you follow this quickstart you must create a Google Cloud Platform
account, install the Google Cloud CLI and configure the ```gcloud``` command.
For more information see [Set up an account and a Cloud TPU project](https://cloud.google.com/tpu/docs/setup-gcp-account)

#### Install the Google Cloud CLI
The Google Cloud CLI contains tools and libraries for interfacing with Google
Cloud products and services. For more information
see [Installing the Google Cloud CLI](https://cloud.google.com/sdk/docs/install)

#### Configure the gcloud command
Run the following commands to configure ```gcloud``` to use your Google Cloud
project.

```sh
$ gcloud config set account your-email-account
$ gcloud config set project your-project-id
```

#### Enable the cloud TPU API
Enable the Cloud TPU API and create a service identity.

```sh
$ gcloud services enable tpu.googleapis.com
$ gcloud beta services identity create --service tpu.googleapis.com
```

## Provision a TPU
* Clone HEIR repo and install dependencies

```sh
$ git clone [email protected]:google/heir.git
$ cd heir/scripts/gcp
$ pip install -r requirements.txt
```

* Create a TPU

Create a new TPU called "heir_tpu" and the required infrastructure

```sh
$ ./tool provision heir_tpu
```

## Execute an FHE program on the TPU

```sh
$ ./tool run examples/jaxite_example.py
```

## Execute a HEIR program on the TPU
Compile a HEIR program and run on TPU
```sh
$ bazel run //heir/tools:heir-opt -- --tosa-to-boolean-jaxite=entry-function=test_add_one_lut3
$ bazel run //heir/tools:heir-translate --emit-jaxite
$ ./tool run --files="/bazel-bin/heir/tools/add_one_lut3_jaxite.py" --main="/bazel-bin/heir/tools/add_one_lut3_jaxite.py"
```

## Pricing
Though stopped TPU VM does not incur any cost, the disk attached to the VM does.
The cost is the same as the cost of the disk when the VM is running. See
[Disk and image pricing](https://cloud.google.com/compute/disks-image-pricing#disk)
for more details.
98 changes: 98 additions & 0 deletions scripts/gcp/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""Core functions for interacting with GCP TPUs."""

import functools
from typing import Optional

import google.auth
from google.cloud import tpu_v2


@functools.lru_cache()
def tpu_client():
return tpu_v2.TpuClient()


@functools.lru_cache()
def default_project() -> str:
_, project = google.auth.default()
return project


class Core:
"""Core GCP TPU provisioningfunctionality.
project: The GCP project to use.
zone: The GCP zone to use.
https://cloud.google.com/tpu/docs/regions-zones#europe
"""

def __init__(self, project: Optional[str], zone: Optional[str]) -> None:
if project is None:
project = default_project()
self.project = project

if zone is None:
zone = "us-central1-c"
if zone not in self.available_zones():
raise ValueError(
f"{zone=} is not available in {project=}. Available zones are:"
f" {self.available_zones().keys()}"
)
self.zone = zone
self.region = zone[:-2]

self.parent = f"projects/{self.project}/locations/{self.zone}"
print(f"Using {project=} {zone=}")

@staticmethod
def add_args(parser) -> None:
"""Adds project and zone arguments to the parser."""
parser.add_argument("--project", type=str, required=False)
parser.add_argument("--zone", type=str, required=False)

def from_args(self, args) -> "Core":
"""Creates a Core object from the given args."""
return Core(project=args.project, zone=args.zone)

@functools.lru_cache()
def available_zones(self) -> dict[str, str]:
client = tpu_v2.TpuClient()
resp = client.list_locations(
{"name": f"projects/{self.project}", "page_size": 1000}
)
return {l.location_id: l.name for l in resp.locations}

def list_nodes(self, all_zones: bool = False) -> list[tpu_v2.Node]:
"""Returns all nodes in the given zones."""
parents = (
list(self.available_zones().values()) if all_zones else [self.parent]
)

nodes = []
for parent in parents:
nodes.extend(
tpu_client()
.list_nodes(request={"parent": parent, "page_size": 1000})
.nodes
)
return nodes


def get_node(parent: str, node_id: str) -> tpu_v2.Node:
return tpu_client().get_node(request={"name": f"{parent}/nodes/{node_id}"})


def start_node(node: tpu_v2.Node) -> tpu_v2.Node:
"""Starts a TPU node."""
print(f"Starting {node.name} ...")
node = tpu_client().start_node(request={"name": node.name}).result()
print(f"Started {node.name}")
return node


def stop_node(node: tpu_v2.Node) -> tpu_v2.Node:
"""Stops a TPU node."""
print(f"Stopping {node.name} ...")
node = tpu_client().stop_node(request={"name": node.name}).result()
print(f"Stopped {node.name}")
return node
49 changes: 49 additions & 0 deletions scripts/gcp/examples/jaxite_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""Example script showing how to use Jaxite."""

import timeit

from jaxite.jaxite_bool import jaxite_bool


bool_params = jaxite_bool.bool_params

# Note: In real applications, a cryptographically secure seed needs to be
# used.
lwe_rng = bool_params.get_lwe_rng_for_128_bit_security(seed=1)
rlwe_rng = bool_params.get_rlwe_rng_for_128_bit_security(seed=1)
params = bool_params.get_params_for_128_bit_security()

cks = jaxite_bool.ClientKeySet(
params,
lwe_rng=lwe_rng,
rlwe_rng=rlwe_rng,
)
print("Client keygen done")

sks = jaxite_bool.ServerKeySet(
cks,
params,
lwe_rng=lwe_rng,
rlwe_rng=rlwe_rng,
bootstrap_callback=None,
)
print("Server keygen done.")

ct_true = jaxite_bool.encrypt(True, cks, lwe_rng)
ct_false = jaxite_bool.encrypt(False, cks, lwe_rng)

# Calling function once before timing it so compile-time doesn't get
# included in timing metircs.
and_gate = jaxite_bool.and_(ct_false, ct_true, sks, params)

# Using Timeit
def timed_fn():
and_gate = jaxite_bool.and_(ct_false, ct_true, sks, params)
and_gate.block_until_ready()
timer = timeit.Timer(timed_fn)
execution_time = timer.repeat(repeat=1, number=1)
print("And gate execution time: ", execution_time)

actual = jaxite_bool.decrypt(and_gate, cks)
expected = False
print(f"{actual=}, {expected=}")
3 changes: 3 additions & 0 deletions scripts/gcp/examples/noop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""Example script to test the GCP setup."""

print("Hello, World from GCP")
Loading

0 comments on commit 27141cd

Please sign in to comment.