/*
 * Decompiled with CFR 0.152.
 */
package com.ibm.gpu;

import java.io.IOException;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.nio.charset.StandardCharsets;

final class PtxKernelGenerator {
    private final String maxValue;
    private int stepCount;
    private final boolean typeIsScalar;
    private final String typeName;
    private final int typeSize;
    private final OutputStreamWriter writer;

    public static void writeTo(int capability, char elementType, OutputStream out) throws IOException {
        PtxKernelGenerator generator = new PtxKernelGenerator(out, elementType);
        generator.generate(capability);
    }

    private PtxKernelGenerator(OutputStream out, char type) {
        switch (type) {
            case 'D': {
                this.typeIsScalar = false;
                this.typeName = ".f64";
                this.typeSize = 8;
                this.maxValue = "0dFFF8000000000000";
                break;
            }
            case 'F': {
                this.typeIsScalar = false;
                this.typeName = ".f32";
                this.typeSize = 4;
                this.maxValue = "0f7FFFFFFF";
                break;
            }
            case 'I': {
                this.typeIsScalar = true;
                this.typeName = ".s32";
                this.typeSize = 4;
                this.maxValue = "0x" + Integer.toHexString(Integer.MAX_VALUE);
                break;
            }
            case 'J': {
                this.typeIsScalar = true;
                this.typeName = ".s64";
                this.typeSize = 8;
                this.maxValue = "0x" + Long.toHexString(Long.MAX_VALUE);
                break;
            }
            default: {
                throw new IllegalArgumentException(String.valueOf(type));
            }
        }
        this.stepCount = 0;
        this.writer = new OutputStreamWriter(out, StandardCharsets.US_ASCII);
    }

    private void append(String line) throws IOException {
        this.writer.append(line).append('\n');
    }

    private void compare(int index0, int index1) throws IOException {
        if (this.typeIsScalar) {
            this.format("setp.gt%s p0,vl%d,vl%d;", this.typeName, index0, index1);
        } else {
            this.format("testp.number%s p1,vl%d;", this.typeName, index1);
            this.format("setp.gtu.and%s p0,vl%d,vl%d,p1;", this.typeName, index0, index1);
            this.format("@!p0 mov.b%d vs0,vl%d;", this.typeSize * 8, index0);
            this.format("@!p0 mov.b%d vs1,vl%d;", this.typeSize * 8, index1);
            this.format("@!p0 setp.eq%s p1,vl%d,vl%d;", this.typeName, index0, index1);
            this.format("@!p0 setp.gt.and.s%d p0,vs0,vs1,p1;", this.typeSize * 8);
        }
    }

    private void compareAndSwap(int index0, int index1) throws IOException {
        this.compare(index0, index1);
        this.format("@p0 mov%s tmp,vl%d;", this.typeName, index0);
        this.format("@p0 mov%s vl%d,vl%d;", this.typeName, index0, index1);
        this.format("@p0 mov%s vl%d,tmp;", this.typeName, index1);
        this.format("@p0 or.b32 moved,moved,%s;", this.constant(1 << index0 | 1 << index1));
    }

    private void computeIndices(boolean first) throws IOException {
        assert (1 <= this.stepCount && this.stepCount <= 5);
        int groupSize = 1 << this.stepCount;
        String step = "stride";
        if (this.stepCount != 1) {
            step = "step";
            this.format("shr.u32 step,stride,%d;", this.stepCount - 1);
        }
        this.format("sub.s32 mask,%s,1;", step);
        this.append("mov.u32 rt0,%nctaid.x;");
        this.append("mov.u32 rt1,%ctaid.y;");
        this.append("mov.u32 rt2,%ctaid.x;");
        this.append("mad.lo.u32 threadId,rt0,rt1,rt2;");
        this.append("mov.u32 rt0,%ntid.x;");
        this.append("mov.u32 rt1,%tid.x;");
        this.append("mad.lo.u32 threadId,threadId,rt0,rt1;");
        this.append("not.b32 rt0,mask;");
        this.append("and.b32 rt0,rt0,threadId;");
        this.append("and.b32 rt1,threadId,mask;");
        this.format("mad.lo.u32 ix0,rt0,%d,rt1;", groupSize);
        if (first) {
            int halfGroupSize = groupSize >> 1;
            int i = 0;
            while (++i < halfGroupSize) {
                this.format("add.u32 ix%d,ix%d,%s;", i, i - 1, step);
            }
            this.append("mad.lo.u32 rt0,stride,2,-1;");
            this.format("xor.b32 ix%d,ix%d,rt0;", halfGroupSize, halfGroupSize - 1);
            i = halfGroupSize;
            while (++i < groupSize) {
                this.format("add.u32 ix%d,ix%d,%s;", i, i - 1, step);
            }
        } else {
            int i = 0;
            while (++i < groupSize) {
                this.format("add.u32 ix%d,ix%d,%s;", i, i - 1, step);
            }
        }
    }

