/*
 * Decompiled with CFR 0.152.
 */
package org.apache.cassandra.spark.sparksql;

import java.io.IOException;
import java.math.BigInteger;
import java.nio.ByteBuffer;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import org.apache.cassandra.analytics.stats.Stats;
import org.apache.cassandra.bridge.CassandraBridge;
import org.apache.cassandra.bridge.CassandraBridgeFactory;
import org.apache.cassandra.bridge.CassandraVersion;
import org.apache.cassandra.spark.TestUtils;
import org.apache.cassandra.spark.data.CqlField;
import org.apache.cassandra.spark.data.CqlTable;
import org.apache.cassandra.spark.data.DataLayer;
import org.apache.cassandra.spark.data.converter.SparkSqlTypeConverter;
import org.apache.cassandra.spark.reader.RowData;
import org.apache.cassandra.spark.reader.StreamScanner;
import org.apache.cassandra.spark.sparksql.SparkRowIterator;
import org.apache.cassandra.spark.sparksql.filters.PartitionKeyFilter;
import org.apache.cassandra.spark.sparksql.filters.PruneColumnFilter;
import org.apache.cassandra.spark.sparksql.filters.SSTableTimeRangeFilter;
import org.apache.cassandra.spark.utils.ByteBufferUtils;
import org.apache.cassandra.spark.utils.test.TestSchema;
import org.apache.spark.sql.catalyst.InternalRow;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.ArgumentMatchers;
import org.mockito.Mockito;
import org.quicktheories.QuickTheory;

public class SparkRowIteratorTests {
    private static final int NUM_ROWS = 50;

    @ParameterizedTest
    @MethodSource(value={"org.apache.cassandra.bridge.VersionRunner#bridges"})
    public void testBasicKeyValue(CassandraBridge bridge) {
        QuickTheory.qt().forAll(TestUtils.cql3Type((CassandraBridge)bridge), TestUtils.cql3Type((CassandraBridge)bridge)).assuming((type1, type2) -> type1.supportedAsPrimaryKeyColumn()).checkAssert((type1, type2) -> SparkRowIteratorTests.runTest(bridge.getVersion(), TestSchema.builder((CassandraBridge)bridge).withPartitionKey("a", (CqlField.CqlType)type1).withColumn("b", (CqlField.CqlType)type2).build()));
    }

    @ParameterizedTest
    @MethodSource(value={"org.apache.cassandra.bridge.VersionRunner#bridges"})
    public void testMultiPartitionKeys(CassandraBridge bridge) {
        QuickTheory.qt().forAll(TestUtils.cql3Type((CassandraBridge)bridge), TestUtils.cql3Type((CassandraBridge)bridge), TestUtils.cql3Type((CassandraBridge)bridge)).assuming((type1, type2, type3) -> type1.supportedAsPrimaryKeyColumn() && type2.supportedAsPrimaryKeyColumn() && type3.supportedAsPrimaryKeyColumn()).checkAssert((type1, type2, type3) -> SparkRowIteratorTests.runTest(bridge.getVersion(), TestSchema.builder((CassandraBridge)bridge).withPartitionKey("a", (CqlField.CqlType)type1).withPartitionKey("b", (CqlField.CqlType)type2).withPartitionKey("c", (CqlField.CqlType)type3).withColumn("d", (CqlField.CqlType)bridge.bigint()).build()));
    }

    @ParameterizedTest
    @MethodSource(value={"org.apache.cassandra.bridge.VersionRunner#bridges"})
    public void testBasicClusteringKey(CassandraBridge bridge) {
        QuickTheory.qt().forAll(TestUtils.cql3Type((CassandraBridge)bridge), TestUtils.cql3Type((CassandraBridge)bridge), TestUtils.cql3Type((CassandraBridge)bridge), TestUtils.sortOrder()).assuming((type1, type2, type3, order) -> type1.supportedAsPrimaryKeyColumn() && type2.supportedAsPrimaryKeyColumn()).checkAssert((type1, type2, type3, order) -> SparkRowIteratorTests.runTest(bridge.getVersion(), TestSchema.builder((CassandraBridge)bridge).withPartitionKey("a", (CqlField.CqlType)type1).withClusteringKey("b", (CqlField.CqlType)type2).withColumn("c", (CqlField.CqlType)type3).withSortOrder(order).build()));
    }

