Skip to content

Commit

Permalink
more tests + add support for missing location
Browse files Browse the repository at this point in the history
  • Loading branch information
lucymcnatt committed Dec 6, 2024
1 parent 2af8d35 commit 13aed93
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 51 deletions.
23 changes: 0 additions & 23 deletions runConfigurations/Repo template_ Cromwell server GCPBATCH.run.xml

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,20 @@ class BatchPollResultMonitorActor(pollMonitorParameters: PollMonitorParameters)

override def extractEarliestEventTimeFromRunState(pollStatus: RunStatus): Option[OffsetDateTime] =
pollStatus.eventList.minByOption(_.offsetDateTime).map(e => e.offsetDateTime)
override def extractStartTimeFromRunState(pollStatus: RunStatus): Option[OffsetDateTime] =
override def extractStartTimeFromRunState(pollStatus: RunStatus): Option[OffsetDateTime] = {
pollStatus.eventList.collectFirst {
case event if event.name == CallMetadataKeys.VmStartTime => event.offsetDateTime
}
}

override def extractEndTimeFromRunState(pollStatus: RunStatus): Option[OffsetDateTime] =
pollStatus.eventList.collectFirst {
case event if event.name == CallMetadataKeys.VmEndTime => event.offsetDateTime
}

override def extractVmInfoFromRunState(pollStatus: RunStatus): Option[InstantiatedVmInfo] =
override def extractVmInfoFromRunState(pollStatus: RunStatus): Option[InstantiatedVmInfo] = {
pollStatus.instantiatedVmInfo
}

override def handleVmCostLookup(vmInfo: InstantiatedVmInfo) = {
val request = GcpCostLookupRequest(vmInfo, self)
Expand All @@ -67,6 +69,7 @@ class BatchPollResultMonitorActor(pollMonitorParameters: PollMonitorParameters)
)
BigDecimal(-1)
}
params.logger.info(s"vmCostPerHour for ${costLookupResponse.vmInfo} is $cost")
vmCostPerHour = Option(cost)
tellMetadata(Map(CallMetadataKeys.VmCostPerHour -> cost))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1013,6 +1013,18 @@ class GcpBatchAsyncBackendJobExecutionActor(override val standardParams: Standar
} yield status
}

override val pollingResultMonitorActor: Option[ActorRef] = Option(
context.actorOf(
BatchPollResultMonitorActor.props(serviceRegistryActor,
workflowDescriptor,
jobDescriptor,
validatedRuntimeAttributes,
platform,
jobLogger
)
)
)

