/*
 * Decompiled with CFR 0.152.
 */
package org.eclipse.tracecompass.incubator.internal.rocm.core.analysis.dependency;

import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Queue;
import org.eclipse.tracecompass.analysis.os.linux.core.model.HostThread;
import org.eclipse.tracecompass.incubator.internal.rocm.core.Activator;
import org.eclipse.tracecompass.incubator.internal.rocm.core.analysis.dependency.AbstractDependencyMaker;
import org.eclipse.tracecompass.incubator.internal.rocm.core.analysis.handlers.old.AbstractGpuEventHandler;
import org.eclipse.tracecompass.incubator.internal.rocm.core.analysis.handlers.old.ApiEventHandler;
import org.eclipse.tracecompass.incubator.internal.rocm.core.analysis.handlers.old.HostThreadIdentifier;
import org.eclipse.tracecompass.statesystem.core.ITmfStateSystemBuilder;
import org.eclipse.tracecompass.tmf.core.event.ITmfEvent;
import org.eclipse.tracecompass.tmf.core.util.Pair;

public class HipApiHipActivityDependencyMaker
extends AbstractDependencyMaker {
    private final Map<Long, ITmfEvent> fApiEventCorrelationMap = new HashMap<Long, ITmfEvent>();
    private List<ITmfEvent> fWaitEventPerThread = new LinkedList<ITmfEvent>();
    private final Queue<Pair<ITmfEvent, ITmfEvent>> fInFlightEvents = new PriorityQueue<Pair<ITmfEvent, ITmfEvent>>(new EventComparator());

    @Override
    public void processEvent(ITmfEvent event, ITmfStateSystemBuilder ssb) {
        this.removeInFlightEvents(event);
        switch (event.getName()) {
            case "hip_api": {
                String eventName = ApiEventHandler.getFunctionApiName(event);
                if (eventName.equals("hipLaunchKernel")) {
                    this.addGpuActivityDispatch(event, 7);
                    break;
                }
                if (eventName.equals("hipMemcpy")) {
                    this.addGpuActivityDispatch(event, 4);
                    break;
                }
                if (!eventName.equals("hipDeviceSynchronize")) break;
                this.addWaitDependencies(event, ssb);
                break;
            }
            case "hip_activity": {
                String eventName = (String)event.getContent().getFieldValue(String.class, new String[]{"name"});
                if (eventName != null && eventName.equals("KernelExecution")) {
                    this.addKernelDependency(event, ssb);
                } else if (eventName != null && eventName.startsWith("Copy")) {
                    this.addMemoryDependency(event, ssb);
                }
                Long correlationId = (Long)event.getContent().getFieldValue(Long.class, new String[]{"correlation_id"});
                this.fApiEventCorrelationMap.remove(correlationId);
            }
        }
    }

    private void addGpuActivityDispatch(ITmfEvent event, int argPosition) {
        Long correlationId = Long.parseLong(ApiEventHandler.getArg(event.getContent(), argPosition));
        this.fApiEventCorrelationMap.put(correlationId, event);
    }

    private void addKernelDependency(ITmfEvent hipActivityEvent, ITmfStateSystemBuilder ssb) {
        Integer queueId = (Integer)hipActivityEvent.getContent().getFieldValue(Integer.class, new String[]{"queue_id"});
        Long correlationId = (Long)hipActivityEvent.getContent().getFieldValue(Long.class, new String[]{"correlation_id"});
        Long gpuId = (Long)hipActivityEvent.getContent().getFieldValue(Long.class, new String[]{"device_id"});
        if (correlationId != null && queueId != null && gpuId != null) {
            ITmfEvent hipEvent = this.fApiEventCorrelationMap.get(correlationId);
            if (hipEvent == null) {
                return;
            }
            this.addInFlightEvent(hipEvent, hipActivityEvent, ssb);
            Integer tid = (Integer)hipEvent.getContent().getFieldValue(Integer.class, new String[]{"tid"});
            if (tid == null) {
                return;
            }
            int hipStreamId = Integer.parseInt(ApiEventHandler.getArg(hipEvent.getContent(), 4));
            HostThreadIdentifier srcHostThreadIdentifier = new HostThreadIdentifier(hipEvent, tid);
            HostThreadIdentifier dstStreamHostThreadIdentifier = new HostThreadIdentifier(hipStreamId, HostThreadIdentifier.KERNEL_CATEGORY.STREAM, gpuId.intValue());
            HostThreadIdentifier dstQueueHostThreadIdentifier = new HostThreadIdentifier((int)queueId, HostThreadIdentifier.KERNEL_CATEGORY.QUEUE, gpuId.intValue());
            HostThread src = new HostThread(hipEvent.getTrace().getHostId(), Integer.valueOf(srcHostThreadIdentifier.hashCode()));
            HostThread destStream = new HostThread(hipEvent.getTrace().getHostId(), Integer.valueOf(dstStreamHostThreadIdentifier.hashCode()));
            HostThread destQueue = new HostThread(hipEvent.getTrace().getHostId(), Integer.valueOf(dstQueueHostThreadIdentifier.hashCode()));
            Long hipStreamEventEndTimestamp = AbstractGpuEventHandler.getEndTime(hipEvent);
            if (hipStreamEventEndTimestamp != null) {
                HipApiHipActivityDependencyMaker.addArrow(ssb, hipStreamEventEndTimestamp - 1L, hipActivityEvent.getTimestamp().getValue(), Math.toIntExact(correlationId), src, destQueue);
                HipApiHipActivityDependencyMaker.addArrow(ssb, hipStreamEventEndTimestamp - 1L, hipActivityEvent.getTimestamp().getValue(), Math.toIntExact(correlationId), src, destStream);
            }
        }
    }

    private void addMemoryDependency(ITmfEvent hipActivityEvent, ITmfStateSystemBuilder ssb) {
        Long correlationId = (Long)hipActivityEvent.getContent().getFieldValue(Long.class, new String[]{"correlation_id"});
        if (correlationId == null) {
            return;
        }
        ITmfEvent hipEvent = this.fApiEventCorrelationMap.get(correlationId);
        if (hipEvent == null) {
            return;
        }
        this.addInFlightEvent(hipEvent, hipActivityEvent, ssb);
        Integer tid = (Integer)hipEvent.getContent().getFieldValue(Integer.class, new String[]{"tid"});
        if (tid == null) {
            return;
        }
        HostThreadIdentifier srcHostThreadIdentifier = new HostThreadIdentifier(hipEvent, tid);
        HostThreadIdentifier dstHostThreadIdentifier = new HostThreadIdentifier();
        HostThread src = new HostThread(hipEvent.getTrace().getHostId(), Integer.valueOf(srcHostThreadIdentifier.hashCode()));
        HostThread dst = new HostThread(hipEvent.getTrace().getHostId(), Integer.valueOf(dstHostThreadIdentifier.hashCode()));
        HipApiHipActivityDependencyMaker.addArrow(ssb, hipEvent.getTimestamp().getValue(), hipActivityEvent.getTimestamp().getValue(), Math.toIntExact(correlationId), src, dst);
    }

    private void addInFlightEvent(ITmfEvent hipApiEvent, ITmfEvent hipActivityEvent, ITmfStateSystemBuilder ssb) {
        this.fInFlightEvents.add((Pair<ITmfEvent, ITmfEvent>)new Pair((Object)hipApiEvent, (Object)hipActivityEvent));
        Long beginTs = hipApiEvent.getTimestamp().getValue();
        Long endTs = AbstractGpuEventHandler.getEndTime(hipApiEvent);
        Iterator<ITmfEvent> waitEventIterator = this.fWaitEventPerThread.iterator();
        while (waitEventIterator.hasNext()) {
            ITmfEvent waitEvent = waitEventIterator.next();
            Long dependencyBeginTs = waitEvent.getTimestamp().getValue();
            Long dependencyEndTs = AbstractGpuEventHandler.getEndTime(waitEvent);
            if (beginTs > dependencyEndTs) {
                waitEventIterator.remove();
                continue;
            }
            if (beginTs > dependencyBeginTs) continue;
            if (beginTs < dependencyBeginTs && endTs > dependencyBeginTs) {
                Activator.getInstance().logError("If you see this message, the wait dependencies behavior should be changed.");
            }
            this.addWaitDependencies(waitEvent, ssb);
        }
    }

    private void addWaitDependencies(ITmfEvent hipWaitEvent, ITmfStateSystemBuilder ssb) {
        String apiFunctionName = ApiEventHandler.getFunctionApiName(hipWaitEvent);
        Integer waitTid = (Integer)hipWaitEvent.getContent().getFieldValue(Integer.class, new String[]{"tid"});
        if (waitTid == null) {
            return;
        }
        if (!this.fWaitEventPerThread.contains(hipWaitEvent)) {
            this.fWaitEventPerThread.add(hipWaitEvent);
        }
        if (apiFunctionName.equals("hipDeviceSynchronize")) {
            int waitingForDevice = 0;
            for (Pair pair : this.fInFlightEvents) {
                Integer deviceId = (Integer)((ITmfEvent)pair.getSecond()).getContent().getFieldValue(Integer.class, new String[]{"device_id"});
                if (deviceId == null || deviceId != waitingForDevice) continue;
                int hipStreamId = Integer.parseInt(ApiEventHandler.getArg(((ITmfEvent)pair.getFirst()).getContent(), 4));
                HipApiHipActivityDependencyMaker.addWaitArrow(ssb, (ITmfEvent)pair.getSecond(), hipWaitEvent, waitTid, waitingForDevice, hipStreamId);
            }
        }
    }

    private static void addWaitArrow(ITmfStateSystemBuilder ssb, ITmfEvent deviceEvent, ITmfEvent waitEvent, int waitTid, int deviceId, int hipStreamId) {
        Integer queueId = (Integer)deviceEvent.getContent().getFieldValue(Integer.class, new String[]{"queue_id"});
        if (queueId == null) {
            return;
        }
        HostThreadIdentifier srcHostThreadIdentifier = new HostThreadIdentifier(waitEvent, waitTid);
        HostThreadIdentifier dstStreamHostThreadIdentifier = new HostThreadIdentifier(hipStreamId, HostThreadIdentifier.KERNEL_CATEGORY.STREAM, deviceId);
        HostThreadIdentifier dstQueueHostThreadIdentifier = new HostThreadIdentifier((int)queueId, HostThreadIdentifier.KERNEL_CATEGORY.QUEUE, deviceId);
        HostThread destThread = new HostThread(waitEvent.getTrace().getHostId(), Integer.valueOf(srcHostThreadIdentifier.hashCode()));
        HostThread srcStream = new HostThread(deviceEvent.getTrace().getHostId(), Integer.valueOf(dstStreamHostThreadIdentifier.hashCode()));
        HostThread srcQueue = new HostThread(deviceEvent.getTrace().getHostId(), Integer.valueOf(dstQueueHostThreadIdentifier.hashCode()));
        Long hipStreamEventEndTimestamp = AbstractGpuEventHandler.getEndTime(waitEvent) - 1L;
        Long correlationId = (Long)deviceEvent.getContent().getFieldValue(Long.class, new String[]{"correlation_id"});
        if (correlationId != null) {
            HipApiHipActivityDependencyMaker.addArrow(ssb, AbstractGpuEventHandler.getEndTime(deviceEvent) - 1L, hipStreamEventEndTimestamp, Math.toIntExact(correlationId), srcStream, destThread);
            HipApiHipActivityDependencyMaker.addArrow(ssb, AbstractGpuEventHandler.getEndTime(deviceEvent) - 1L, hipStreamEventEndTimestamp, Math.toIntExact(correlationId), srcQueue, destThread);
        }
    }

    private void removeInFlightEvents(ITmfEvent event) {
        Long currentTime = event.getTimestamp().getValue();
        Pair<ITmfEvent, ITmfEvent> inFlightEvent = this.fInFlightEvents.peek();
        while (inFlightEvent != null && currentTime >= AbstractGpuEventHandler.getEndTime((ITmfEvent)inFlightEvent.getSecond())) {
            this.fInFlightEvents.remove();
            inFlightEvent = this.fInFlightEvents.peek();
        }
    }

    @Override
    public Map<Long, ITmfEvent> getApiEventCorrelationMap() {
        return this.fApiEventCorrelationMap;
    }

    public static class EventComparator
    implements Comparator<Pair<ITmfEvent, ITmfEvent>> {
        @Override
        public int compare(Pair<ITmfEvent, ITmfEvent> pair1, Pair<ITmfEvent, ITmfEvent> pair2) {
            Long endTime1 = AbstractGpuEventHandler.getEndTime((ITmfEvent)pair1.getSecond());
            Long endTime2 = AbstractGpuEventHandler.getEndTime((ITmfEvent)pair2.getSecond());
            if (endTime1 == null || endTime2 == null) {
                return -1;
            }
            return (int)(endTime1 - endTime2);
        }
    }
}

