package ai.onnxruntime;

import N5.K0;
import ai.onnxruntime.OnnxValue;
import ai.onnxruntime.OrtUtil;
import ai.onnxruntime.TensorInfo;
import ai.onnxruntime.platform.Fp16Conversions;
import j$.util.Optional;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.nio.ShortBuffer;
import java.util.logging.Logger;

/* loaded from: classes.dex */
public class OnnxTensor extends OnnxTensorLike {
    private static final Logger logger = Logger.getLogger(OnnxTensor.class.getName());
    private final Buffer buffer;
    private final boolean ownsBuffer;

    /* renamed from: ai.onnxruntime.OnnxTensor$1, reason: invalid class name */
    /* loaded from: classes.dex */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$ai$onnxruntime$OnnxJavaType;

        static {
            int[] iArr = new int[OnnxJavaType.values().length];
            $SwitchMap$ai$onnxruntime$OnnxJavaType = iArr;
            try {
                iArr[OnnxJavaType.FLOAT.ordinal()] = 1;
            } catch (NoSuchFieldError unused) {
            }
            try {
                $SwitchMap$ai$onnxruntime$OnnxJavaType[OnnxJavaType.DOUBLE.ordinal()] = 2;
            } catch (NoSuchFieldError unused2) {
            }
            try {
                $SwitchMap$ai$onnxruntime$OnnxJavaType[OnnxJavaType.UINT8.ordinal()] = 3;
            } catch (NoSuchFieldError unused3) {
            }
            try {
                $SwitchMap$ai$onnxruntime$OnnxJavaType[OnnxJavaType.INT8.ordinal()] = 4;
            } catch (NoSuchFieldError unused4) {
            }
            try {
                $SwitchMap$ai$onnxruntime$OnnxJavaType[OnnxJavaType.INT16.ordinal()] = 5;
            } catch (NoSuchFieldError unused5) {
            }
            try {
                $SwitchMap$ai$onnxruntime$OnnxJavaType[OnnxJavaType.INT32.ordinal()] = 6;
            } catch (NoSuchFieldError unused6) {
            }
            try {
                $SwitchMap$ai$onnxruntime$OnnxJavaType[OnnxJavaType.INT64.ordinal()] = 7;
            } catch (NoSuchFieldError unused7) {
            }
            try {
                $SwitchMap$ai$onnxruntime$OnnxJavaType[OnnxJavaType.BOOL.ordinal()] = 8;
            } catch (NoSuchFieldError unused8) {
            }
            try {
                $SwitchMap$ai$onnxruntime$OnnxJavaType[OnnxJavaType.STRING.ordinal()] = 9;
            } catch (NoSuchFieldError unused9) {
            }
            try {
                $SwitchMap$ai$onnxruntime$OnnxJavaType[OnnxJavaType.FLOAT16.ordinal()] = 10;
            } catch (NoSuchFieldError unused10) {
            }
            try {
                $SwitchMap$ai$onnxruntime$OnnxJavaType[OnnxJavaType.BFLOAT16.ordinal()] = 11;
            } catch (NoSuchFieldError unused11) {
            }
            try {
                $SwitchMap$ai$onnxruntime$OnnxJavaType[OnnxJavaType.UNKNOWN.ordinal()] = 12;
            } catch (NoSuchFieldError unused12) {
            }
        }
    }

    public OnnxTensor(long j10, long j11, TensorInfo tensorInfo) {
        this(j10, j11, tensorInfo, null, false);
    }

    public OnnxTensor(long j10, long j11, TensorInfo tensorInfo, Buffer buffer, boolean z10) {
        super(j10, j11, tensorInfo);
        this.buffer = buffer;
        this.ownsBuffer = z10;
    }

    private native void close(long j10, long j11);

    private static native long createString(long j10, long j11, String str) throws OrtException;

    private static native long createStringTensor(long j10, long j11, Object[] objArr, long[] jArr) throws OrtException;

    private static native long createTensor(long j10, long j11, Object obj, long[] jArr, int i10) throws OrtException;

    private static OnnxTensor createTensor(OnnxJavaType onnxJavaType, OrtAllocator ortAllocator, Buffer buffer, long[] jArr) throws OrtException {
        OrtUtil.BufferTuple prepareBuffer = OrtUtil.prepareBuffer(buffer, onnxJavaType);
        TensorInfo constructFromBuffer = TensorInfo.constructFromBuffer(prepareBuffer.data, jArr, onnxJavaType);
        return new OnnxTensor(createTensorFromBuffer(OnnxRuntime.ortApiHandle, ortAllocator.handle, prepareBuffer.data, prepareBuffer.pos, prepareBuffer.byteSize, jArr, constructFromBuffer.onnxType.value), ortAllocator.handle, constructFromBuffer, prepareBuffer.data, prepareBuffer.isCopy);
    }

    public static OnnxTensor createTensor(OrtEnvironment ortEnvironment, OrtAllocator ortAllocator, Object obj) throws OrtException {
        if (ortAllocator.isClosed()) {
            throw new IllegalStateException("Trying to create an OnnxTensor with a closed OrtAllocator.");
        }
        TensorInfo constructFromJavaArray = TensorInfo.constructFromJavaArray(obj);
        OnnxJavaType onnxJavaType = constructFromJavaArray.type;
        if (onnxJavaType == OnnxJavaType.STRING) {
            return constructFromJavaArray.shape.length == 0 ? new OnnxTensor(createString(OnnxRuntime.ortApiHandle, ortAllocator.handle, (String) obj), ortAllocator.handle, constructFromJavaArray) : new OnnxTensor(createStringTensor(OnnxRuntime.ortApiHandle, ortAllocator.handle, OrtUtil.flattenString(obj), constructFromJavaArray.shape), ortAllocator.handle, constructFromJavaArray);
        }
        if (constructFromJavaArray.shape.length != 0 || (obj = OrtUtil.convertBoxedPrimitiveToArray(onnxJavaType, obj)) != null) {
            return new OnnxTensor(createTensor(OnnxRuntime.ortApiHandle, ortAllocator.handle, obj, constructFromJavaArray.shape, constructFromJavaArray.onnxType.value), ortAllocator.handle, constructFromJavaArray);
        }
        throw new OrtException("Failed to convert a boxed primitive to an array, this is an error with the ORT Java API, please report this message & stack trace. JavaType = " + constructFromJavaArray.type + ", object = " + obj);
    }

    public static OnnxTensor createTensor(OrtEnvironment ortEnvironment, OrtAllocator ortAllocator, ByteBuffer byteBuffer, long[] jArr) throws OrtException {
        return createTensor(ortEnvironment, ortAllocator, byteBuffer, jArr, OnnxJavaType.INT8);
    }

    public static OnnxTensor createTensor(OrtEnvironment ortEnvironment, OrtAllocator ortAllocator, ByteBuffer byteBuffer, long[] jArr, OnnxJavaType onnxJavaType) throws OrtException {
        if (ortAllocator.isClosed()) {
            throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator.");
        }
        return createTensor(onnxJavaType, ortAllocator, byteBuffer, jArr);
    }

    public static OnnxTensor createTensor(OrtEnvironment ortEnvironment, OrtAllocator ortAllocator, DoubleBuffer doubleBuffer, long[] jArr) throws OrtException {
        if (ortAllocator.isClosed()) {
            throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator.");
        }
        return createTensor(OnnxJavaType.DOUBLE, ortAllocator, doubleBuffer, jArr);
    }

    public static OnnxTensor createTensor(OrtEnvironment ortEnvironment, OrtAllocator ortAllocator, FloatBuffer floatBuffer, long[] jArr) throws OrtException {
        if (ortAllocator.isClosed()) {
            throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator.");
        }
        return createTensor(OnnxJavaType.FLOAT, ortAllocator, floatBuffer, jArr);
    }

    public static OnnxTensor createTensor(OrtEnvironment ortEnvironment, OrtAllocator ortAllocator, IntBuffer intBuffer, long[] jArr) throws OrtException {
        if (ortAllocator.isClosed()) {
            throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator.");
        }
        return createTensor(OnnxJavaType.INT32, ortAllocator, intBuffer, jArr);
    }

    public static OnnxTensor createTensor(OrtEnvironment ortEnvironment, OrtAllocator ortAllocator, LongBuffer longBuffer, long[] jArr) throws OrtException {
        if (ortAllocator.isClosed()) {
            throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator.");
        }
        return createTensor(OnnxJavaType.INT64, ortAllocator, longBuffer, jArr);
    }

    public static OnnxTensor createTensor(OrtEnvironment ortEnvironment, OrtAllocator ortAllocator, ShortBuffer shortBuffer, long[] jArr) throws OrtException {
        if (ortAllocator.isClosed()) {
            throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator.");
        }
        return createTensor(OnnxJavaType.INT16, ortAllocator, shortBuffer, jArr);
    }

    public static OnnxTensor createTensor(OrtEnvironment ortEnvironment, OrtAllocator ortAllocator, String[] strArr, long[] jArr) throws OrtException {
        if (ortAllocator.isClosed()) {
            throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator.");
        }
        return new OnnxTensor(createStringTensor(OnnxRuntime.ortApiHandle, ortAllocator.handle, strArr, jArr), ortAllocator.handle, new TensorInfo(jArr, OnnxJavaType.STRING, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING));
    }

    public static OnnxTensor createTensor(OrtEnvironment ortEnvironment, Object obj) throws OrtException {
        return createTensor(ortEnvironment, ortEnvironment.defaultAllocator, obj);
    }

    public static OnnxTensor createTensor(OrtEnvironment ortEnvironment, ByteBuffer byteBuffer, long[] jArr) throws OrtException {
        return createTensor(ortEnvironment, ortEnvironment.defaultAllocator, byteBuffer, jArr);
    }

    public static OnnxTensor createTensor(OrtEnvironment ortEnvironment, ByteBuffer byteBuffer, long[] jArr, OnnxJavaType onnxJavaType) throws OrtException {
        return createTensor(ortEnvironment, ortEnvironment.defaultAllocator, byteBuffer, jArr, onnxJavaType);
    }

    public static OnnxTensor createTensor(OrtEnvironment ortEnvironment, DoubleBuffer doubleBuffer, long[] jArr) throws OrtException {
        return createTensor(ortEnvironment, ortEnvironment.defaultAllocator, doubleBuffer, jArr);
    }

    public static OnnxTensor createTensor(OrtEnvironment ortEnvironment, FloatBuffer floatBuffer, long[] jArr) throws OrtException {
        return createTensor(ortEnvironment, ortEnvironment.defaultAllocator, floatBuffer, jArr);
    }

    public static OnnxTensor createTensor(OrtEnvironment ortEnvironment, IntBuffer intBuffer, long[] jArr) throws OrtException {
        return createTensor(ortEnvironment, ortEnvironment.defaultAllocator, intBuffer, jArr);
    }

    public static OnnxTensor createTensor(OrtEnvironment ortEnvironment, LongBuffer longBuffer, long[] jArr) throws OrtException {
        return createTensor(ortEnvironment, ortEnvironment.defaultAllocator, longBuffer, jArr);
    }

    public static OnnxTensor createTensor(OrtEnvironment ortEnvironment, ShortBuffer shortBuffer, long[] jArr) throws OrtException {
        return createTensor(ortEnvironment, ortEnvironment.defaultAllocator, shortBuffer, jArr);
    }

    public static OnnxTensor createTensor(OrtEnvironment ortEnvironment, String[] strArr, long[] jArr) throws OrtException {
        return createTensor(ortEnvironment, ortEnvironment.defaultAllocator, strArr, jArr);
    }

    private static native long createTensorFromBuffer(long j10, long j11, Buffer buffer, int i10, long j12, long[] jArr, int i11) throws OrtException;

    private native void getArray(long j10, long j11, Object obj) throws OrtException;

    private native boolean getBool(long j10, long j11) throws OrtException;

    private ByteBuffer getBuffer() {
        return getBuffer(OnnxRuntime.ortApiHandle, this.nativeHandle).order(ByteOrder.nativeOrder());
    }

    private native ByteBuffer getBuffer(long j10, long j11);

    private native byte getByte(long j10, long j11, int i10) throws OrtException;

    private native double getDouble(long j10, long j11) throws OrtException;

    private native float getFloat(long j10, long j11, int i10) throws OrtException;

    private native int getInt(long j10, long j11, int i10) throws OrtException;

    private native long getLong(long j10, long j11, int i10) throws OrtException;

    private native short getShort(long j10, long j11, int i10) throws OrtException;

    private native String getString(long j10, long j11) throws OrtException;

    @Override // ai.onnxruntime.OnnxValue, java.lang.AutoCloseable
    public synchronized void close() {
        try {
            if (this.closed) {
                logger.warning("Closing an already closed tensor.");
            } else {
                close(OnnxRuntime.ortApiHandle, this.nativeHandle);
                this.closed = true;
            }
        } catch (Throwable th) {
            throw th;
        }
    }

    public Optional<Buffer> getBufferRef() {
        return Optional.ofNullable(this.buffer);
    }

    public ByteBuffer getByteBuffer() {
        checkClosed();
        if (this.info.type == OnnxJavaType.STRING) {
            return null;
        }
        ByteBuffer buffer = getBuffer(OnnxRuntime.ortApiHandle, this.nativeHandle);
        ByteBuffer allocate = ByteBuffer.allocate(buffer.capacity());
        allocate.put(buffer);
        allocate.rewind();
        return allocate;
    }

    public DoubleBuffer getDoubleBuffer() {
        checkClosed();
        if (this.info.type != OnnxJavaType.DOUBLE) {
            return null;
        }
        DoubleBuffer asDoubleBuffer = getBuffer().asDoubleBuffer();
        DoubleBuffer allocate = DoubleBuffer.allocate(asDoubleBuffer.capacity());
        allocate.put(asDoubleBuffer);
        allocate.rewind();
        return allocate;
    }

    public FloatBuffer getFloatBuffer() {
        checkClosed();
        OnnxJavaType onnxJavaType = this.info.type;
        if (onnxJavaType == OnnxJavaType.FLOAT) {
            FloatBuffer asFloatBuffer = getBuffer().asFloatBuffer();
            FloatBuffer allocate = FloatBuffer.allocate(asFloatBuffer.capacity());
            allocate.put(asFloatBuffer);
            allocate.rewind();
            return allocate;
        }
        if (onnxJavaType == OnnxJavaType.FLOAT16) {
            return Fp16Conversions.convertFp16BufferToFloatBuffer(getBuffer().asShortBuffer());
        }
        if (onnxJavaType == OnnxJavaType.BFLOAT16) {
            return Fp16Conversions.convertBf16BufferToFloatBuffer(getBuffer().asShortBuffer());
        }
        return null;
    }

    public IntBuffer getIntBuffer() {
        checkClosed();
        if (this.info.type != OnnxJavaType.INT32) {
            return null;
        }
        IntBuffer asIntBuffer = getBuffer().asIntBuffer();
        IntBuffer allocate = IntBuffer.allocate(asIntBuffer.capacity());
        allocate.put(asIntBuffer);
        allocate.rewind();
        return allocate;
    }

    public LongBuffer getLongBuffer() {
        checkClosed();
        if (this.info.type != OnnxJavaType.INT64) {
            return null;
        }
        LongBuffer asLongBuffer = getBuffer().asLongBuffer();
        LongBuffer allocate = LongBuffer.allocate(asLongBuffer.capacity());
        allocate.put(asLongBuffer);
        allocate.rewind();
        return allocate;
    }

    public ShortBuffer getShortBuffer() {
        checkClosed();
        OnnxJavaType onnxJavaType = this.info.type;
        if (onnxJavaType != OnnxJavaType.INT16 && onnxJavaType != OnnxJavaType.FLOAT16 && onnxJavaType != OnnxJavaType.BFLOAT16) {
            return null;
        }
        ShortBuffer asShortBuffer = getBuffer().asShortBuffer();
        ShortBuffer allocate = ShortBuffer.allocate(asShortBuffer.capacity());
        allocate.put(asShortBuffer);
        allocate.rewind();
        return allocate;
    }

    @Override // ai.onnxruntime.OnnxValue
    public OnnxValue.OnnxValueType getType() {
        return OnnxValue.OnnxValueType.ONNX_TYPE_TENSOR;
    }

    @Override // ai.onnxruntime.OnnxValue
    public Object getValue() throws OrtException {
        checkClosed();
        if (!this.info.isScalar()) {
            Object makeCarrier = this.info.makeCarrier();
            if (this.info.getNumElements() > 0) {
                getArray(OnnxRuntime.ortApiHandle, this.nativeHandle, makeCarrier);
            }
            TensorInfo tensorInfo = this.info;
            if (tensorInfo.type != OnnxJavaType.STRING) {
                return makeCarrier;
            }
            long[] jArr = tensorInfo.shape;
            return jArr.length != 1 ? OrtUtil.reshape((String[]) makeCarrier, jArr) : makeCarrier;
        }
        switch (AnonymousClass1.$SwitchMap$ai$onnxruntime$OnnxJavaType[this.info.type.ordinal()]) {
            case 1:
                return Float.valueOf(getFloat(OnnxRuntime.ortApiHandle, this.nativeHandle, this.info.onnxType.value));
            case 2:
                return Double.valueOf(getDouble(OnnxRuntime.ortApiHandle, this.nativeHandle));
            case 3:
            case 4:
                return Byte.valueOf(getByte(OnnxRuntime.ortApiHandle, this.nativeHandle, this.info.onnxType.value));
            case 5:
                return Short.valueOf(getShort(OnnxRuntime.ortApiHandle, this.nativeHandle, this.info.onnxType.value));
            case 6:
                return Integer.valueOf(getInt(OnnxRuntime.ortApiHandle, this.nativeHandle, this.info.onnxType.value));
            case 7:
                return Long.valueOf(getLong(OnnxRuntime.ortApiHandle, this.nativeHandle, this.info.onnxType.value));
            case 8:
                return Boolean.valueOf(getBool(OnnxRuntime.ortApiHandle, this.nativeHandle));
            case 9:
                return getString(OnnxRuntime.ortApiHandle, this.nativeHandle);
            case 10:
                return Float.valueOf(Fp16Conversions.fp16ToFloat(getShort(OnnxRuntime.ortApiHandle, this.nativeHandle, this.info.onnxType.value)));
            case 11:
                return Float.valueOf(Fp16Conversions.bf16ToFloat(getShort(OnnxRuntime.ortApiHandle, this.nativeHandle, this.info.onnxType.value)));
            default:
                throw new OrtException("Extracting the value of an invalid Tensor.");
        }
    }

    public boolean ownsBuffer() {
        return this.ownsBuffer;
    }

    public String toString() {
        StringBuilder sb2 = new StringBuilder("OnnxTensor(info=");
        sb2.append(this.info.toString());
        sb2.append(",closed=");
        return K0.l(sb2, this.closed, ")");
    }
}
