From 27141cda4a568ad412718cba5518c3ab7dbd1705 Mon Sep 17 00:00:00 2001 From: Shruthi Gorantala Date: Mon, 2 Dec 2024 23:54:20 -0800 Subject: [PATCH] Add script to run FHE programs on GCP TPU via HEIR+Jaxite path. Currently this PR: - Ports over Ivan's gcloud scripts to HEIR repo. - It has an example jaxite_program to test out TODO: - Package HEIR compiled jaxite code and run on TPU. - Run on a test gcp account. Sending this out for initial review to get thoughts on how best to get this started. PiperOrigin-RevId: 702221303 --- scripts/gcp/README.md | 70 +++++++ scripts/gcp/core.py | 98 +++++++++ scripts/gcp/examples/jaxite_example.py | 49 +++++ scripts/gcp/examples/noop.py | 3 + scripts/gcp/provision.py | 205 +++++++++++++++++++ scripts/gcp/requirements.txt | 2 + scripts/gcp/tool | 265 +++++++++++++++++++++++++ 7 files changed, 692 insertions(+) create mode 100644 scripts/gcp/README.md create mode 100644 scripts/gcp/core.py create mode 100644 scripts/gcp/examples/jaxite_example.py create mode 100644 scripts/gcp/examples/noop.py create mode 100644 scripts/gcp/provision.py create mode 100644 scripts/gcp/requirements.txt create mode 100644 scripts/gcp/tool diff --git a/scripts/gcp/README.md b/scripts/gcp/README.md new file mode 100644 index 000000000..de901e736 --- /dev/null +++ b/scripts/gcp/README.md @@ -0,0 +1,70 @@ +# HEIR on Google cloud TPU + +This document provides a brief introduction to running FHE programs compiled +with [HEIR](https://heir.dev) on cloud TPU. Jaxite library is used to run +programs on TPU. + +### Setting up the GCP project + +Before you follow this quickstart you must create a Google Cloud Platform +account, install the Google Cloud CLI and configure the ```gcloud``` command. +For more information see [Set up an account and a Cloud TPU project](https://cloud.google.com/tpu/docs/setup-gcp-account) + +#### Install the Google Cloud CLI +The Google Cloud CLI contains tools and libraries for interfacing with Google +Cloud products and services. For more information +see [Installing the Google Cloud CLI](https://cloud.google.com/sdk/docs/install) + +#### Configure the gcloud command +Run the following commands to configure ```gcloud``` to use your Google Cloud +project. + +```sh +$ gcloud config set account your-email-account +$ gcloud config set project your-project-id +``` + +#### Enable the cloud TPU API +Enable the Cloud TPU API and create a service identity. + +```sh +$ gcloud services enable tpu.googleapis.com +$ gcloud beta services identity create --service tpu.googleapis.com +``` + +## Provision a TPU +* Clone HEIR repo and install dependencies + +```sh +$ git clone git@github.com:google/heir.git +$ cd heir/scripts/gcp +$ pip install -r requirements.txt +``` + +* Create a TPU + +Create a new TPU called "heir_tpu" and the required infrastructure + +```sh +$ ./tool provision heir_tpu +``` + +## Execute an FHE program on the TPU + +```sh +$ ./tool run examples/jaxite_example.py +``` + +## Execute a HEIR program on the TPU +Compile a HEIR program and run on TPU +```sh +$ bazel run //heir/tools:heir-opt -- --tosa-to-boolean-jaxite=entry-function=test_add_one_lut3 +$ bazel run //heir/tools:heir-translate --emit-jaxite +$ ./tool run --files="/bazel-bin/heir/tools/add_one_lut3_jaxite.py" --main="/bazel-bin/heir/tools/add_one_lut3_jaxite.py" +``` + +## Pricing +Though stopped TPU VM does not incur any cost, the disk attached to the VM does. +The cost is the same as the cost of the disk when the VM is running. See +[Disk and image pricing](https://cloud.google.com/compute/disks-image-pricing#disk) +for more details. diff --git a/scripts/gcp/core.py b/scripts/gcp/core.py new file mode 100644 index 000000000..99d3e3d2b --- /dev/null +++ b/scripts/gcp/core.py @@ -0,0 +1,98 @@ +"""Core functions for interacting with GCP TPUs.""" + +import functools +from typing import Optional + +import google.auth +from google.cloud import tpu_v2 + + +@functools.lru_cache() +def tpu_client(): + return tpu_v2.TpuClient() + + +@functools.lru_cache() +def default_project() -> str: + _, project = google.auth.default() + return project + + +class Core: + """Core GCP TPU provisioningfunctionality. + + project: The GCP project to use. + zone: The GCP zone to use. + https://cloud.google.com/tpu/docs/regions-zones#europe + """ + + def __init__(self, project: Optional[str], zone: Optional[str]) -> None: + if project is None: + project = default_project() + self.project = project + + if zone is None: + zone = "us-central1-c" + if zone not in self.available_zones(): + raise ValueError( + f"{zone=} is not available in {project=}. Available zones are:" + f" {self.available_zones().keys()}" + ) + self.zone = zone + self.region = zone[:-2] + + self.parent = f"projects/{self.project}/locations/{self.zone}" + print(f"Using {project=} {zone=}") + + @staticmethod + def add_args(parser) -> None: + """Adds project and zone arguments to the parser.""" + parser.add_argument("--project", type=str, required=False) + parser.add_argument("--zone", type=str, required=False) + + def from_args(self, args) -> "Core": + """Creates a Core object from the given args.""" + return Core(project=args.project, zone=args.zone) + + @functools.lru_cache() + def available_zones(self) -> dict[str, str]: + client = tpu_v2.TpuClient() + resp = client.list_locations( + {"name": f"projects/{self.project}", "page_size": 1000} + ) + return {l.location_id: l.name for l in resp.locations} + + def list_nodes(self, all_zones: bool = False) -> list[tpu_v2.Node]: + """Returns all nodes in the given zones.""" + parents = ( + list(self.available_zones().values()) if all_zones else [self.parent] + ) + + nodes = [] + for parent in parents: + nodes.extend( + tpu_client() + .list_nodes(request={"parent": parent, "page_size": 1000}) + .nodes + ) + return nodes + + +def get_node(parent: str, node_id: str) -> tpu_v2.Node: + return tpu_client().get_node(request={"name": f"{parent}/nodes/{node_id}"}) + + +def start_node(node: tpu_v2.Node) -> tpu_v2.Node: + """Starts a TPU node.""" + print(f"Starting {node.name} ...") + node = tpu_client().start_node(request={"name": node.name}).result() + print(f"Started {node.name}") + return node + + +def stop_node(node: tpu_v2.Node) -> tpu_v2.Node: + """Stops a TPU node.""" + print(f"Stopping {node.name} ...") + node = tpu_client().stop_node(request={"name": node.name}).result() + print(f"Stopped {node.name}") + return node diff --git a/scripts/gcp/examples/jaxite_example.py b/scripts/gcp/examples/jaxite_example.py new file mode 100644 index 000000000..a1f910f9c --- /dev/null +++ b/scripts/gcp/examples/jaxite_example.py @@ -0,0 +1,49 @@ +"""Example script showing how to use Jaxite.""" + +import timeit + +from jaxite.jaxite_bool import jaxite_bool + + +bool_params = jaxite_bool.bool_params + +# Note: In real applications, a cryptographically secure seed needs to be +# used. +lwe_rng = bool_params.get_lwe_rng_for_128_bit_security(seed=1) +rlwe_rng = bool_params.get_rlwe_rng_for_128_bit_security(seed=1) +params = bool_params.get_params_for_128_bit_security() + +cks = jaxite_bool.ClientKeySet( + params, + lwe_rng=lwe_rng, + rlwe_rng=rlwe_rng, +) +print("Client keygen done") + +sks = jaxite_bool.ServerKeySet( + cks, + params, + lwe_rng=lwe_rng, + rlwe_rng=rlwe_rng, + bootstrap_callback=None, +) +print("Server keygen done.") + +ct_true = jaxite_bool.encrypt(True, cks, lwe_rng) +ct_false = jaxite_bool.encrypt(False, cks, lwe_rng) + +# Calling function once before timing it so compile-time doesn't get +# included in timing metircs. +and_gate = jaxite_bool.and_(ct_false, ct_true, sks, params) + +# Using Timeit +def timed_fn(): + and_gate = jaxite_bool.and_(ct_false, ct_true, sks, params) + and_gate.block_until_ready() +timer = timeit.Timer(timed_fn) +execution_time = timer.repeat(repeat=1, number=1) +print("And gate execution time: ", execution_time) + +actual = jaxite_bool.decrypt(and_gate, cks) +expected = False +print(f"{actual=}, {expected=}") diff --git a/scripts/gcp/examples/noop.py b/scripts/gcp/examples/noop.py new file mode 100644 index 000000000..eac04d358 --- /dev/null +++ b/scripts/gcp/examples/noop.py @@ -0,0 +1,3 @@ +"""Example script to test the GCP setup.""" + +print("Hello, World from GCP") diff --git a/scripts/gcp/provision.py b/scripts/gcp/provision.py new file mode 100644 index 000000000..38d2a806e --- /dev/null +++ b/scripts/gcp/provision.py @@ -0,0 +1,205 @@ +"""Provisions and destroys TPU resources on GCP.""" + +from typing import Optional + +from core import Core +from core import get_node +from core import stop_node +from core import tpu_client +import google.api_core.exceptions as core_exceptions +from google.cloud import compute_v1 +from google.cloud import tpu_v2 + + +def provision_subnet( + core: Core, name: str, subnet: Optional[str] = None +) -> compute_v1.Subnetwork: + """Provisions a subnet.""" + if subnet is not None: # fetch subnet + return compute_v1.SubnetworksClient().get( + project=core.project, region=core.region, subnetwork=subnet + ) + + print(f"Provisioning network {name=} ", end=" - ") + try: + compute_v1.NetworksClient().insert( + project=core.project, + network_resource=compute_v1.Network( + name=name, + auto_create_subnetworks=False, + routing_config=compute_v1.NetworkRoutingConfig( + routing_mode="REGIONAL" + ), + ), + ).result() + except core_exceptions.Conflict: + print("already exists") + else: + print("done") + nw = compute_v1.NetworksClient().get(project=core.project, network=name) + + print(f"Provisioning subnetwork {name=}", end=" - ", flush=True) + sn_op = None + try: + sn_op = compute_v1.SubnetworksClient().insert( + project=core.project, + region=core.region, + subnetwork_resource=compute_v1.Subnetwork( + name=name, + network=nw.self_link, + ip_cidr_range="10.0.0.0/16", + stack_type="IPV4_ONLY", + ), + ) + except core_exceptions.Conflict: + print("already exists") + else: + print("done") + + fw_name = f"{name}-allow-ssh" + print(f"Provisioning firewall rule {fw_name}", end=" - ", flush=True) + try: + compute_v1.FirewallsClient().insert( + project=core.project, + firewall_resource=compute_v1.Firewall( + name=fw_name, + network=nw.self_link, + direction="INGRESS", + allowed=[compute_v1.Allowed(I_p_protocol="tcp", ports=["22"])], + source_ranges=["0.0.0.0/0"], + priority=1000, + log_config=compute_v1.FirewallLogConfig(enable=False), + ), + ).result() + except core_exceptions.Conflict: + print("already exists") + else: + print("done") + + if sn_op is not None: + sn_op.result() # wait for subnet to be created + return compute_v1.SubnetworksClient().get( + project=core.project, region=core.region, subnetwork=name + ) + + +def provision( + core: Core, + name: str, + runtime_version: str = None, + accelerator_type: str = None, + subnet_url: Optional[str] = None, + keep_running: bool = False, +) -> None: + """Provisions a TPU node. + + Args: + core: + name: Name of the TPU node. + runtime_version: TPU runtime version. + accelerator_type: TPU accelerator type. + subnet_url: URL of the subnet to use. If not provided, a new subnet will be + created. + keep_running: If True, the TPU node will not be stopped after creation. + """ + if runtime_version is None or accelerator_type is None: + print("Runtime version and accelerator type are required but not provided") + + subnet = provision_subnet(core, name=name, subnet=subnet_url) + + req = tpu_v2.CreateNodeRequest( + parent=core.parent, + node_id=name, + node=tpu_v2.Node( + # https://cloud.google.com/php/docs/reference/cloud-tpu/1.1.1/V2.Node + name=name, + description="Created automagically for HEIR experiments", + accelerator_type=accelerator_type, + runtime_version=runtime_version, + shielded_instance_config={}, + network_config=tpu_v2.NetworkConfig( + network=subnet.network, + subnetwork=subnet.self_link, + enable_external_ips=True, + can_ip_forward=True, + ), + metadata={ + "startup-script": """ +#!/bin/bash +set -eux +pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html +pip install jaxite +""", + }, + # cidr_block !!! + # accelerator_config - already supplied accelerator_type ? + # service_account, scheduling_config, labels, tags, data_disks + ), + ) + + print(f"Creating TPU VM {name=}") + tpu_vm_client = tpu_v2.TpuClient() + try: + op = tpu_vm_client.create_node(request=req) + except core_exceptions.ResourceExhausted as e: + print(f"Resource exhausted: {e}") + print( + "Try later or different zone. Available zones:" + f" {list(core.available_zones().keys())}" + ) + return + + print("Waiting for operation to complete") + node = op.result() + print("TPU node has been created") + if not keep_running: + stop_node(node) + + +def destroy_tpu(core: Core, name: str) -> None: + print(f"Deleting TPU {name}", end=" - ", flush=True) + try: + node = get_node(core.parent, name) + except core_exceptions.NotFound: + print("TPU node not found") + return + tpu_client().delete_node(name=node.name).result() + print("done") + + +def destroy_network(core: Core, name: str) -> None: + """Deletes the network, subnetwork and firewall rule.""" + try: + fw_name = f"{name}-allow-ssh" + print(f"Deleting firewall rule {fw_name}", end=" - ", flush=True) + compute_v1.FirewallsClient().delete( + project=core.project, firewall=fw_name + ).result() + print("done") + except core_exceptions.NotFound: + print("not found") + + try: + print(f"Deleting subnetwork {name}", end=" - ", flush=True) + compute_v1.SubnetworksClient().delete( + project=core.project, region=core.region, subnetwork=name + ).result() + print("done") + except core_exceptions.NotFound: + print("not found") + + try: + print(f"Deleting network {name}", end=" - ", flush=True) + compute_v1.NetworksClient().delete( + project=core.project, network=name + ).result() + except core_exceptions.NotFound: + print("not found") + + +def destroy(core: Core, name: str, keep_network: bool = False) -> None: + try: + destroy_tpu(core, name) + finally: + if not keep_network: + destroy_network(core, name) diff --git a/scripts/gcp/requirements.txt b/scripts/gcp/requirements.txt new file mode 100644 index 000000000..3a0b08405 --- /dev/null +++ b/scripts/gcp/requirements.txt @@ -0,0 +1,2 @@ +google-cloud-compute +google-cloud-tpu diff --git a/scripts/gcp/tool b/scripts/gcp/tool new file mode 100644 index 000000000..8da7ef5cd --- /dev/null +++ b/scripts/gcp/tool @@ -0,0 +1,265 @@ +#!/usr/bin/env python3 + +import argparse +import subprocess +import sys +import time + +from core import Core +from core import get_node +from core import start_node +from core import stop_node +from google.cloud import tpu_v2 +from provision import destroy +from provision import provision + + +tpu_vm_cmd_pref = ["gcloud", "compute", "tpus", "tpu-vm"] + + +def _extract_short_name(node: tpu_v2.Node) -> str: + return node.name.split("/")[-1] + + +def _parse_node_name(node: tpu_v2.Node) -> tuple[str, str, str]: + _, project, _, zone, _, name = node.name.split("/") + return project, zone, name + + +def vm_run_ssh(node: tpu_v2.Node, cmd: str) -> None: + """ssh onto a TPU VM and run a command.""" + project, zone, name = _parse_node_name(node) + start = time.time() + print(f"Running '{cmd}' on {name} ...\n") + res = subprocess.run( + [ + *tpu_vm_cmd_pref, + "ssh", + name, + f"--project={project}", + f"--zone={zone}", + f"--command={cmd}", + ], + stderr=subprocess.PIPE, + ) + if res.returncode != 0: + print(res.stderr.decode()) + raise RuntimeError(f"Execution returned code is {res.returncode}") + print(f"\nDone, elapsed time: {time.time() - start} seconds") + + +def vm_scp(node: tpu_v2.Node, src_path: str, tgt_path: str) -> None: + project, zone, name = _parse_node_name(node) + print(f"Copying {src_path} to {name} ...") + subprocess.run( + [ + *tpu_vm_cmd_pref, + "scp", + src_path, + f"{name}:{tgt_path}", + f"--project={project}", + f"--zone={zone}", + ], + check=True, + ) + + +class UxError(Exception): + pass + + +def _parse_with_core( + parser: argparse.ArgumentParser, +) -> tuple[argparse.Namespace, Core]: + Core.add_args(parser) + args = parser.parse_args(sys.argv[2:]) + return args, Core.from_args(args) + + +class Tool: + """Tool to provision TPUs and run FHE programs on them.""" + + def __init__(self): + self._default_im_sa = "tpursuit-deployer" + + parser = argparse.ArgumentParser( + description="Tool to run executables on TPUs", + usage="""tool [] + run Run an executable on a TPU VM + list List all TPU VMs + + Managing infrastructure: + provision Provision a TPU VM and supporting infrastructure + destroy Destroy a TPU VM and supporting infrastructure + + Managing TPU VMs state: + stop Stop a TPU VM + start Start a TPU VM +""", + ) + parser.add_argument("command", help="Subcommand to run") + args = parser.parse_args(sys.argv[1:2]) + if not hasattr(self, args.command): + print("Unrecognized command") + parser.print_help() + exit(1) + getattr(self, args.command)() + + def list(self): + parser = argparse.ArgumentParser(description="List all TPU VMs") + parser.add_argument( + "--all", + action="store_true", + default=False, + help="List all available zones", + ) + args, core = _parse_with_core(parser) + + self._print_list(core.list_nodes(all_zones=args.all)) + + def _print_list(self, nodes) -> None: + if not nodes: + print("No VMs found") + return + + print("Name\tStatus\tAccelerator\tZone\tSubnet") + for node in nodes: + _, _, _, zone, _, name = node.name.split("/") + *_, sn = node.network_config.subnetwork.split("/") + print(f"{name}\t{node.state.name}\t{node.accelerator_type}\t{zone}\t{sn}") + + def run(self): + parser = argparse.ArgumentParser( + description="Run an executable on a TPU VM" + ) + parser.add_argument('--files', help="Files required to run fhe code") + parser.add_argument('--main', help="Main file") + parser.add_argument( + "--dest_dir", + help="Destination directory on the TPU VM to copy the files to", + default="fhe_code", + ) + parser.add_argument('--main', help="Main file") + parser.add_argument("--vm", type=str, help="Name of the TPU VM to run on") + parser.add_argument( + "--keep-running", + action="store_true", + default=False, + help="Keep the VM running after the script finishes", + ) + args, core = _parse_with_core(parser) + + nodes = core.list_nodes(all_zones=False) + if args.vm is None: + if len(nodes) == 0: + raise UxError("No VMs found") + if len(nodes) > 1: + self._print_list(nodes) + raise UxError("Multiple VMs found, please specify one") + + node = nodes[0] + print(f"Using only VM in zone={core.zone}: {_extract_short_name(node)}") + else: + match = [node for node in nodes if args.vm == _extract_short_name(node)] + if not match: + self._print_list(nodes) + raise UxError(f"VM {args.vm} not found in zone={core.zone}") + node = match[0] + + src_files_to_copy = args.files.split(",") + if args.main not in src_files_to_copy: + src_files_to_copy.append(args.main) + py_command = "python3 " + args.dest_dir + "/" + args.main.split("/")[-1] + + if node.state.name != "READY": # TODO:Handle states properly + start_node(node) + try: + for f in src_files_to_copy: + vm_scp(node, f, args.dest_dir + "/" + f.split("/")[-1]) + vm_run_ssh(node, py_command) + finally: + if not args.keep_running: + stop_node(node) + + def stop(self): + parser = argparse.ArgumentParser(description="Stop a TPU VM") + parser.add_argument("vm", type=str, help="Name of the TPU VM to stop") + args, core = _parse_with_core(parser) + + stop_node(get_node(core.parent, args.vm)) + + def start(self): + parser = argparse.ArgumentParser(description="Start a TPU VM") + parser.add_argument("vm", type=str, help="Name of the TPU VM to start") + args, core = _parse_with_core(parser) + + start_node(get_node(core.parent, args.vm)) + + def provision(self): + """Provision a TPU VM and supporting infrastructure.""" + parser = argparse.ArgumentParser( + description="Provision a TPU VM and supporting infrastructure" + ) + parser.add_argument("vm", type=str, help="Name of the TPU VM to provision") + parser.add_argument( + "--runtime-version", + type=str, + # https://cloud.google.com/tpu/docs/runtimes + default="v2-alpha-tpuv5-lite", + help="Runtime version", + ) + parser.add_argument( + "--accelerator-type", + type=str, + # https://cloud.google.com/tpu/docs/system-architecture-tpu-vm + default="v5litepod-8", + help="Accelerator type", + ) + parser.add_argument( + "--keep-running", + action="store_true", + default=False, + help="Keep the VM running after the script finishes", + ) + parser.add_argument( + "--subnet", + type=str, + help=( + "Subnet to use, if none provided a new network infrastructure will" + " be created" + ), + ) + args, core = _parse_with_core(parser) + + provision( + core=core, + name=args.vm, + runtime_version=args.runtime_version, + accelerator_type=args.accelerator_type, + subnet_url=args.subnet, + keep_running=args.keep_running, + ) + + def destroy(self): + """Destroys TPU VM and supporting infrastructure.""" + parser = argparse.ArgumentParser( + description="Destroy a TPU VM and supporting infrastructure" + ) + parser.add_argument("vm", type=str, help="Name of the TPU VM to destroy") + parser.add_argument( + "--keep-network", + action="store_true", + default=False, + help="Keep the network infrastructure", + ) + args, core = _parse_with_core(parser) + + destroy(core, args.vm, keep_network=args.keep_network) + + +if __name__ == "__main__": + try: + Tool() + except UxError as e: + print(e) + exit(1)