3232import io .grpc .Status ;
3333import java .io .IOException ;
3434import java .time .Clock ;
35+ import java .time .Instant ;
3536import java .util .ArrayList ;
37+ import java .util .EnumSet ;
38+ import java .util .function .Supplier ;
3639import java .util .List ;
3740import java .util .Random ;
3841import java .util .concurrent .CancellationException ;
4346import java .util .concurrent .atomic .AtomicBoolean ;
4447import java .util .concurrent .atomic .AtomicInteger ;
4548import java .util .concurrent .atomic .AtomicReference ;
46- import java .util .function .Supplier ;
4749import java .util .logging .Level ;
4850import java .util .logging .Logger ;
4951import javax .annotation .Nullable ;
@@ -196,7 +198,7 @@ private int pickEntryIndexLeastInFlight() {
196198
197199 for (int i = 0 ; i < localEntries .size (); i ++) {
198200 Entry entry = localEntries .get (i );
199- int rpcs = entry .outstandingRpcs . get ();
201+ int rpcs = entry .adjustedOutstandingRpcs ();
200202 if (rpcs < minRpcs ) {
201203 minRpcs = rpcs ;
202204 candidates .clear ();
@@ -222,7 +224,9 @@ private int pickEntryIndexPowerOfTwoLeastInFlight() {
222224
223225 Entry entry1 = localEntries .get (choice1 );
224226 Entry entry2 = localEntries .get (choice2 );
225- return entry1 .outstandingRpcs .get () < entry2 .outstandingRpcs .get () ? choice1 : choice2 ;
227+ return entry1 .adjustedOutstandingRpcs () < entry2 .adjustedOutstandingRpcs ()
228+ ? choice1
229+ : choice2 ;
226230 }
227231
228232 Channel getChannel (int index ) {
@@ -334,7 +338,7 @@ void resize() {
334338 List <Entry > localEntries = entries .get ();
335339 // Estimate the peak of RPCs in the last interval by summing the peak of RPCs per channel
336340 int actualOutstandingRpcs =
337- localEntries .stream ().mapToInt (Entry ::getAndResetMaxOutstanding ).sum ();
341+ localEntries .stream ().mapToInt (Entry ::adjustedOutstandingRpcs ).sum ();
338342
339343 // Number of channels if each channel operated at max capacity
340344 int minChannels =
@@ -527,7 +531,29 @@ static class Entry {
527531 */
528532 @ VisibleForTesting final AtomicInteger outstandingRpcs = new AtomicInteger (0 );
529533
530- private final AtomicInteger maxOutstanding = new AtomicInteger ();
534+ /*
535+ * Server errors that are likely to succeed on a different server and fail within a short period
536+ * of time - these errors add a latency penalty to adjustedOutstandingRpcs. Note that
537+ * DEADLINE_EXCEEDED is not included as it's typically caused by a slow request.
538+ */
539+ EnumSet <Status .Code > FAST_SERVER_ERRORS = EnumSet .of (
540+ Status .Code .UNKNOWN ,
541+ Status .Code .UNIMPLEMENTED ,
542+ Status .Code .INTERNAL ,
543+ Status .Code .UNAVAILABLE ,
544+ Status .Code .DATA_LOSS );
545+ private static final java .time .Duration FAST_SERVER_ERROR_PENALTY =
546+ java .time .Duration .ofSeconds (1 );
547+
548+ /*
549+ * Sometimes it's useful to pretend that failed requests take longer than they do.
550+ * delayedOutstandingRequests keeps track of the simulated end time of these requests (in order).
551+ */
552+ private final ConcurrentLinkedQueue <Instant > delayedOutstandingRequests =
553+ new ConcurrentLinkedQueue <>();
554+
555+ /* Equivalent to delayedOutstandingRequests.size(), but with constant access time. */
556+ private final AtomicInteger delayedOutstandingRequestsSize = new AtomicInteger (0 );
531557
532558 /** Queue storing the last 5 minutes of probe results */
533559 @ VisibleForTesting
@@ -555,8 +581,26 @@ ManagedChannel getManagedChannel() {
555581 return this .channel ;
556582 }
557583
558- int getAndResetMaxOutstanding () {
559- return maxOutstanding .getAndSet (outstandingRpcs .get ());
584+ private void drainAdjustedOutstandingRpcs () {
585+ Instant now = Instant .now ();
586+ Instant oldest = delayedOutstandingRequests .peek ();
587+
588+ while (oldest != null && oldest .isBefore (now )) {
589+ // poll() returns null if the queue became empty between peek() and poll()
590+ if (delayedOutstandingRequests .poll () != null ) {
591+ delayedOutstandingRequestsSize .decrementAndGet ();
592+ }
593+ oldest = delayedOutstandingRequests .peek ();
594+ }
595+ }
596+
597+ /*
598+ * Number of RPCs that would be outstanding if requests with a {@link FAST_SERVER_ERRORS} had
599+ * extra latency of {@link FAST_SERVER_ERROR_PENALTY}. This is useful for load balancing.
600+ */
601+ int adjustedOutstandingRpcs () {
602+ drainAdjustedOutstandingRpcs ();
603+ return outstandingRpcs .get () + delayedOutstandingRequestsSize .get ();
560604 }
561605
562606 /**
@@ -567,17 +611,11 @@ int getAndResetMaxOutstanding() {
567611 */
568612 private boolean retain () {
569613 // register desire to start RPC
570- int currentOutstanding = outstandingRpcs .incrementAndGet ();
571-
572- // Rough bookkeeping
573- int prevMax = maxOutstanding .get ();
574- if (currentOutstanding > prevMax ) {
575- maxOutstanding .incrementAndGet ();
576- }
614+ outstandingRpcs .incrementAndGet ();
577615
578616 // abort if the channel is closing
579617 if (shutdownRequested .get ()) {
580- release ();
618+ release (Status . CANCELLED );
581619 return false ;
582620 }
583621 return true ;
@@ -587,11 +625,15 @@ private boolean retain() {
587625 * Notify the channel that the number of outstanding RPCs has decreased. If shutdown has been
588626 * previously requested, this method will shutdown the channel if its the last outstanding RPC.
589627 */
590- private void release () {
628+ private void release (Status status ) {
591629 int newCount = outstandingRpcs .decrementAndGet ();
592630 if (newCount < 0 ) {
593631 LOG .log (Level .WARNING , "Bug! Reference count is negative (" + newCount + ")!" );
594632 }
633+ if (FAST_SERVER_ERRORS .contains (status .getCode ())) {
634+ delayedOutstandingRequests .add (Instant .now ().plus (FAST_SERVER_ERROR_PENALTY ));
635+ delayedOutstandingRequestsSize .incrementAndGet ();
636+ }
595637
596638 // Must check outstandingRpcs after shutdownRequested (in reverse order of retain()) to ensure
597639 // mutual exclusion.
@@ -673,7 +715,7 @@ public void onClose(Status status, Metadata trailers) {
673715 super .onClose (status , trailers );
674716 } finally {
675717 if (wasReleased .compareAndSet (false , true )) {
676- entry .release ();
718+ entry .release (status );
677719 } else {
678720 LOG .log (
679721 Level .WARNING ,
@@ -687,7 +729,7 @@ public void onClose(Status status, Metadata trailers) {
687729 } catch (Exception e ) {
688730 // In case start failed, make sure to release
689731 if (wasReleased .compareAndSet (false , true )) {
690- entry .release ();
732+ entry .release (Status . fromThrowable ( e ) );
691733 } else {
692734 LOG .log (
693735 Level .WARNING ,
0 commit comments