/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.beam.sdk.fn.splittabledofn;

import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.ReentrantLock;
import javax.annotation.concurrent.ThreadSafe;
import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker.HasProgress;
import org.apache.beam.sdk.transforms.splittabledofn.SplitResult;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;

/** Support utilities for interacting with {@link RestrictionTracker RestrictionTrackers}. */
@SuppressWarnings({
  "nullness" // TODO(https://github.com/apache/beam/issues/20497)
})
public class RestrictionTrackers {

  /** Interface allowing a runner to observe the calls to {@link RestrictionTracker#tryClaim}. */
  public interface ClaimObserver<PositionT> {
    /** Called when {@link RestrictionTracker#tryClaim} returns true. */
    void onClaimed(PositionT position);

    /** Called when {@link RestrictionTracker#tryClaim} returns false. */
    void onClaimFailed(PositionT position);
  }

  /**
   * A {@link RestrictionTracker} which forwards all calls to the delegate {@link
   * RestrictionTracker}.
   */
  @ThreadSafe
  private static class RestrictionTrackerObserver<RestrictionT, PositionT>
      extends RestrictionTracker<RestrictionT, PositionT> {
    protected final RestrictionTracker<RestrictionT, PositionT> delegate;
    protected ReentrantLock lock = new ReentrantLock();
    protected volatile boolean hasInitialProgress = false;
    private final ClaimObserver<PositionT> claimObserver;

    protected RestrictionTrackerObserver(
        RestrictionTracker<RestrictionT, PositionT> delegate,
        ClaimObserver<PositionT> claimObserver) {
      this.delegate = delegate;
      this.claimObserver = claimObserver;
    }

    @Override
    public boolean tryClaim(PositionT position) {
      lock.lock();
      try {
        if (delegate.tryClaim(position)) {
          claimObserver.onClaimed(position);
          return true;
        } else {
          claimObserver.onClaimFailed(position);
          return false;
        }
      } finally {
        lock.unlock();
      }
    }

    @Override
    public RestrictionT currentRestriction() {
      lock.lock();
      try {
        return delegate.currentRestriction();
      } finally {
        lock.unlock();
      }
    }

    @Override
    public SplitResult<RestrictionT> trySplit(double fractionOfRemainder) {
      lock.lock();
      try {
        SplitResult<RestrictionT> result = delegate.trySplit(fractionOfRemainder);
        return result;
      } finally {
        lock.unlock();
      }
    }

    @Override
    public void checkDone() throws IllegalStateException {
      lock.lock();
      try {
        delegate.checkDone();
      } finally {
        lock.unlock();
      }
    }

    @Override
    public IsBounded isBounded() {
      return delegate.isBounded();
    }

    /** Evaluate progress if requested. */
    protected Progress getProgressBlocking() {
      lock.lock();
      try {
        return ((HasProgress) delegate).getProgress();
      } finally {
        lock.unlock();
      }
    }
  }

  /**
   * A {@link RestrictionTracker} which forwards all calls to the delegate progress reporting {@link
   * RestrictionTracker}.
   */
  @ThreadSafe
  static class RestrictionTrackerObserverWithProgress<RestrictionT, PositionT>
      extends RestrictionTrackerObserver<RestrictionT, PositionT> implements HasProgress {
    private static final int FIRST_PROGRESS_TIMEOUT_SEC = 60;

    protected RestrictionTrackerObserverWithProgress(
        RestrictionTracker<RestrictionT, PositionT> delegate,
        ClaimObserver<PositionT> claimObserver) {
      super(delegate, claimObserver);
    }

    @Override
    public Progress getProgress() {
      return getProgress(FIRST_PROGRESS_TIMEOUT_SEC);
    }

    @VisibleForTesting
    Progress getProgress(int timeOutSec) {
      if (!hasInitialProgress) {
        Progress progress = Progress.NONE;
        try {
          // lock can be held long by long-running tryClaim/trySplit. We tolerate this scenario
          // by returning zero progress when initial progress never evaluated before due to lock
          // timeout.
          if (lock.tryLock(timeOutSec, TimeUnit.SECONDS)) {
            try {
              progress = getProgressBlocking();
              hasInitialProgress = true;
            } finally {
              lock.unlock();
            }
          }
        } catch (InterruptedException e) {
          Thread.currentThread().interrupt();
        }
        return progress;
      } else {
        return getProgressBlocking();
      }
    }
  }

  /**
   * Returns a thread safe {@link RestrictionTracker} which reports all claim attempts to the
   * specified {@link ClaimObserver}.
   */
  public static <RestrictionT, PositionT> RestrictionTracker<RestrictionT, PositionT> observe(
      RestrictionTracker<RestrictionT, PositionT> restrictionTracker,
      ClaimObserver<PositionT> claimObserver) {
    if (restrictionTracker instanceof RestrictionTracker.HasProgress) {
      return new RestrictionTrackerObserverWithProgress<>(restrictionTracker, claimObserver);
    } else {
      return new RestrictionTrackerObserver<>(restrictionTracker, claimObserver);
    }
  }

  public static <RestrictionT, PositionT> RestrictionTracker<RestrictionT, PositionT> synchronize(
      RestrictionTracker<RestrictionT, PositionT> restrictionTracker) {
    if (restrictionTracker instanceof RestrictionTracker.HasProgress) {
      return new RestrictionTrackerObserverWithProgress<>(
          restrictionTracker, (ClaimObserver<PositionT>) NOOP_CLAIM_OBSERVER);
    } else {
      return new RestrictionTrackerObserver<>(
          restrictionTracker, (ClaimObserver<PositionT>) NOOP_CLAIM_OBSERVER);
    }
  }

  static class NoopClaimObserver<PositionT> implements ClaimObserver<PositionT> {
    @Override
    public void onClaimed(PositionT position) {}

    @Override
    public void onClaimFailed(PositionT position) {}
  }

  private static final NoopClaimObserver<Object> NOOP_CLAIM_OBSERVER = new NoopClaimObserver<>();
}