    @ParameterizedTest
    @MethodSource(value={"org.apache.cassandra.bridge.VersionRunner#bridges"})
    public void testMultiClusteringKey(CassandraBridge bridge) {
        QuickTheory.qt().forAll(TestUtils.cql3Type((CassandraBridge)bridge), TestUtils.cql3Type((CassandraBridge)bridge), TestUtils.sortOrder(), TestUtils.sortOrder()).assuming((type1, type2, order1, order2) -> type1.supportedAsPrimaryKeyColumn() && type2.supportedAsPrimaryKeyColumn()).checkAssert((type1, type2, order1, order2) -> SparkRowIteratorTests.runTest(bridge.getVersion(), TestSchema.builder((CassandraBridge)bridge).withPartitionKey("a", (CqlField.CqlType)bridge.bigint()).withClusteringKey("b", (CqlField.CqlType)type1).withClusteringKey("c", (CqlField.CqlType)type2).withColumn("d", (CqlField.CqlType)bridge.bigint()).withSortOrder(order1).withSortOrder(order2).build()));
    }

    @ParameterizedTest
    @MethodSource(value={"org.apache.cassandra.bridge.VersionRunner#bridges"})
    public void testUdt(CassandraBridge bridge) {
        QuickTheory.qt().forAll(TestUtils.cql3Type((CassandraBridge)bridge), TestUtils.cql3Type((CassandraBridge)bridge)).checkAssert((type1, type2) -> SparkRowIteratorTests.runTest(bridge.getVersion(), TestSchema.builder((CassandraBridge)bridge).withPartitionKey("a", (CqlField.CqlType)bridge.bigint()).withClusteringKey("b", (CqlField.CqlType)bridge.text()).withColumn("c", (CqlField.CqlType)bridge.udt("keyspace", "testudt").withField("x", (CqlField.CqlType)type1).withField("y", (CqlField.CqlType)bridge.ascii()).withField("z", (CqlField.CqlType)type2).build()).build()));
    }

    @ParameterizedTest
    @MethodSource(value={"org.apache.cassandra.bridge.VersionRunner#bridges"})
    public void testTuple(CassandraBridge bridge) {
        QuickTheory.qt().forAll(TestUtils.cql3Type((CassandraBridge)bridge), TestUtils.cql3Type((CassandraBridge)bridge)).checkAssert((type1, type2) -> SparkRowIteratorTests.runTest(bridge.getVersion(), TestSchema.builder((CassandraBridge)bridge).withPartitionKey("a", (CqlField.CqlType)bridge.bigint()).withClusteringKey("b", (CqlField.CqlType)bridge.text()).withColumn("c", (CqlField.CqlType)bridge.tuple(new CqlField.CqlType[]{bridge.aInt(), type1, bridge.ascii(), type2, bridge.date()})).build()));
    }

    private static void runTest(CassandraVersion version, TestSchema schema) {
        SparkRowIteratorTests.runTest(version, schema, schema.randomRows(50));
    }

    private static void runTest(CassandraVersion version, TestSchema schema, TestSchema.TestRow[] testRows) {
        try {
            schema.setCassandraVersion(version);
            SparkRowIteratorTests.testRowIterator(version, schema, testRows);
        }
        catch (IOException exception) {
            throw new RuntimeException(exception);
        }
    }