    private String constant(int value) {
        String format;
        switch (this.stepCount) {
            case 1: {
                format = "%d";
                break;
            }
            case 2: {
                format = "0x%x";
                break;
            }
            case 3: {
                format = "0x%02x";
                break;
            }
            default: {
                format = "0x%04x";
            }
        }
        return String.format(format, value);
    }

    private void declareLocals() throws IOException {
        int groupSize = 1 << this.stepCount;
        this.append(".reg .u64 data;");
        this.append(".reg .u32 length;");
        this.append(".reg .u32 stride;");
        this.append(".reg .u32 threadId;");
        this.append(".reg .u32 mask;");
        if (this.stepCount != 1) {
            this.append(".reg .u32 step;");
        }
        this.format(".reg %s tmp;", this.typeName);
        this.append(".reg .b32 moved;");
        this.append(".reg .b32 bit;");
        this.append(".reg .u32 rt<3>;");
        this.format(".reg .u32 ix<%d>;", groupSize);
        this.format(".reg %s vl<%d>;", this.typeName, groupSize);
        this.append(".reg .pred p<2>;");
        this.format(".reg .s%d vs<2>;", this.typeSize * 8);
        this.append(".reg .u64 ptr;");
    }

    private void emitFirstPhases() throws IOException {
        int phaseCount = 9;
        int inputCount = 512;
        int workUnits = 256;
        this.append(".visible .entry");
        this.format("phase%d(.param .u64 _data,.param .u32 _length)", 9);
        this.append(".maxntid 256,1,1");
        this.append("{");
        this.format(".shared .align %d %s _sharedData[%d];", this.typeSize, this.typeName, 512);
        this.append(".reg .u64 data;");
        this.append(".reg .u32 length;");
        this.append(".reg .u64 sharedData;");
        this.append(".reg .u64 dataPtr;");
        this.append(".reg .u64 sharedPtr<2>;");
        this.append(".reg .u32 baseIndex;");
        this.append(".reg .u32 blockDimX;");
        this.append(".reg .u32 globalIndex;");
        this.append(".reg .u32 workId;");
        this.append(".reg .pred p<2>;");
        this.format(".reg .s%d vs<2>;", this.typeSize * 8);
        this.append(".reg .u32 ix<2>;");
        this.append(".reg .u32 rt<3>;");
        this.format(".reg %s vl<2>;", this.typeName);
        this.append("ld.param.u64 data,[_data];");
        this.append("cvta.to.global.u64 data,data;");
        this.append("ld.param.u32 length,[_length];");
        this.append("mov.u64 sharedData,_sharedData;");
        this.append("mov.u32 blockDimX,%ntid.x;");
        this.append("mov.u32 rt0,%nctaid.x;");
        this.append("mov.u32 rt1,%ctaid.y;");
        this.append("mov.u32 rt2,%ctaid.x;");
        this.append("mad.lo.u32 baseIndex,rt0,rt1,rt2;");
        this.format("shl.b32 baseIndex,baseIndex,%d;", 9);
        this.append("mov.u32 workId,%tid.x;");
        this.append("bra loadTest;");
        this.append("loadLoop:");
        this.append("add.u32 globalIndex,baseIndex,workId;");
        this.format("mov%s vl0,%s;", this.typeName, this.maxValue);
        this.append("setp.lt.u32 p0,globalIndex,length;");
        this.format("@p0 mad.wide.u32 dataPtr,globalIndex,%d,data;", this.typeSize);
        this.format("@p0 ld.global%s vl0,[dataPtr];", this.typeName);
        this.format("mad.wide.u32 sharedPtr0,workId,%d,sharedData;", this.typeSize);
        this.format("st.shared%s [sharedPtr0],vl0;", this.typeName);
        this.append("add.u32 workId,workId,blockDimX;");
        this.append("loadTest:");
        this.format("setp.lt.u32 p0,workId,%d;", 512);
        this.append("@p0 bra loadLoop;");
        for (int phase = 0; phase < 9; ++phase) {
            for (int step = 0; step <= phase; ++step) {
                this.append("bar.sync 0;");
                String workLoop = String.format("workLoop_%d_%d", phase + 1, step + 1);
                String workTest = String.format("workTest_%d_%d", phase + 1, step + 1);
                this.append("mov.u32 workId,%tid.x;");
                this.format("bra %s;", workTest);
                this.format("%s:", workLoop);
                if (step == phase) {
                    this.append("shl.b32 ix0,workId,1;");
                } else {
                    this.append("shl.b32 ix0,workId,1;");
                    this.format("and.b32 rt0,workId,%s;", this.constant((1 << phase - step) - 1));
                    this.append("sub.u32 ix0,ix0,rt0;");
                }
                if (step == 0 && step != phase) {
                    this.format("xor.b32 ix1,ix0,%s;", this.constant((2 << phase) - 1));
                } else {
                    this.format("add.u32 ix1,ix0,%s;", this.constant(1 << phase - step));
                }
                this.format("mad.wide.u32 sharedPtr0,ix0,%d,sharedData;", this.typeSize);
                this.format("ld.shared%s vl0,[sharedPtr0];", this.typeName);
                this.format("mad.wide.u32 sharedPtr1,ix1,%d,sharedData;", this.typeSize);
                this.format("ld.shared%s vl1,[sharedPtr1];", this.typeName);
                this.compare(0, 1);
                this.format("@p0 st.shared%s [sharedPtr0],vl1;", this.typeName);
                this.format("@p0 st.shared%s [sharedPtr1],vl0;", this.typeName);
                this.append("add.u32 workId,workId,blockDimX;");
                this.format("%s:", workTest);
                this.format("setp.lt.u32 p0,workId,%d;", 256);
                this.format("@p0 bra %s;", workLoop);
            }
        }
        this.append("bar.sync 0;");
        this.append("mov.u32 workId,%tid.x;");
        this.append("bra storeTest;");
        this.append("storeLoop:");
        this.append("{");
        this.append("add.u32 globalIndex,baseIndex,workId;");
        this.append("setp.lt.u32 p0,globalIndex,length;");
        this.format("@p0 mad.wide.u32 sharedPtr0,workId,%d,sharedData;", this.typeSize);
        this.format("@p0 ld.shared%s vl0,[sharedPtr0];", this.typeName);
        this.format("@p0 mad.wide.u32 dataPtr,globalIndex,%d,data;", this.typeSize);
        this.format("@p0 st.global%s [dataPtr],vl0;", this.typeName);
        this.append("add.u32 workId,workId,blockDimX;");
        this.append("}");
        this.append("storeTest:");
        this.format("setp.lt.u32 p0,workId,%d;", 512);
        this.append("@p0 bra storeLoop;");
        this.append("}");
    }

