diff --git a/services/src/main/scala/cromwell/services/cost/GcpCostCatalogService.scala b/services/src/main/scala/cromwell/services/cost/GcpCostCatalogService.scala index 81b228b4304..504ee067ce5 100644 --- a/services/src/main/scala/cromwell/services/cost/GcpCostCatalogService.scala +++ b/services/src/main/scala/cromwell/services/cost/GcpCostCatalogService.scala @@ -7,6 +7,7 @@ import com.google.cloud.billing.v1._ import com.typesafe.config.Config import com.typesafe.scalalogging.LazyLogging import common.util.StringUtil.EnhancedToStringable +import common.validation.ErrorOr import common.validation.ErrorOr._ import common.validation.ErrorOr.ErrorOr import cromwell.services.ServiceRegistryActor.ServiceRegistryMessage @@ -18,9 +19,9 @@ import scala.jdk.CollectionConverters.IterableHasAsScala import java.time.temporal.ChronoUnit.SECONDS import scala.util.Using -case class CostCatalogKey(machineType: MachineType, +case class CostCatalogKey(resourceInfo: ResourceInfo, usageType: UsageType, - machineCustomization: MachineCustomization, + machineCustomization: Option[MachineCustomization], resourceType: ResourceType, region: String ) @@ -38,28 +39,47 @@ object CostCatalogKey { final val expectedSku = (".*?N1 Predefined Instance (Core|Ram) .*|" + ".*?N2 Custom Instance (Core|Ram) .*|" + - ".*?N2D AMD Custom Instance (Core|Ram) .*").r + ".*?N2D AMD Custom Instance (Core|Ram) .*|" + + "Nvidia Tesla V100 GPU .*|" + + "Nvidia Tesla P100 GPU .*|" + + "Nvidia Tesla P4 GPU .*|" + + "Nvidia Tesla T4 GPU .*").r + // TODO: seems like it will probably still match GPU strings with extra stuff in front - + // it just won't take any of those preceding characters + // What is the point of the .*? ?? def apply(sku: Sku): List[CostCatalogKey] = for { _ <- expectedSku.findFirstIn(sku.getDescription).toList - machineType <- MachineType.fromSku(sku).toList + resourceInfo <- ResourceInfo.fromSku(sku).toList resourceType <- ResourceType.fromSku(sku).toList usageType <- UsageType.fromSku(sku).toList - machineCustomization <- MachineCustomization.fromSku(sku).toList region <- sku.getServiceRegionsList.asScala.toList - } yield CostCatalogKey(machineType, usageType, machineCustomization, resourceType, region) + machineCustomization = if (resourceType == Gpu) None else Some(MachineCustomization.fromCpuOrRamSku(sku)) + } yield CostCatalogKey(resourceInfo, usageType, machineCustomization, resourceType, region) def apply(instantiatedVmInfo: InstantiatedVmInfo, resourceType: ResourceType): ErrorOr[CostCatalogKey] = - MachineType.fromGoogleMachineTypeString(instantiatedVmInfo.machineType).map { mType => - CostCatalogKey( - mType, + if (resourceType == Gpu) + for { + gpuInfo <- ErrorOr(instantiatedVmInfo.gpuInfo.get) // TODO: improve error message (default: "None.get") + gpuType <- GpuType.fromGpuInfo(gpuInfo) + } yield CostCatalogKey( + gpuType, UsageType.fromBoolean(instantiatedVmInfo.preemptible), - MachineCustomization.fromMachineTypeString(instantiatedVmInfo.machineType), - resourceType, + None, + Gpu, instantiatedVmInfo.region ) - } + else + MachineType.fromGoogleMachineTypeString(instantiatedVmInfo.machineType).map { mType => + CostCatalogKey( + mType, + UsageType.fromBoolean(instantiatedVmInfo.preemptible), + Some(MachineCustomization.fromMachineTypeString(instantiatedVmInfo.machineType)), + resourceType, + instantiatedVmInfo.region + ) + } } case class GcpCostLookupRequest(vmInfo: InstantiatedVmInfo, replyTo: ActorRef) extends ServiceRegistryMessage { @@ -116,6 +136,9 @@ object GcpCostCatalogService { s"Expected usage units of RAM to be 'GiBy.h'. Got ${usageUnit}".invalidNel } } + + // TODO: implement this + def calculateGpuPricePerHour(gpuSku: Sku, gpuCount: Long): ErrorOr[BigDecimal] = BigDecimal(1).validNel } /** @@ -200,8 +223,8 @@ class GcpCostCatalogService(serviceConfig: Config, globalConfig: Config, service // As of Sept 2024 the cost catalog does not contain entries for custom N1 machines. If we're using N1, attempt // to fall back to predefined. lazy val n1PredefinedKey = - (key.machineType, key.machineCustomization) match { - case (N1, Custom) => Option(key.copy(machineCustomization = Predefined)) + (key.resourceInfo, key.machineCustomization) match { + case (N1, Some(Custom)) => Option(key.copy(machineCustomization = Some(Predefined))) case _ => None } val sku = getSku(key).orElse(n1PredefinedKey.flatMap(getSku)).map(_.catalogObject) @@ -212,23 +235,47 @@ class GcpCostCatalogService(serviceConfig: Config, globalConfig: Config, service } // TODO consider caching this, answers won't change until we reload the SKUs - def calculateVmCostPerHour(instantiatedVmInfo: InstantiatedVmInfo): ErrorOr[BigDecimal] = - for { + def calculateVmCostPerHour(instantiatedVmInfo: InstantiatedVmInfo): ErrorOr[BigDecimal] = { + val cpuPricingInfoErrorOr = for { cpuSku <- lookUpSku(instantiatedVmInfo, Cpu) coreCount <- MachineType.extractCoreCountFromMachineTypeString(instantiatedVmInfo.machineType) cpuPricePerHour <- GcpCostCatalogService.calculateCpuPricePerHour(cpuSku, coreCount) + } yield (cpuSku, coreCount, cpuPricePerHour) + + val ramPricingInfoErrorOr = for { ramSku <- lookUpSku(instantiatedVmInfo, Ram) ramMbCount <- MachineType.extractRamMbFromMachineTypeString(instantiatedVmInfo.machineType) ramGbCount = ramMbCount / 1024d // need sub-integer resolution ramPricePerHour <- GcpCostCatalogService.calculateRamPricePerHour(ramSku, ramGbCount) - totalCost = cpuPricePerHour + ramPricePerHour + } yield (ramSku, ramGbCount, ramPricePerHour) + + val gpuPricingInfoErrorOr = instantiatedVmInfo.gpuInfo match { + case None => (None, 0, BigDecimal(0)).validNel + case Some(gpuInfo) => + for { + gpuSku <- lookUpSku(instantiatedVmInfo, Gpu) + gpuCount = gpuInfo.count + gpuPricePerHour <- GcpCostCatalogService.calculateGpuPricePerHour(gpuSku, gpuCount) + } yield (Some(gpuSku), gpuCount, gpuPricePerHour) + } + + for { + cpuPricingInfo <- cpuPricingInfoErrorOr + (cpuSku, coreCount, cpuPricePerHour) = cpuPricingInfo + ramPricingInfo <- ramPricingInfoErrorOr + (ramSku, ramGbCount, ramPricePerHour) = ramPricingInfo + gpuPricingInfo <- gpuPricingInfoErrorOr + (gpuSku, gpuCount, gpuPricePerHour) = gpuPricingInfo + totalCost = cpuPricePerHour + ramPricePerHour + gpuPricePerHour _ = logger.info( s"Calculated vmCostPerHour of ${totalCost} " + s"(CPU ${cpuPricePerHour} for ${coreCount} cores [${cpuSku.getDescription}], " + - s"RAM ${ramPricePerHour} for ${ramGbCount} Gb [${ramSku.getDescription}]) " + + s"RAM ${ramPricePerHour} for ${ramGbCount} Gb [${ramSku.getDescription}], " + + s"GPU ${gpuPricePerHour} for ${gpuCount} GPUs [${gpuSku.map(_.getDescription)}]) " + s"for ${instantiatedVmInfo}" ) } yield totalCost + } def serviceRegistryActor: ActorRef = serviceRegistry override def receive: Receive = { diff --git a/services/src/main/scala/cromwell/services/cost/GcpCostCatalogTypes.scala b/services/src/main/scala/cromwell/services/cost/GcpCostCatalogTypes.scala index d189de43f1b..9a03c258a16 100644 --- a/services/src/main/scala/cromwell/services/cost/GcpCostCatalogTypes.scala +++ b/services/src/main/scala/cromwell/services/cost/GcpCostCatalogTypes.scala @@ -6,29 +6,39 @@ import common.validation.ErrorOr.ErrorOr import java.util.regex.{Matcher, Pattern} +case class GpuInfo(count: Long, gpuType: String) + /* * Case class that contains information retrieved from Google about a VM that cromwell has started */ -case class InstantiatedVmInfo(region: String, machineType: String, preemptible: Boolean) +case class InstantiatedVmInfo(region: String, machineType: String, gpuInfo: Option[GpuInfo], preemptible: Boolean) /* * These types reflect hardcoded strings found in a google cost catalog. */ -sealed trait MachineType { def machineTypeName: String } -case object N1 extends MachineType { override val machineTypeName = "n1" } -case object N2 extends MachineType { override val machineTypeName = "n2" } -case object N2d extends MachineType { override val machineTypeName = "n2d" } +sealed trait ResourceInfo -object MachineType { - def fromSku(sku: Sku): Option[MachineType] = { +object ResourceInfo { + def fromSku(sku: Sku): Option[ResourceInfo] = { val tokenizedDescription = sku.getDescription.toLowerCase.split(" ") if (tokenizedDescription.contains(N1.machineTypeName)) Some(N1) else if (tokenizedDescription.contains(N2.machineTypeName)) Some(N2) else if (tokenizedDescription.contains(N2d.machineTypeName)) Some(N2d) + else if (tokenizedDescription.contains(NvidiaTeslaV100.gpuTypeName)) Some(NvidiaTeslaV100) + else if (tokenizedDescription.contains(NvidiaTeslaP100.gpuTypeName)) Some(NvidiaTeslaP100) + else if (tokenizedDescription.contains(NvidiaTeslaP4.gpuTypeName)) Some(NvidiaTeslaP4) + else if (tokenizedDescription.contains(NvidiaTeslaT4.gpuTypeName)) Some(NvidiaTeslaT4) else Option.empty } +} +sealed trait MachineType extends ResourceInfo { def machineTypeName: String } +case object N1 extends MachineType { override val machineTypeName = "n1" } +case object N2 extends MachineType { override val machineTypeName = "n2" } +case object N2d extends MachineType { override val machineTypeName = "n2d" } + +object MachineType { // expects a string that looks something like "n1-standard-1" or "custom-1-4096" def fromGoogleMachineTypeString(machineTypeString: String): ErrorOr[MachineType] = { val mType = machineTypeString.toLowerCase @@ -61,6 +71,24 @@ object MachineType { } } +sealed trait GpuType extends ResourceInfo { def gpuTypeName: String } +case object NvidiaTeslaV100 extends GpuType { override val gpuTypeName = "v100" } +case object NvidiaTeslaP100 extends GpuType { override val gpuTypeName = "p100" } +case object NvidiaTeslaP4 extends GpuType { override val gpuTypeName = "p4" } +case object NvidiaTeslaT4 extends GpuType { override val gpuTypeName = "t4" } + +object GpuType { + // expects GpuInfo with a GPU type that looks something like "nvidia-tesla-v100" + def fromGpuInfo(gpuInfo: GpuInfo): ErrorOr[GpuType] = { + val gpuType = gpuInfo.gpuType.toLowerCase + if (gpuType.endsWith("-v100")) NvidiaTeslaV100.validNel + else if (gpuType.endsWith("-p100")) NvidiaTeslaP100.validNel + else if (gpuType.endsWith("-p4")) NvidiaTeslaP4.validNel + else if (gpuType.endsWith("-t4")) NvidiaTeslaT4.validNel + else s"Unrecognized GPU type: $gpuType".invalidNel + } +} + sealed trait UsageType { def typeName: String } case object OnDemand extends UsageType { override val typeName = "ondemand" } case object Preemptible extends UsageType { override val typeName = "preemptible" } @@ -76,7 +104,6 @@ object UsageType { case true => Preemptible case false => OnDemand } - } sealed trait MachineCustomization { def customizationName: String } @@ -94,21 +121,22 @@ object MachineCustomization { - For non-N1 machines, both custom and predefined SKUs are included, custom ones include "Custom" in their description strings and predefined SKUs are only identifiable by the absence of "Custom." */ - def fromSku(sku: Sku): Option[MachineCustomization] = { + def fromCpuOrRamSku(sku: Sku): MachineCustomization = { val tokenizedDescription = sku.getDescription.toLowerCase.split(" ") // ex. "N1 Predefined Instance Core running in Montreal" - if (tokenizedDescription.contains(Predefined.customizationName)) Some(Predefined) + if (tokenizedDescription.contains(Predefined.customizationName)) Predefined // ex. "N2 Custom Instance Core running in Paris" - else if (tokenizedDescription.contains(Custom.customizationName)) Some(Custom) + else if (tokenizedDescription.contains(Custom.customizationName)) Custom // ex. "N2 Instance Core running in Paris" - else Some(Predefined) + else Predefined } } sealed trait ResourceType { def groupName: String } case object Cpu extends ResourceType { override val groupName = "cpu" } case object Ram extends ResourceType { override val groupName = "ram" } +case object Gpu extends ResourceType { override val groupName = "gpu" } object ResourceType { def fromSku(sku: Sku): Option[ResourceType] = { @@ -116,6 +144,7 @@ object ResourceType { sku.getCategory.getResourceGroup.toLowerCase match { case Cpu.groupName => Some(Cpu) case Ram.groupName => Some(Ram) + case Gpu.groupName => Some(Gpu) case "n1standard" if tokenizedDescription.contains("ram") => Some(Ram) case "n1standard" if tokenizedDescription.contains("core") => Some(Cpu) case _ => Option.empty diff --git a/services/src/test/scala/cromwell/services/cost/GcpCostCatalogServiceSpec.scala b/services/src/test/scala/cromwell/services/cost/GcpCostCatalogServiceSpec.scala index 3f295cb189b..39e4e5ed95f 100644 --- a/services/src/test/scala/cromwell/services/cost/GcpCostCatalogServiceSpec.scala +++ b/services/src/test/scala/cromwell/services/cost/GcpCostCatalogServiceSpec.scala @@ -80,9 +80,9 @@ class GcpCostCatalogServiceSpec it should "cache catalogs properly" in { val testLookupKey = CostCatalogKey( - machineType = N2, + resourceInfo = N2, usageType = Preemptible, - machineCustomization = Predefined, + machineCustomization = Some(Predefined), resourceType = Cpu, region = "europe-west9" ) @@ -110,30 +110,30 @@ class GcpCostCatalogServiceSpec it should "find CPU and RAM skus for all supported machine types" in { val lookupRows = Table( ("machineType", "usage", "customization", "resource", "region", "exists"), - (N1, Preemptible, Predefined, Cpu, "us-west1", true), - (N1, Preemptible, Predefined, Ram, "us-west1", true), - (N1, OnDemand, Predefined, Cpu, "us-west1", true), - (N1, OnDemand, Predefined, Ram, "us-west1", true), - (N1, Preemptible, Custom, Cpu, "us-west1", false), - (N1, Preemptible, Custom, Ram, "us-west1", false), - (N1, OnDemand, Custom, Cpu, "us-west1", false), - (N1, OnDemand, Custom, Ram, "us-west1", false), - (N2, Preemptible, Predefined, Cpu, "us-west1", false), - (N2, Preemptible, Predefined, Ram, "us-west1", false), - (N2, OnDemand, Predefined, Cpu, "us-west1", false), - (N2, OnDemand, Predefined, Ram, "us-west1", false), - (N2, Preemptible, Custom, Cpu, "us-west1", true), - (N2, Preemptible, Custom, Ram, "us-west1", true), - (N2, OnDemand, Custom, Cpu, "us-west1", true), - (N2, OnDemand, Custom, Ram, "us-west1", true), - (N2d, Preemptible, Predefined, Cpu, "us-west1", false), - (N2d, Preemptible, Predefined, Ram, "us-west1", false), - (N2d, OnDemand, Predefined, Cpu, "us-west1", false), - (N2d, OnDemand, Predefined, Ram, "us-west1", false), - (N2d, Preemptible, Custom, Cpu, "us-west1", true), - (N2d, Preemptible, Custom, Ram, "us-west1", true), - (N2d, OnDemand, Custom, Cpu, "us-west1", true), - (N2d, OnDemand, Custom, Ram, "us-west1", true) + (N1, Preemptible, Some(Predefined), Cpu, "us-west1", true), + (N1, Preemptible, Some(Predefined), Ram, "us-west1", true), + (N1, OnDemand, Some(Predefined), Cpu, "us-west1", true), + (N1, OnDemand, Some(Predefined), Ram, "us-west1", true), + (N1, Preemptible, Some(Custom), Cpu, "us-west1", false), + (N1, Preemptible, Some(Custom), Ram, "us-west1", false), + (N1, OnDemand, Some(Custom), Cpu, "us-west1", false), + (N1, OnDemand, Some(Custom), Ram, "us-west1", false), + (N2, Preemptible, Some(Predefined), Cpu, "us-west1", false), + (N2, Preemptible, Some(Predefined), Ram, "us-west1", false), + (N2, OnDemand, Some(Predefined), Cpu, "us-west1", false), + (N2, OnDemand, Some(Predefined), Ram, "us-west1", false), + (N2, Preemptible, Some(Custom), Cpu, "us-west1", true), + (N2, Preemptible, Some(Custom), Ram, "us-west1", true), + (N2, OnDemand, Some(Custom), Cpu, "us-west1", true), + (N2, OnDemand, Some(Custom), Ram, "us-west1", true), + (N2d, Preemptible, Some(Predefined), Cpu, "us-west1", false), + (N2d, Preemptible, Some(Predefined), Ram, "us-west1", false), + (N2d, OnDemand, Some(Predefined), Cpu, "us-west1", false), + (N2d, OnDemand, Some(Predefined), Ram, "us-west1", false), + (N2d, Preemptible, Some(Custom), Cpu, "us-west1", true), + (N2d, Preemptible, Some(Custom), Ram, "us-west1", true), + (N2d, OnDemand, Some(Custom), Cpu, "us-west1", true), + (N2d, OnDemand, Some(Custom), Ram, "us-west1", true) ) forAll(lookupRows) { case (machineType, usage, customization, resource, region, exists: Boolean) => @@ -146,64 +146,67 @@ class GcpCostCatalogServiceSpec it should "find the skus for a VM when appropriate" in { val lookupRows = Table( ("instantiatedVmInfo", "resource", "skuDescription"), - (InstantiatedVmInfo("europe-west9", "custom-16-32768", false), + (InstantiatedVmInfo("europe-west9", "custom-16-32768", None, false), Cpu, "N1 Predefined Instance Core running in Paris" ), - (InstantiatedVmInfo("europe-west9", "custom-16-32768", false), + (InstantiatedVmInfo("europe-west9", "custom-16-32768", None, false), Ram, "N1 Predefined Instance Ram running in Paris" ), - (InstantiatedVmInfo("us-central1", "custom-4-4096", true), + (InstantiatedVmInfo("us-central1", "custom-4-4096", None, true), Cpu, "Spot Preemptible N1 Predefined Instance Core running in Americas" ), - (InstantiatedVmInfo("us-central1", "custom-4-4096", true), + (InstantiatedVmInfo("us-central1", "custom-4-4096", None, true), Ram, "Spot Preemptible N1 Predefined Instance Ram running in Americas" ), - (InstantiatedVmInfo("europe-west9", "n1-custom-16-32768", false), + (InstantiatedVmInfo("europe-west9", "n1-custom-16-32768", None, false), Cpu, "N1 Predefined Instance Core running in Paris" ), - (InstantiatedVmInfo("europe-west9", "n1-custom-16-32768", false), + (InstantiatedVmInfo("europe-west9", "n1-custom-16-32768", None, false), Ram, "N1 Predefined Instance Ram running in Paris" ), - (InstantiatedVmInfo("us-central1", "n1-custom-4-4096", true), + (InstantiatedVmInfo("us-central1", "n1-custom-4-4096", None, true), Cpu, "Spot Preemptible N1 Predefined Instance Core running in Americas" ), - (InstantiatedVmInfo("us-central1", "n1-custom-4-4096", true), + (InstantiatedVmInfo("us-central1", "n1-custom-4-4096", None, true), Ram, "Spot Preemptible N1 Predefined Instance Ram running in Americas" ), - (InstantiatedVmInfo("us-central1", "n2-custom-4-4096", true), + (InstantiatedVmInfo("us-central1", "n2-custom-4-4096", None, true), Cpu, "Spot Preemptible N2 Custom Instance Core running in Americas" ), - (InstantiatedVmInfo("us-central1", "n2-custom-4-4096", true), + (InstantiatedVmInfo("us-central1", "n2-custom-4-4096", None, true), Ram, "Spot Preemptible N2 Custom Instance Ram running in Americas" ), - (InstantiatedVmInfo("us-central1", "n2-custom-4-4096", false), + (InstantiatedVmInfo("us-central1", "n2-custom-4-4096", None, false), Cpu, "N2 Custom Instance Core running in Americas" ), - (InstantiatedVmInfo("us-central1", "n2-custom-4-4096", false), Ram, "N2 Custom Instance Ram running in Americas"), - (InstantiatedVmInfo("us-central1", "n2d-custom-4-4096", true), + (InstantiatedVmInfo("us-central1", "n2-custom-4-4096", None, false), + Ram, + "N2 Custom Instance Ram running in Americas" + ), + (InstantiatedVmInfo("us-central1", "n2d-custom-4-4096", None, true), Cpu, "Spot Preemptible N2D AMD Custom Instance Core running in Americas" ), - (InstantiatedVmInfo("us-central1", "n2d-custom-4-4096", true), + (InstantiatedVmInfo("us-central1", "n2d-custom-4-4096", None, true), Ram, "Spot Preemptible N2D AMD Custom Instance Ram running in Americas" ), - (InstantiatedVmInfo("us-central1", "n2d-custom-4-4096", false), + (InstantiatedVmInfo("us-central1", "n2d-custom-4-4096", None, false), Cpu, "N2D AMD Custom Instance Core running in Americas" ), - (InstantiatedVmInfo("us-central1", "n2d-custom-4-4096", false), + (InstantiatedVmInfo("us-central1", "n2d-custom-4-4096", None, false), Ram, "N2D AMD Custom Instance Ram running in Americas" ) @@ -219,21 +222,21 @@ class GcpCostCatalogServiceSpec it should "fail to find the skus for a VM when appropriate" in { val lookupRows = Table( ("instantiatedVmInfo", "resource", "errors"), - (InstantiatedVmInfo("us-central1", "custooooooom-4-4096", true), + (InstantiatedVmInfo("us-central1", "custooooooom-4-4096", None, true), Cpu, List("Unrecognized machine type: custooooooom-4-4096") ), - (InstantiatedVmInfo("us-central1", "n2custom-4-4096", true), + (InstantiatedVmInfo("us-central1", "n2custom-4-4096", None, true), Cpu, List("Unrecognized machine type: n2custom-4-4096") ), - (InstantiatedVmInfo("us-central1", "standard-4-4096", true), + (InstantiatedVmInfo("us-central1", "standard-4-4096", None, true), Cpu, List("Unrecognized machine type: standard-4-4096") ), - (InstantiatedVmInfo("planet-mars1", "custom-4-4096", true), + (InstantiatedVmInfo("planet-mars1", "custom-4-4096", None, true), Cpu, - List("Failed to look up Cpu SKU for InstantiatedVmInfo(planet-mars1,custom-4-4096,true)") + List("Failed to look up Cpu SKU for InstantiatedVmInfo(planet-mars1,custom-4-4096,None,true)") ) ) @@ -249,18 +252,18 @@ class GcpCostCatalogServiceSpec // Create BigDecimals from strings to avoid inequality due to floating point shenanigans val lookupRows = Table( ("instantiatedVmInfo", "costPerHour"), - (InstantiatedVmInfo("us-central1", "custom-4-4096", true), BigDecimal(".0361")), - (InstantiatedVmInfo("us-central1", "n2-custom-4-4096", true), BigDecimal(".04254400000000000480")), - (InstantiatedVmInfo("us-central1", "n2d-custom-4-4096", true), BigDecimal(".02371600000000000040")), - (InstantiatedVmInfo("us-central1", "custom-4-4096", false), BigDecimal(".143392")), - (InstantiatedVmInfo("us-central1", "n2-custom-4-4096", false), BigDecimal(".150561600")), - (InstantiatedVmInfo("us-central1", "n2d-custom-4-4096", false), BigDecimal(".130989600000000012")), - (InstantiatedVmInfo("europe-west9", "custom-4-4096", true), BigDecimal(".035018080000000004")), - (InstantiatedVmInfo("europe-west9", "n2-custom-4-4096", true), BigDecimal("0.049532000000000004")), - (InstantiatedVmInfo("europe-west9", "n2d-custom-4-4096", true), BigDecimal("0.030608000000000004")), - (InstantiatedVmInfo("europe-west9", "custom-4-4096", false), BigDecimal(".1663347200000000040")), - (InstantiatedVmInfo("europe-west9", "n2-custom-4-4352", false), BigDecimal(".17594163050")), - (InstantiatedVmInfo("europe-west9", "n2d-custom-4-4096", false), BigDecimal(".151947952")) + (InstantiatedVmInfo("us-central1", "custom-4-4096", None, true), BigDecimal(".0361")), + (InstantiatedVmInfo("us-central1", "n2-custom-4-4096", None, true), BigDecimal(".04254400000000000480")), + (InstantiatedVmInfo("us-central1", "n2d-custom-4-4096", None, true), BigDecimal(".02371600000000000040")), + (InstantiatedVmInfo("us-central1", "custom-4-4096", None, false), BigDecimal(".143392")), + (InstantiatedVmInfo("us-central1", "n2-custom-4-4096", None, false), BigDecimal(".150561600")), + (InstantiatedVmInfo("us-central1", "n2d-custom-4-4096", None, false), BigDecimal(".130989600000000012")), + (InstantiatedVmInfo("europe-west9", "custom-4-4096", None, true), BigDecimal(".035018080000000004")), + (InstantiatedVmInfo("europe-west9", "n2-custom-4-4096", None, true), BigDecimal("0.049532000000000004")), + (InstantiatedVmInfo("europe-west9", "n2d-custom-4-4096", None, true), BigDecimal("0.030608000000000004")), + (InstantiatedVmInfo("europe-west9", "custom-4-4096", None, false), BigDecimal(".1663347200000000040")), + (InstantiatedVmInfo("europe-west9", "n2-custom-4-4352", None, false), BigDecimal(".17594163050")), + (InstantiatedVmInfo("europe-west9", "n2d-custom-4-4096", None, false), BigDecimal(".151947952")) ) forAll(lookupRows) { case (instantiatedVmInfo: InstantiatedVmInfo, expectedCostPerHour: BigDecimal) => @@ -274,24 +277,24 @@ class GcpCostCatalogServiceSpec val lookupRows = Table( ("instantiatedVmInfo", "errors"), - (InstantiatedVmInfo("us-central1", "custooooooom-4-4096", true), + (InstantiatedVmInfo("us-central1", "custooooooom-4-4096", None, true), List("Unrecognized machine type: custooooooom-4-4096") ), - (InstantiatedVmInfo("us-central1", "n2_custom_4_4096", true), + (InstantiatedVmInfo("us-central1", "n2_custom_4_4096", None, true), List("Unrecognized machine type: n2_custom_4_4096") ), - (InstantiatedVmInfo("us-central1", "custom-foo-4096", true), + (InstantiatedVmInfo("us-central1", "custom-foo-4096", None, true), List("Could not extract core count from custom-foo-4096") ), - (InstantiatedVmInfo("us-central1", "custom-16-bar", true), + (InstantiatedVmInfo("us-central1", "custom-16-bar", None, true), List("Could not extract Ram MB count from custom-16-bar") ), - (InstantiatedVmInfo("us-central1", "123-456-789", true), List("Unrecognized machine type: 123-456-789")), - (InstantiatedVmInfo("us-central1", "n2-16-4096", true), - List("Failed to look up Cpu SKU for InstantiatedVmInfo(us-central1,n2-16-4096,true)") + (InstantiatedVmInfo("us-central1", "123-456-789", None, true), List("Unrecognized machine type: 123-456-789")), + (InstantiatedVmInfo("us-central1", "n2-16-4096", None, true), + List("Failed to look up Cpu SKU for InstantiatedVmInfo(us-central1,n2-16-4096,None,true)") ), - (InstantiatedVmInfo("planet-mars1", "n2-custom-4-4096", true), - List("Failed to look up Cpu SKU for InstantiatedVmInfo(planet-mars1,n2-custom-4-4096,true)") + (InstantiatedVmInfo("planet-mars1", "n2-custom-4-4096", None, true), + List("Failed to look up Cpu SKU for InstantiatedVmInfo(planet-mars1,n2-custom-4-4096,None,true)") ) ) diff --git a/supportedBackends/google/batch/src/main/scala/cromwell/backend/google/batch/api/request/BatchRequestExecutor.scala b/supportedBackends/google/batch/src/main/scala/cromwell/backend/google/batch/api/request/BatchRequestExecutor.scala index 38ebc66cf49..f2ce7c53a98 100644 --- a/supportedBackends/google/batch/src/main/scala/cromwell/backend/google/batch/api/request/BatchRequestExecutor.scala +++ b/supportedBackends/google/batch/src/main/scala/cromwell/backend/google/batch/api/request/BatchRequestExecutor.scala @@ -145,7 +145,7 @@ object BatchRequestExecutor { // Get instances that can be created with this AllocationPolicy, only instances[0] is supported val instancePolicy = allocationPolicy.getInstances(0).getPolicy val machineType = instancePolicy.getMachineType - val preemtible = instancePolicy.getProvisioningModelValue == ProvisioningModel.PREEMPTIBLE.getNumber + val preemptible = instancePolicy.getProvisioningModelValue == ProvisioningModel.PREEMPTIBLE.getNumber // location list = [regions/us-central1, zones/us-central1-b], region is the first element val location = allocationPolicy.getLocation.getAllowedLocationsList.get(0) @@ -155,7 +155,8 @@ object BatchRequestExecutor { else location.split("/").last - val instantiatedVmInfo = Some(InstantiatedVmInfo(region, machineType, preemtible)) + // TODO: include GPU info + val instantiatedVmInfo = Some(InstantiatedVmInfo(region, machineType, None, preemptible)) if (job.getStatus.getState == JobStatus.State.SUCCEEDED) { RunStatus.Success(events, instantiatedVmInfo) diff --git a/supportedBackends/google/pipelines/v2beta/src/main/scala/cromwell/backend/google/pipelines/v2beta/api/Deserialization.scala b/supportedBackends/google/pipelines/v2beta/src/main/scala/cromwell/backend/google/pipelines/v2beta/api/Deserialization.scala index 8699983ca63..c41316b5ef4 100644 --- a/supportedBackends/google/pipelines/v2beta/src/main/scala/cromwell/backend/google/pipelines/v2beta/api/Deserialization.scala +++ b/supportedBackends/google/pipelines/v2beta/src/main/scala/cromwell/backend/google/pipelines/v2beta/api/Deserialization.scala @@ -117,6 +117,22 @@ private[api] object Deserialization { case (Some(f), number: Number) if f.getType == classOf[java.lang.Double] => newT.set(key, number.doubleValue()) case (Some(f), number: Number) if f.getType == classOf[java.lang.Float] => newT.set(key, number.floatValue()) case (Some(f), number: Number) if f.getType == classOf[java.lang.Long] => newT.set(key, number.longValue()) + // AN-144, 12/12/2024 + // These cases handle the possibility that a field of a numeric type has a corresponding numeric value + // represented as a string. This can happen if the value was represented as a string in the original JSON. + // Currently, only the last of these 4 cases is used: Google stores the count field of the Accelerator class as + // a long, but the JSON value of that field is a number represented as a string. The other cases were added to + // provide similar support for other numeric types in case Google adds new fields represented in a similar way + // in the future. If a field of a numeric type has a string value that cannot be parsed to the correct numeric + // type, it will be skipped (see comment below for rationale). + case (Some(f), string: String) if f.getType == classOf[java.lang.Integer] => + string.toIntOption.foreach(intValue => newT.set(key, intValue)) + case (Some(f), string: String) if f.getType == classOf[java.lang.Double] => + string.toDoubleOption.foreach(doubleValue => newT.set(key, doubleValue)) + case (Some(f), string: String) if f.getType == classOf[java.lang.Float] => + string.toFloatOption.foreach(floatValue => newT.set(key, floatValue)) + case (Some(f), string: String) if f.getType == classOf[java.lang.Long] => + string.toLongOption.foreach(longValue => newT.set(key, longValue)) // If either the key is not an attribute of T, or we can't assign it - just skip it // The only effect is that an attribute might not be populated and would be null. // We would only notice if we do look at this attribute though, which we only do with the purpose of populating metadata diff --git a/supportedBackends/google/pipelines/v2beta/src/main/scala/cromwell/backend/google/pipelines/v2beta/api/request/GetRequestHandler.scala b/supportedBackends/google/pipelines/v2beta/src/main/scala/cromwell/backend/google/pipelines/v2beta/api/request/GetRequestHandler.scala index 4a1d780685f..a0e5d4d2cfa 100644 --- a/supportedBackends/google/pipelines/v2beta/src/main/scala/cromwell/backend/google/pipelines/v2beta/api/request/GetRequestHandler.scala +++ b/supportedBackends/google/pipelines/v2beta/src/main/scala/cromwell/backend/google/pipelines/v2beta/api/request/GetRequestHandler.scala @@ -22,7 +22,7 @@ import cromwell.backend.google.pipelines.v2beta.api.Deserialization._ import cromwell.backend.google.pipelines.v2beta.api.request.ErrorReporter._ import cromwell.cloudsupport.gcp.auth.GoogleAuthMode import cromwell.core.ExecutionEvent -import cromwell.services.cost.InstantiatedVmInfo +import cromwell.services.cost.{GpuInfo, InstantiatedVmInfo} import cromwell.services.metadata.CallMetadataKeys import io.grpc.Status import org.apache.commons.lang3.exception.ExceptionUtils @@ -115,9 +115,30 @@ trait GetRequestHandler { this: RequestHandler => if (lastDashIndex != -1) zoneString.substring(0, lastDashIndex) else zoneString } + val gpuInfo: Option[GpuInfo] = for { + pipelineValue <- pipeline + resources <- Option(pipelineValue.getResources) + virtualMachine <- Option(resources.getVirtualMachine) + gpusList <- Option(virtualMachine.getAccelerators) + gpus <- { + if (gpusList.size > 1) { + // TODO: Improve this warning + // - Log appears repeatedly while task is running + // - Improve formatting of accelerator info + // - Include workflow/task ID? + logger.warn( + s"Multiple GPU types present ($gpusList) for a single task. Only the first will be used for cost calculations." + ) + } + gpusList.asScala.headOption + } + } yield GpuInfo(gpus.getCount, gpus.getType) + + // Unlike with region and machineType, gpuInfo's being None does not indicate an invalid + // result - it just means no GPUs are being used by the VM val instantiatedVmInfo: Option[InstantiatedVmInfo] = (region, machineType) match { case (Some(instantiatedRegion), Some(instantiatedMachineType)) => - Option(InstantiatedVmInfo(instantiatedRegion, instantiatedMachineType, preemptible)) + Option(InstantiatedVmInfo(instantiatedRegion, instantiatedMachineType, gpuInfo, preemptible)) case _ => Option.empty } if (operation.getDone) { diff --git a/supportedBackends/google/pipelines/v2beta/src/test/scala/cromwell/backend/google/pipelines/v2beta/api/DeserializationSpec.scala b/supportedBackends/google/pipelines/v2beta/src/test/scala/cromwell/backend/google/pipelines/v2beta/api/DeserializationSpec.scala index c2e8920301b..76517c87826 100644 --- a/supportedBackends/google/pipelines/v2beta/src/test/scala/cromwell/backend/google/pipelines/v2beta/api/DeserializationSpec.scala +++ b/supportedBackends/google/pipelines/v2beta/src/test/scala/cromwell/backend/google/pipelines/v2beta/api/DeserializationSpec.scala @@ -85,7 +85,13 @@ class DeserializationSpec extends AnyFlatSpec with CromwellTimeoutSpec with Matc "projectId" -> "project", "virtualMachine" -> Map[String, Any]( "machineType" -> "custom-1-1024", - "preemptible" -> false + "preemptible" -> false, + "accelerators" -> List[java.util.Map[String, String]]( + Map[String, String]( + "type" -> "nvidia-tesla-t4", + "count" -> "2" + ).asJava + ).asJava ).asJava ).asJava ).asJava @@ -100,6 +106,7 @@ class DeserializationSpec extends AnyFlatSpec with CromwellTimeoutSpec with Matc val virtualMachine = deserializedPipeline.getResources.getVirtualMachine virtualMachine.getMachineType shouldBe "custom-1-1024" virtualMachine.getPreemptible shouldBe false + virtualMachine.getAccelerators.get(0).getCount shouldBe 2 } // https://github.com/broadinstitute/cromwell/issues/4772 @@ -192,4 +199,31 @@ class DeserializationSpec extends AnyFlatSpec with CromwellTimeoutSpec with Matc } } + it should "deserialize numbers represented as strings while skipping invalid string values" in { + val numericFieldsToStringsMap = Map[String, Object]( + "validIntegerValue" -> "5", + "validDoubleValue" -> "6", + "validFloatValue" -> "7", + "validLongValue" -> "8", + "invalidIntegerValue" -> "5!", + "invalidDoubleValue" -> "pi", + "invalidFloatValue" -> "3 point 1 4", + "invalidLongValue" -> "looooooooooong" + ).asJava + + val deserialized = Deserialization.deserializeTo[StringToNumberDeserializationTestClass](numericFieldsToStringsMap) + deserialized match { + case Success(deserializedSuccess) => + deserializedSuccess.validIntegerValue shouldBe 5 + deserializedSuccess.validDoubleValue shouldBe 6d + deserializedSuccess.validFloatValue shouldBe 7f + deserializedSuccess.validLongValue shouldBe 8L + deserializedSuccess.invalidIntegerValue shouldBe null + deserializedSuccess.invalidDoubleValue shouldBe null + deserializedSuccess.invalidFloatValue shouldBe null + deserializedSuccess.invalidLongValue shouldBe null + case Failure(f) => + fail("Bad deserialization", f) + } + } } diff --git a/supportedBackends/google/pipelines/v2beta/src/test/scala/cromwell/backend/google/pipelines/v2beta/api/StringToNumberDeserializationTestClass.java b/supportedBackends/google/pipelines/v2beta/src/test/scala/cromwell/backend/google/pipelines/v2beta/api/StringToNumberDeserializationTestClass.java new file mode 100644 index 00000000000..11e1883bdf6 --- /dev/null +++ b/supportedBackends/google/pipelines/v2beta/src/test/scala/cromwell/backend/google/pipelines/v2beta/api/StringToNumberDeserializationTestClass.java @@ -0,0 +1,29 @@ +package cromwell.backend.google.pipelines.v2beta.api; + +import com.google.api.client.json.GenericJson; + +public class StringToNumberDeserializationTestClass extends GenericJson { + @com.google.api.client.util.Key + public java.lang.Integer validIntegerValue; + + @com.google.api.client.util.Key + public java.lang.Double validDoubleValue; + + @com.google.api.client.util.Key + public java.lang.Float validFloatValue; + + @com.google.api.client.util.Key + public java.lang.Long validLongValue; + + @com.google.api.client.util.Key + public java.lang.Integer invalidIntegerValue; + + @com.google.api.client.util.Key + public java.lang.Double invalidDoubleValue; + + @com.google.api.client.util.Key + public java.lang.Float invalidFloatValue; + + @com.google.api.client.util.Key + public java.lang.Long invalidLongValue; +}