    private static void testRowIterator(CassandraVersion version, TestSchema schema, TestSchema.TestRow[] testRows) throws IOException {
        CassandraBridge bridge = CassandraBridgeFactory.get((CassandraVersion)version);
        SparkSqlTypeConverter typeConverter = CassandraBridgeFactory.getSparkSql((CassandraVersion)bridge.getVersion());
        CqlTable cqlTable = schema.buildTable();
        int numRows = testRows.length;
        int numColumns = cqlTable.fields().size() - cqlTable.numPartitionKeys() - cqlTable.numClusteringKeys();
        List columns = cqlTable.fields().stream().filter(field -> !field.isPartitionKey()).filter(field -> !field.isClusteringColumn()).sorted().collect(Collectors.toList());
        RowData rowData = new RowData();
        AtomicInteger rowPos = new AtomicInteger();
        AtomicInteger colPos = new AtomicInteger();
        DataLayer dataLayer = (DataLayer)Mockito.mock(DataLayer.class);
        Mockito.when((Object)dataLayer.cqlTable()).thenReturn((Object)cqlTable);
        Mockito.when((Object)dataLayer.version()).thenCallRealMethod();
        Mockito.when((Object)dataLayer.isInPartition(ArgumentMatchers.anyInt(), (BigInteger)ArgumentMatchers.any(BigInteger.class), (ByteBuffer)ArgumentMatchers.any(ByteBuffer.class))).thenReturn((Object)true);
        Mockito.when((Object)dataLayer.bridge()).thenReturn((Object)bridge);
        Mockito.when((Object)dataLayer.stats()).thenReturn((Object)Stats.DoNothingStats.INSTANCE);
        Mockito.when((Object)dataLayer.requestedFeatures()).thenCallRealMethod();
        Mockito.when((Object)dataLayer.typeConverter()).thenReturn((Object)typeConverter);
        Mockito.when((Object)dataLayer.sstableTimeRangeFilter()).thenReturn((Object)SSTableTimeRangeFilter.ALL);
        StreamScanner scanner = (StreamScanner)Mockito.mock(StreamScanner.class);
        Mockito.when((Object)scanner.data()).thenReturn((Object)rowData);
        ((StreamScanner)Mockito.doAnswer(invocation -> {
            int col = colPos.getAndIncrement();
            if (rowPos.get() >= numRows) {
                return false;
            }
            TestSchema.TestRow testRow = testRows[rowPos.get()];
            if (col == 0) {
                if (cqlTable.numPartitionKeys() == 1) {
                    CqlField partitionKey = (CqlField)cqlTable.partitionKeys().get(0);
                    rowData.setPartitionKeyCopy(partitionKey.serialize(testRow.get(partitionKey.position())), BigInteger.ONE);
                } else {
                    assert (cqlTable.numPartitionKeys() > 1);
                    ByteBuffer[] partitionBuffers = new ByteBuffer[cqlTable.numPartitionKeys()];
                    int position = 0;
                    for (CqlField partitionKey : cqlTable.partitionKeys()) {
                        partitionBuffers[position] = partitionKey.serialize(testRow.get(partitionKey.position()));
                        ++position;
                    }
                    rowData.setPartitionKeyCopy(ByteBufferUtils.build((boolean)false, (ByteBuffer[])partitionBuffers), BigInteger.ONE);
                }
            }
            CqlField column = (CqlField)columns.get(col);
            ByteBuffer[] colBuffers = new ByteBuffer[cqlTable.numClusteringKeys() + 1];
            int position = 0;
            for (CqlField clusteringColumn : cqlTable.clusteringKeys()) {
                colBuffers[position] = clusteringColumn.serialize(testRow.get(clusteringColumn.position()));
                ++position;
            }
            colBuffers[position] = bridge.ascii().serialize((Object)column.name());
            rowData.setColumnNameCopy(ByteBufferUtils.build((boolean)false, (ByteBuffer[])colBuffers));
            rowData.setValueCopy(column.serialize(testRow.get(column.position())));
            if (colPos.get() == numColumns) {
                if (rowPos.getAndIncrement() >= numRows) {
                    throw new IllegalStateException("Went too far...");
                }
                colPos.set(0);
            }
            return true;
        }).when((Object)scanner)).next();
        Mockito.when((Object)dataLayer.openCompactionScanner(ArgumentMatchers.anyInt(), ArgumentMatchers.anyListOf(PartitionKeyFilter.class), (SSTableTimeRangeFilter)ArgumentMatchers.any(), (PruneColumnFilter)ArgumentMatchers.any())).thenReturn((Object)scanner);
        SparkRowIterator it = new SparkRowIterator(0, dataLayer);
        int rowCount = 0;
        while (it.next()) {
            while (rowCount < testRows.length && testRows[rowCount].isTombstone()) {
                ++rowCount;
            }
            if (rowCount >= testRows.length) break;
            TestSchema.TestRow row = testRows[rowCount];
            Assertions.assertThat((Object)schema.toTestRow((InternalRow)it.get(), typeConverter)).isEqualTo((Object)row);
            ++rowCount;
        }
        Assertions.assertThat((int)rowCount).isEqualTo(numRows);
        it.close();
    }
}