override def isTerminal(runStatus: RunStatus): Boolean =
runStatus match {
case _: RunStatus.TerminalRunStatus => true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import cromwell.backend.google.batch.api.{BatchApiRequestManager, BatchApiRespon
import cromwell.backend.google.batch.models.{GcpBatchExitCode, RunStatus}
import cromwell.core.ExecutionEvent
import cromwell.services.cost.InstantiatedVmInfo
import cromwell.services.metadata.CallMetadataKeys

import scala.annotation.unused
import scala.concurrent.{ExecutionContext, Future, Promise}
Expand Down Expand Up @@ -146,9 +147,14 @@ object BatchRequestExecutor {
val machineType = instancePolicy.getMachineType
val preemtible = instancePolicy.getProvisioningModelValue == ProvisioningModel.PREEMPTIBLE.getNumber

// Each location can be a region or a zone. Only one region is supported, ex: "regions/us-central1"
val location = allocationPolicy.getLocation.getAllowedLocations(0)
val region = location.split("/").last
// location list = [regions/us-central1, zones/us-central1-b], region is the first element
val location = allocationPolicy.getLocation.getAllowedLocationsList.get(0)
val region =
if (location.isEmpty)
"us-central1"
else
location.split("/").last

val instantiatedVmInfo = Some(InstantiatedVmInfo(region, machineType, preemtible))

if (job.getStatus.getState == JobStatus.State.SUCCEEDED) {
Expand All @@ -167,12 +173,20 @@ object BatchRequestExecutor {
GcpBatchExitCode.fromEventMessage(e.name.toLowerCase)
}.headOption

private def getEventList(events: List[StatusEvent]): List[ExecutionEvent] =
private def getEventList(events: List[StatusEvent]): List[ExecutionEvent] = {
val startedRegex = ".*SCHEDULED to RUNNING.*".r
val endedRegex = ".*RUNNING to.*".r // can be SUCCEEDED or FAILED
events.map { e =>
val time = java.time.Instant
.ofEpochSecond(e.getEventTime.getSeconds, e.getEventTime.getNanos.toLong)
.atOffset(java.time.ZoneOffset.UTC)
ExecutionEvent(name = e.getDescription, offsetDateTime = time)
val eventType = e.getDescription match {
case startedRegex() => CallMetadataKeys.VmStartTime
case endedRegex() => CallMetadataKeys.VmEndTime
case _ => e.getType
}
ExecutionEvent(name = eventType, offsetDateTime = time)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,10 @@ package cromwell.backend.google.batch.api.request

import akka.actor.ActorSystem
import akka.testkit.TestKit
import com.google.cloud.batch.v1.{
AllocationPolicy,
BatchServiceClient,
BatchServiceSettings,
GetJobRequest,
Job,
JobStatus
}
import com.google.cloud.batch.v1.AllocationPolicy.{
InstancePolicy,
InstancePolicyOrTemplate,
LocationPolicy,
ProvisioningModel
}
import com.google.cloud.batch.v1.{AllocationPolicy, BatchServiceClient, BatchServiceSettings, GetJobRequest, Job, JobStatus, StatusEvent}
import com.google.cloud.batch.v1.AllocationPolicy.{InstancePolicy, InstancePolicyOrTemplate, LocationPolicy, ProvisioningModel}
import com.google.cloud.batch.v1.JobStatus.State
import com.google.protobuf.Timestamp
import common.mock.MockSugar
import cromwell.backend.google.batch.api.BatchApiResponse
import cromwell.backend.google.batch.models.RunStatus
Expand All @@ -32,30 +22,57 @@ class BatchRequestExecutorSpec
with MockSugar
with PrivateMethodTester {

behavior of "BatchRequestExecutor"

it should "create instantiatedVmInfo correctly" in {

def setupBatchClient(machineType: String = "n1-standard-1",
location: String = "regions/us-central1",
jobState: State = JobStatus.State.SUCCEEDED
): BatchServiceClient = {
val instancePolicy = InstancePolicy
.newBuilder()
.setMachineType("n1-standard-1")
.setMachineType(machineType)
.setProvisioningModel(ProvisioningModel.PREEMPTIBLE)
.build()

val allocationPolicy = AllocationPolicy
.newBuilder()
.setLocation(LocationPolicy.newBuilder().addAllowedLocations("regions/us-central1"))
.setLocation(LocationPolicy.newBuilder().addAllowedLocations(location))
.addInstances(InstancePolicyOrTemplate.newBuilder().setPolicy(instancePolicy))
.build()

val jobStatus = JobStatus.newBuilder().setState(JobStatus.State.RUNNING).build()
val startStatusEvent = StatusEvent
.newBuilder()
.setType("STATUS_CHANGED")
.setEventTime(Timestamp.newBuilder().setSeconds(1).build())
.setDescription("Job state is set from SCHEDULED to RUNNING for job...")
.build()

val endStatusEvent = StatusEvent
.newBuilder()
.setType("STATUS_CHANGED")
.setEventTime(Timestamp.newBuilder().setSeconds(2).build())
.setDescription("Job state is set from RUNNING to SOME_OTHER_STATUS for job...")
.build()

val jobStatus = JobStatus
.newBuilder()
.setState(jobState)
.addStatusEvents(startStatusEvent)
.addStatusEvents(endStatusEvent)
.build()

val job = Job.newBuilder().setAllocationPolicy(allocationPolicy).setStatus(jobStatus).build()

val mockClient = mock[BatchServiceClient]
doReturn(job).when(mockClient).getJob(any[GetJobRequest])
doReturn(job).when(mockClient).getJob(any[String])

mockClient
}

behavior of "BatchRequestExecutor"

it should "create instantiatedVmInfo correctly" in {

val mockClient = setupBatchClient(jobState = JobStatus.State.RUNNING)
// Create the BatchRequestExecutor
val batchRequestExecutor = new BatchRequestExecutor.CloudImpl(BatchServiceSettings.newBuilder().build())

Expand All @@ -72,4 +89,90 @@ class BatchRequestExecutorSpec
case _ => fail("Expected RunStatus.Running with instantiatedVmInfo")
}
}

it should "create instantiatedVmInfo correctly with different location info" in {

val mockClient = setupBatchClient(location = "zones/us-central1-a", jobState = JobStatus.State.RUNNING)

// Create the BatchRequestExecutor
val batchRequestExecutor = new BatchRequestExecutor.CloudImpl(BatchServiceSettings.newBuilder().build())

// testing a private method see https://www.scalatest.org/user_guide/using_PrivateMethodTester
val internalGetHandler = PrivateMethod[BatchApiResponse.StatusQueried](Symbol("internalGetHandler"))
val result = batchRequestExecutor invokePrivate internalGetHandler(mockClient, GetJobRequest.newBuilder().build())

// Verify the instantiatedVmInfo
result.status match {
case RunStatus.Running(_, Some(instantiatedVmInfo)) =>
instantiatedVmInfo.region shouldBe "us-central1-a"
instantiatedVmInfo.machineType shouldBe "n1-standard-1"
instantiatedVmInfo.preemptible shouldBe true
case _ => fail("Expected RunStatus.Running with instantiatedVmInfo")
}
}

it should "create instantiatedVmInfo correctly with missing location info" in {

val mockClient = setupBatchClient(jobState = JobStatus.State.RUNNING)

// Create the BatchRequestExecutor
val batchRequestExecutor = new BatchRequestExecutor.CloudImpl(BatchServiceSettings.newBuilder().build())

// testing a private method see https://www.scalatest.org/user_guide/using_PrivateMethodTester
val internalGetHandler = PrivateMethod[BatchApiResponse.StatusQueried](Symbol("internalGetHandler"))
val result = batchRequestExecutor invokePrivate internalGetHandler(mockClient, GetJobRequest.newBuilder().build())

// Verify the instantiatedVmInfo
result.status match {
case RunStatus.Running(_, Some(instantiatedVmInfo)) =>
instantiatedVmInfo.region shouldBe "us-central1"
instantiatedVmInfo.machineType shouldBe "n1-standard-1"
instantiatedVmInfo.preemptible shouldBe true
case _ => fail("Expected RunStatus.Running with instantiatedVmInfo")
}
}

it should "send vmStartTime and vmEndTime metadata info when a workflow succeeds" in {

val mockClient = setupBatchClient()

// Create the BatchRequestExecutor
val batchRequestExecutor = new BatchRequestExecutor.CloudImpl(BatchServiceSettings.newBuilder().build())

// testing a private method see https://www.scalatest.org/user_guide/using_PrivateMethodTester
val internalGetHandler = PrivateMethod[BatchApiResponse.StatusQueried](Symbol("internalGetHandler"))
val result = batchRequestExecutor invokePrivate internalGetHandler(mockClient, GetJobRequest.newBuilder().build())

// Verify the events
result.status match {
case RunStatus.Success(events, _) =>
val eventNames = events.map(_.name)
val eventTimes = events.map(_.offsetDateTime.toString)
eventNames should contain theSameElementsAs List("vmStartTime", "vmEndTime")
eventTimes should contain theSameElementsAs List("1970-01-01T00:00:01Z", "1970-01-01T00:00:02Z")
case _ => fail("Expected RunStatus.Success with events")
}
}

it should "send vmStartTime and vmEndTime metadata info when a workflow fails" in {
val mockClient = setupBatchClient(jobState = JobStatus.State.FAILED)

// Create the BatchRequestExecutor
val batchRequestExecutor = new BatchRequestExecutor.CloudImpl(BatchServiceSettings.newBuilder().build())

// testing a private method see https://www.scalatest.org/user_guide/using_PrivateMethodTester
val internalGetHandler = PrivateMethod[BatchApiResponse.StatusQueried](Symbol("internalGetHandler"))
val result = batchRequestExecutor invokePrivate internalGetHandler(mockClient, GetJobRequest.newBuilder().build())

// Verify the events
result.status match {
case RunStatus.Failed(_, events, _) =>
val eventNames = events.map(_.name)
val eventTimes = events.map(_.offsetDateTime.toString)
eventNames should contain theSameElementsAs List("vmStartTime", "vmEndTime")
eventTimes should contain theSameElementsAs List("1970-01-01T00:00:01Z", "1970-01-01T00:00:02Z")
case _ => fail("Expected RunStatus.Success with events")
}
}

}

0 comments on commit 13aed93

Please sign in to comment.