    private void emitKernel(boolean first) throws IOException {
        this.append(".visible .entry");
        this.format("%s%d(.param .u64 _data,.param .u32 _length,.param .u32 _stride)", first ? "first" : "other", this.stepCount);
        this.append(".maxntid 256,1,1");
        this.append("{");
        this.declareLocals();
        this.append("ld.param.u64 data,[_data];");
        this.append("cvta.to.global.u64 data,data;");
        this.append("ld.param.u32 length,[_length];");
        this.append("ld.param.u32 stride,[_stride];");
        this.computeIndices(first);
        this.gatherData();
        this.sortLocally(first);
        this.scatterData();
        this.append("}");
    }

    private void emitPreamble(int capability) throws IOException {
        this.append(".version 3.2");
        this.format(".target sm_%d", capability < 3 ? 20 : 30);
        this.append(".address_size 64");
    }

    private void format(String format, Object ... arguments) throws IOException {
        this.append(String.format(format, arguments));
    }

    private void gatherData() throws IOException {
        int groupSize = 1 << this.stepCount;
        for (int i = 0; i < groupSize; ++i) {
            this.format("mov%s vl%d,%s;", this.typeName, i, this.maxValue);
            this.format("setp.lt.u32 p0,ix%d,length;", i);
            this.format("@p0 mad.wide.u32 ptr,ix%d,%d,data;", i, this.typeSize);
            this.format("@p0 ld.global%s vl%d,[ptr];", this.typeName, i);
        }
    }

    private void generate(int capability) throws IOException {
        this.emitPreamble(capability);
        int MaxStepCount = 4;
        this.stepCount = 1;
        while (true) {
            this.emitKernel(false);
            if (this.stepCount == 4) break;
            ++this.stepCount;
        }
        this.emitKernel(true);
        this.emitFirstPhases();
        this.writer.flush();
    }

    private void scatterData() throws IOException {
        int groupSize = 1 << this.stepCount;
        for (int i = 0; i < groupSize; ++i) {
            this.format("and.b32 bit,moved,%s;", this.constant(1 << i));
            this.append("setp.ne.b32 p0,bit,0;");
            this.format("@p0 mad.wide.u32 ptr,ix%d,%d,data;", i, this.typeSize);
            this.format("@p0 st.global%s [ptr],vl%d;", this.typeName, i);
        }
    }

    private void sortLocally(boolean first) throws IOException {
        int groupSize = 1 << this.stepCount;
        int halfGroupSize = groupSize >> 1;
        int step = 0;
        this.append("mov.b32 moved,0;");
        if (first) {
            int groupMirror = groupSize - 1;
            for (int index0 = 0; index0 < halfGroupSize; ++index0) {
                int index1 = index0 ^ groupMirror;
                this.compareAndSwap(index0, index1);
            }
            ++step;
        }
        while (step < this.stepCount) {
            int stepStride = groupSize >> step + 1;
            int baseStep = stepStride << 1;
            for (int base = 0; base < groupSize; base += baseStep) {
                for (int index = 0; index < stepStride; ++index) {
                    int index0 = base + (index & -stepStride) + index;
                    int index1 = index0 + stepStride;
                    this.compareAndSwap(index0, index1);
                }
            }
            ++step;
        }
    }
}

