diff --git a/server/src/main/java/org/opensearch/action/search/SearchPhase.java b/server/src/main/java/org/opensearch/action/search/SearchPhase.java index 0890e9f5de8d4..8eab2ee8dedac 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchPhase.java +++ b/server/src/main/java/org/opensearch/action/search/SearchPhase.java @@ -37,6 +37,7 @@ import java.io.IOException; import java.util.Locale; import java.util.Objects; +import java.util.Optional; /** * Base class for all individual search phases like collecting distributed frequencies, fetching documents, querying shards. @@ -69,11 +70,15 @@ public String getName() { } /** - * Returns the SearchPhase name as {@link SearchPhaseName}. Exception will come if SearchPhase name is not defined - * in {@link SearchPhaseName} - * @return {@link SearchPhaseName} + * Returns an Optional of the SearchPhase name as {@link SearchPhaseName}. If there's not a matching SearchPhaseName, + * returns an empty Optional. + * @return {@link Optional} */ - public SearchPhaseName getSearchPhaseName() { - return SearchPhaseName.valueOf(name.toUpperCase(Locale.ROOT)); + public Optional getSearchPhaseName() { + try { + return Optional.of(SearchPhaseName.valueOf(name.toUpperCase(Locale.ROOT))); + } catch (IllegalArgumentException e) { + return Optional.empty(); + } } } diff --git a/server/src/main/java/org/opensearch/action/search/SearchRequestStats.java b/server/src/main/java/org/opensearch/action/search/SearchRequestStats.java index a2722318ac599..88728436df847 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchRequestStats.java +++ b/server/src/main/java/org/opensearch/action/search/SearchRequestStats.java @@ -73,32 +73,22 @@ public long getTookMetric() { @Override protected void onPhaseStart(SearchPhaseContext context) { - try { - phaseStatsMap.get(context.getCurrentPhase().getSearchPhaseName()).current.inc(); - } catch (IllegalArgumentException ignored) { - // Do nothing if the phase isn't found in SearchPhaseName. - } + context.getCurrentPhase().getSearchPhaseName().ifPresent(name -> phaseStatsMap.get(name).current.inc()); } @Override protected void onPhaseEnd(SearchPhaseContext context, SearchRequestContext searchRequestContext) { - try { - StatsHolder phaseStats = phaseStatsMap.get(context.getCurrentPhase().getSearchPhaseName()); + context.getCurrentPhase().getSearchPhaseName().ifPresent(name -> { + StatsHolder phaseStats = phaseStatsMap.get(name); phaseStats.current.dec(); phaseStats.total.inc(); phaseStats.timing.inc(TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - context.getCurrentPhase().getStartTimeInNanos())); - } catch (IllegalArgumentException ignored) { - // Do nothing if the phase isn't found in SearchPhaseName. - } + }); } @Override protected void onPhaseFailure(SearchPhaseContext context, Throwable cause) { - try { - phaseStatsMap.get(context.getCurrentPhase().getSearchPhaseName()).current.dec(); - } catch (IllegalArgumentException ignored) { - // Do nothing if the phase isn't found in SearchPhaseName. - } + context.getCurrentPhase().getSearchPhaseName().ifPresent(name -> phaseStatsMap.get(name).current.dec()); } @Override diff --git a/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java b/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java index 27336e86e52b0..3bc9282e17fc4 100644 --- a/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java +++ b/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java @@ -399,29 +399,29 @@ public void testOnPhaseFailureAndVerifyListeners() { final List requestOperationListeners = List.of(testListener, assertingListener); SearchQueryThenFetchAsyncAction action = createSearchQueryThenFetchAsyncAction(requestOperationListeners); action.start(); - assertEquals(1, testListener.getPhaseCurrent(action.getSearchPhaseName())); + assertEquals(1, testListener.getPhaseCurrent(action.getSearchPhaseName().get())); action.onPhaseFailure(new SearchPhase("test") { @Override public void run() { } }, "message", null); - assertEquals(0, testListener.getPhaseCurrent(action.getSearchPhaseName())); - assertEquals(0, testListener.getPhaseTotal(action.getSearchPhaseName())); + assertEquals(0, testListener.getPhaseCurrent(action.getSearchPhaseName().get())); + assertEquals(0, testListener.getPhaseTotal(action.getSearchPhaseName().get())); SearchDfsQueryThenFetchAsyncAction searchDfsQueryThenFetchAsyncAction = createSearchDfsQueryThenFetchAsyncAction( requestOperationListeners ); searchDfsQueryThenFetchAsyncAction.start(); - assertEquals(1, testListener.getPhaseCurrent(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName())); + assertEquals(1, testListener.getPhaseCurrent(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName().get())); searchDfsQueryThenFetchAsyncAction.onPhaseFailure(new SearchPhase("test") { @Override public void run() { } }, "message", null); - assertEquals(0, testListener.getPhaseCurrent(action.getSearchPhaseName())); - assertEquals(0, testListener.getPhaseTotal(action.getSearchPhaseName())); + assertEquals(0, testListener.getPhaseCurrent(action.getSearchPhaseName().get())); + assertEquals(0, testListener.getPhaseTotal(action.getSearchPhaseName().get())); FetchSearchPhase fetchPhase = createFetchSearchPhase(); ShardId shardId = new ShardId(randomAlphaOfLengthBetween(5, 10), randomAlphaOfLength(10), randomInt()); @@ -430,15 +430,15 @@ public void run() { action.skipShard(searchShardIterator); action.start(); action.executeNextPhase(action, fetchPhase); - assertEquals(1, testListener.getPhaseCurrent(fetchPhase.getSearchPhaseName())); + assertEquals(1, testListener.getPhaseCurrent(fetchPhase.getSearchPhaseName().get())); action.onPhaseFailure(new SearchPhase("test") { @Override public void run() { } }, "message", null); - assertEquals(0, testListener.getPhaseCurrent(fetchPhase.getSearchPhaseName())); - assertEquals(0, testListener.getPhaseTotal(fetchPhase.getSearchPhaseName())); + assertEquals(0, testListener.getPhaseCurrent(fetchPhase.getSearchPhaseName().get())); + assertEquals(0, testListener.getPhaseTotal(fetchPhase.getSearchPhaseName().get())); } public void testOnPhaseFailure() { @@ -722,7 +722,7 @@ public void testOnPhaseListenersWithQueryAndThenFetchType() throws InterruptedEx action.start(); // Verify queryPhase current metric - assertEquals(1, testListener.getPhaseCurrent(action.getSearchPhaseName())); + assertEquals(1, testListener.getPhaseCurrent(action.getSearchPhaseName().get())); TimeUnit.MILLISECONDS.sleep(delay); FetchSearchPhase fetchPhase = createFetchSearchPhase(); @@ -733,12 +733,12 @@ public void testOnPhaseListenersWithQueryAndThenFetchType() throws InterruptedEx action.executeNextPhase(action, fetchPhase); // Verify queryPhase total, current and latency metrics - assertEquals(0, testListener.getPhaseCurrent(action.getSearchPhaseName())); - assertThat(testListener.getPhaseMetric(action.getSearchPhaseName()), greaterThanOrEqualTo(delay)); - assertEquals(1, testListener.getPhaseTotal(action.getSearchPhaseName())); + assertEquals(0, testListener.getPhaseCurrent(action.getSearchPhaseName().get())); + assertThat(testListener.getPhaseMetric(action.getSearchPhaseName().get()), greaterThanOrEqualTo(delay)); + assertEquals(1, testListener.getPhaseTotal(action.getSearchPhaseName().get())); // Verify fetchPhase current metric - assertEquals(1, testListener.getPhaseCurrent(fetchPhase.getSearchPhaseName())); + assertEquals(1, testListener.getPhaseCurrent(fetchPhase.getSearchPhaseName().get())); TimeUnit.MILLISECONDS.sleep(delay); ExpandSearchPhase expandPhase = createExpandSearchPhase(); @@ -746,18 +746,18 @@ public void testOnPhaseListenersWithQueryAndThenFetchType() throws InterruptedEx TimeUnit.MILLISECONDS.sleep(delay); // Verify fetchPhase total, current and latency metrics - assertThat(testListener.getPhaseMetric(fetchPhase.getSearchPhaseName()), greaterThanOrEqualTo(delay)); - assertEquals(1, testListener.getPhaseTotal(fetchPhase.getSearchPhaseName())); - assertEquals(0, testListener.getPhaseCurrent(fetchPhase.getSearchPhaseName())); + assertThat(testListener.getPhaseMetric(fetchPhase.getSearchPhaseName().get()), greaterThanOrEqualTo(delay)); + assertEquals(1, testListener.getPhaseTotal(fetchPhase.getSearchPhaseName().get())); + assertEquals(0, testListener.getPhaseCurrent(fetchPhase.getSearchPhaseName().get())); - assertEquals(1, testListener.getPhaseCurrent(expandPhase.getSearchPhaseName())); + assertEquals(1, testListener.getPhaseCurrent(expandPhase.getSearchPhaseName().get())); action.executeNextPhase(expandPhase, fetchPhase); action.onPhaseDone(); /* finish phase since we don't have reponse being sent */ - assertThat(testListener.getPhaseMetric(expandPhase.getSearchPhaseName()), greaterThanOrEqualTo(delay)); - assertEquals(1, testListener.getPhaseTotal(expandPhase.getSearchPhaseName())); - assertEquals(0, testListener.getPhaseCurrent(expandPhase.getSearchPhaseName())); + assertThat(testListener.getPhaseMetric(expandPhase.getSearchPhaseName().get()), greaterThanOrEqualTo(delay)); + assertEquals(1, testListener.getPhaseTotal(expandPhase.getSearchPhaseName().get())); + assertEquals(0, testListener.getPhaseCurrent(expandPhase.getSearchPhaseName().get())); } public void testOnPhaseListenersWithDfsType() throws InterruptedException { @@ -772,7 +772,7 @@ public void testOnPhaseListenersWithDfsType() throws InterruptedException { FetchSearchPhase fetchPhase = createFetchSearchPhase(); searchDfsQueryThenFetchAsyncAction.start(); - assertEquals(1, testListener.getPhaseCurrent(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName())); + assertEquals(1, testListener.getPhaseCurrent(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName().get())); TimeUnit.MILLISECONDS.sleep(delay); ShardId shardId = new ShardId(randomAlphaOfLengthBetween(5, 10), randomAlphaOfLength(10), randomInt()); SearchShardIterator searchShardIterator = new SearchShardIterator(null, shardId, Collections.emptyList(), OriginalIndices.NONE); @@ -786,9 +786,9 @@ public void testOnPhaseListenersWithDfsType() throws InterruptedException { null ); /* finalizing the fetch phase since we do adhoc phase lifecycle calls */ - assertThat(testListener.getPhaseMetric(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName()), greaterThanOrEqualTo(delay)); - assertEquals(1, testListener.getPhaseTotal(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName())); - assertEquals(0, testListener.getPhaseCurrent(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName())); + assertThat(testListener.getPhaseMetric(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName().get()), greaterThanOrEqualTo(delay)); + assertEquals(1, testListener.getPhaseTotal(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName().get())); + assertEquals(0, testListener.getPhaseCurrent(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName().get())); } private SearchDfsQueryThenFetchAsyncAction createSearchDfsQueryThenFetchAsyncAction( diff --git a/server/src/test/java/org/opensearch/action/search/SearchRequestOperationsListenerTests.java b/server/src/test/java/org/opensearch/action/search/SearchRequestOperationsListenerTests.java index 990ed95f1aebc..0b62d5a16427a 100644 --- a/server/src/test/java/org/opensearch/action/search/SearchRequestOperationsListenerTests.java +++ b/server/src/test/java/org/opensearch/action/search/SearchRequestOperationsListenerTests.java @@ -14,6 +14,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -30,18 +31,18 @@ public void testListenersAreExecuted() { @Override public void onPhaseStart(SearchPhaseContext context) { - searchPhaseMap.get(context.getCurrentPhase().getSearchPhaseName()).current.inc(); + searchPhaseMap.get(context.getCurrentPhase().getSearchPhaseName().get()).current.inc(); } @Override public void onPhaseEnd(SearchPhaseContext context, SearchRequestContext searchRequestContext) { - searchPhaseMap.get(context.getCurrentPhase().getSearchPhaseName()).current.dec(); - searchPhaseMap.get(context.getCurrentPhase().getSearchPhaseName()).total.inc(); + searchPhaseMap.get(context.getCurrentPhase().getSearchPhaseName().get()).current.dec(); + searchPhaseMap.get(context.getCurrentPhase().getSearchPhaseName().get()).total.inc(); } @Override public void onPhaseFailure(SearchPhaseContext context, Throwable cause) { - searchPhaseMap.get(context.getCurrentPhase().getSearchPhaseName()).current.dec(); + searchPhaseMap.get(context.getCurrentPhase().getSearchPhaseName().get()).current.dec(); } }; @@ -61,7 +62,7 @@ public void onPhaseFailure(SearchPhaseContext context, Throwable cause) { for (SearchPhaseName searchPhaseName : SearchPhaseName.values()) { when(ctx.getCurrentPhase()).thenReturn(searchPhase); - when(searchPhase.getSearchPhaseName()).thenReturn(searchPhaseName); + when(searchPhase.getSearchPhaseName()).thenReturn(Optional.of(searchPhaseName)); compositeListener.onPhaseStart(ctx); assertEquals(totalListeners, searchPhaseMap.get(searchPhaseName).current.count()); } diff --git a/server/src/test/java/org/opensearch/action/search/SearchRequestStatsTests.java b/server/src/test/java/org/opensearch/action/search/SearchRequestStatsTests.java index 2b9a5992c0117..876bc395dcd52 100644 --- a/server/src/test/java/org/opensearch/action/search/SearchRequestStatsTests.java +++ b/server/src/test/java/org/opensearch/action/search/SearchRequestStatsTests.java @@ -16,6 +16,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Phaser; import java.util.concurrent.TimeUnit; @@ -68,7 +69,7 @@ public void testSearchRequestPhaseFailure() { when(ctx.getCurrentPhase()).thenReturn(mockSearchPhase); for (SearchPhaseName searchPhaseName : SearchPhaseName.values()) { - when(mockSearchPhase.getSearchPhaseName()).thenReturn(searchPhaseName); + when(mockSearchPhase.getSearchPhaseName()).thenReturn(Optional.of(searchPhaseName)); testRequestStats.onPhaseStart(ctx); assertEquals(1, testRequestStats.getPhaseCurrent(searchPhaseName)); testRequestStats.onPhaseFailure(ctx, new Throwable()); @@ -85,7 +86,7 @@ public void testSearchRequestStats() { when(ctx.getCurrentPhase()).thenReturn(mockSearchPhase); for (SearchPhaseName searchPhaseName : SearchPhaseName.values()) { - when(mockSearchPhase.getSearchPhaseName()).thenReturn(searchPhaseName); + when(mockSearchPhase.getSearchPhaseName()).thenReturn(Optional.of(searchPhaseName)); long tookTimeInMillis = randomIntBetween(1, 10); testRequestStats.onPhaseStart(ctx); long startTime = System.nanoTime() - TimeUnit.MILLISECONDS.toNanos(tookTimeInMillis); @@ -116,7 +117,7 @@ public void testSearchRequestStatsOnPhaseStartConcurrently() throws InterruptedE SearchPhaseContext ctx = mock(SearchPhaseContext.class); SearchPhase mockSearchPhase = mock(SearchPhase.class); when(ctx.getCurrentPhase()).thenReturn(mockSearchPhase); - when(mockSearchPhase.getSearchPhaseName()).thenReturn(searchPhaseName); + when(mockSearchPhase.getSearchPhaseName()).thenReturn(Optional.of(searchPhaseName)); for (int i = 0; i < numTasks; i++) { threads[i] = new Thread(() -> { phaser.arriveAndAwaitAdvance(); @@ -145,7 +146,7 @@ public void testSearchRequestStatsOnPhaseEndConcurrently() throws InterruptedExc SearchPhaseContext ctx = mock(SearchPhaseContext.class); SearchPhase mockSearchPhase = mock(SearchPhase.class); when(ctx.getCurrentPhase()).thenReturn(mockSearchPhase); - when(mockSearchPhase.getSearchPhaseName()).thenReturn(searchPhaseName); + when(mockSearchPhase.getSearchPhaseName()).thenReturn(Optional.of(searchPhaseName)); long tookTimeInMillis = randomIntBetween(1, 10); long startTime = System.nanoTime() - TimeUnit.MILLISECONDS.toNanos(tookTimeInMillis); when(mockSearchPhase.getStartTimeInNanos()).thenReturn(startTime); @@ -188,7 +189,7 @@ public void testSearchRequestStatsOnPhaseFailureConcurrently() throws Interrupte SearchPhaseContext ctx = mock(SearchPhaseContext.class); SearchPhase mockSearchPhase = mock(SearchPhase.class); when(ctx.getCurrentPhase()).thenReturn(mockSearchPhase); - when(mockSearchPhase.getSearchPhaseName()).thenReturn(searchPhaseName); + when(mockSearchPhase.getSearchPhaseName()).thenReturn(Optional.of(searchPhaseName)); for (int i = 0; i < numTasks; i++) { threads[i] = new Thread(() -> { phaser.arriveAndAwaitAdvance(); diff --git a/server/src/test/java/org/opensearch/index/search/stats/SearchStatsTests.java b/server/src/test/java/org/opensearch/index/search/stats/SearchStatsTests.java index 594700ea60b3e..519c937e348bc 100644 --- a/server/src/test/java/org/opensearch/index/search/stats/SearchStatsTests.java +++ b/server/src/test/java/org/opensearch/index/search/stats/SearchStatsTests.java @@ -44,6 +44,7 @@ import java.util.HashMap; import java.util.Map; +import java.util.Optional; import java.util.concurrent.TimeUnit; import static org.hamcrest.Matchers.greaterThanOrEqualTo; @@ -86,7 +87,7 @@ public void testShardLevelSearchGroupStats() throws Exception { SearchPhase mockSearchPhase = mock(SearchPhase.class); when(ctx.getCurrentPhase()).thenReturn(mockSearchPhase); when(mockSearchPhase.getStartTimeInNanos()).thenReturn(System.nanoTime() - TimeUnit.SECONDS.toNanos(paramValue)); - when(mockSearchPhase.getSearchPhaseName()).thenReturn(searchPhaseName); + when(mockSearchPhase.getSearchPhaseName()).thenReturn(Optional.of(searchPhaseName)); for (int iterator = 0; iterator < paramValue; iterator++) { onPhaseStart(testRequestStats, ctx); onPhaseEnd(testRequestStats, ctx);