/*
 * Decompiled with CFR 0.152.
 */
package org.jgroups.protocols;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.atomic.LongAdder;
import java.util.function.Supplier;
import java.util.zip.DataFormatException;
import java.util.zip.Deflater;
import java.util.zip.Inflater;
import org.jgroups.BytesMessage;
import org.jgroups.Header;
import org.jgroups.Message;
import org.jgroups.MessageFactory;
import org.jgroups.annotations.MBean;
import org.jgroups.annotations.ManagedAttribute;
import org.jgroups.annotations.Property;
import org.jgroups.conf.AttributeType;
import org.jgroups.stack.Protocol;
import org.jgroups.util.ByteArray;
import org.jgroups.util.FastArray;
import org.jgroups.util.MessageBatch;
import org.jgroups.util.Util;

@MBean(description="Compresses messages to send and uncompresses received messages")
public class COMPRESS
extends Protocol {
    @Property(description="Compression level (from java.util.zip.Deflater) (0=no compression, 1=best speed, 9=best compression). Default is 9")
    protected int compression_level = 9;
    @Property(description="Minimal payload size of a message (in bytes) for compression to kick in. Default is 500 bytes", type=AttributeType.BYTES)
    protected int min_size = 500;
    @Property(description="Number of inflaters/deflaters for concurrent processing. Default is 2 ")
    protected int pool_size = 2;
    protected BlockingQueue<Deflater> deflater_pool;
    protected BlockingQueue<Inflater> inflater_pool;
    protected MessageFactory msg_factory;
    protected final LongAdder num_compressions = new LongAdder();
    protected final LongAdder num_decompressions = new LongAdder();

    public int getMinSize() {
        return this.min_size;
    }

    public COMPRESS setMinSize(int s2) {
        this.min_size = s2;
        return this;
    }

    @ManagedAttribute(description="Number of compressions", type=AttributeType.SCALAR)
    public long getNumCompressions() {
        return this.num_compressions.sum();
    }

    @ManagedAttribute(description="Number of un-compressions", type=AttributeType.SCALAR)
    public long getNumUncompressions() {
        return this.num_decompressions.sum();
    }

    @Override
    public void resetStats() {
        super.resetStats();
        this.num_compressions.reset();
        this.num_decompressions.reset();
    }

    @Override
    public void init() throws Exception {
        int i;
        this.deflater_pool = new ArrayBlockingQueue<Deflater>(this.pool_size);
        for (i = 0; i < this.pool_size; ++i) {
            this.deflater_pool.add(new Deflater(this.compression_level));
        }
        this.inflater_pool = new ArrayBlockingQueue<Inflater>(this.pool_size);
        for (i = 0; i < this.pool_size; ++i) {
            this.inflater_pool.add(new Inflater());
        }
        this.msg_factory = this.getTransport().getMessageFactory();
    }

    @Override
    public void destroy() {
        this.deflater_pool.forEach(Deflater::end);
        this.inflater_pool.forEach(Inflater::end);
    }

    @Override
    public Object down(Message msg) {
        int length = msg.getLength();
        if (length >= this.min_size) {
            byte[] byArray;
            boolean serialize = !msg.hasArray();
            ByteArray tmp = null;
            if (serialize) {
                tmp = COMPRESS.messageToByteArray(msg);
                byArray = tmp.getArray();
            } else {
                byArray = msg.getArray();
            }
            byte[] payload = byArray;
            int offset = serialize ? tmp.getOffset() : msg.getOffset();
            length = serialize ? tmp.getLength() : msg.getLength();
            byte[] compressed_payload = new byte[length];
            Deflater deflater = null;
            try {
                deflater = this.deflater_pool.take();
                deflater.reset();
                deflater.setInput(payload, offset, length);
                deflater.finish();
                deflater.deflate(compressed_payload);
                int compressed_size = deflater.getTotalOut();
                if (compressed_size < length) {
                    Message copy2 = null;
                    copy2 = serialize ? new BytesMessage(msg.getDest()) : msg.copy(false, true);
                    copy2.setArray(compressed_payload, 0, compressed_size).putHeader(this.id, new CompressHeader(length).needsDeserialization(serialize));
                    if (this.log.isTraceEnabled()) {
                        this.log.trace("compressed payload from %d bytes to %d bytes", length, compressed_size);
                    }
                    this.num_compressions.increment();
                    Object object = this.down_prot.down(copy2);
                    return object;
                }
                if (this.log.isTraceEnabled()) {
                    this.log.trace("skipping compression since the compressed message (%d) is not smaller than the original (%d)", compressed_size, length);
                }
            }
            catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                throw new RuntimeException(e);
            }
            finally {
                if (deflater != null) {
                    this.deflater_pool.offer(deflater);
                }
            }
        }
        return this.down_prot.down(msg);
    }

    @Override
    public Object up(Message msg) {
        Message uncompressed_msg;
        CompressHeader hdr = (CompressHeader)msg.getHeader(this.id);
        if (hdr != null && (uncompressed_msg = this.uncompress(msg, hdr.original_size, hdr.needsDeserialization())) != null) {
            if (this.log.isTraceEnabled()) {
                this.log.trace("uncompressed %d bytes to %d bytes", msg.getLength(), uncompressed_msg.getLength());
            }
            this.num_decompressions.increment();
            return this.up_prot.up(uncompressed_msg);
        }
        return this.up_prot.up(msg);
    }

    @Override
    public void up(MessageBatch batch) {
        FastArray.FastIterator it = (FastArray.FastIterator)batch.iterator();
        while (it.hasNext()) {
            Message uncompressed_msg;
            Message msg = (Message)it.next();
            CompressHeader hdr = (CompressHeader)msg.getHeader(this.id);
            if (hdr == null || (uncompressed_msg = this.uncompress(msg, hdr.original_size, hdr.needsDeserialization())) == null) continue;
            if (this.log.isTraceEnabled()) {
                this.log.trace("uncompressed %d bytes to %d bytes", msg.getLength(), uncompressed_msg.getLength());
            }
            it.replace(uncompressed_msg);
            this.num_decompressions.increment();
        }
        if (!batch.isEmpty()) {
            this.up_prot.up(batch);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    protected Message uncompress(Message msg, int original_size, boolean needs_deserialization) {
        byte[] compressed_payload = msg.getArray();
        if (compressed_payload != null && compressed_payload.length > 0) {
            byte[] uncompressed_payload = new byte[original_size];
            Inflater inflater = null;
            try {
                inflater = this.inflater_pool.take();
                inflater.reset();
                inflater.setInput(compressed_payload, msg.getOffset(), msg.getLength());
                try {
                    inflater.inflate(uncompressed_payload);
                    if (needs_deserialization) {
                        Message message = COMPRESS.messageFromByteArray(uncompressed_payload, this.msg_factory);
                        return message;
                    }
                    Message message = msg.copy(false, true).setArray(uncompressed_payload, 0, uncompressed_payload.length);
                    return message;
                }
                catch (DataFormatException e) {
                    this.log.error(Util.getMessage("CompressionFailure"), e);
                }
            }
            catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            }
            finally {
                if (inflater != null) {
                    this.inflater_pool.offer(inflater);
                }
            }
        }
        return null;
    }

    protected static ByteArray messageToByteArray(Message msg) {
        try {
            return Util.messageToBuffer(msg);
        }
        catch (Exception ex) {
            throw new RuntimeException("failed marshalling message", ex);
        }
    }

    protected static Message messageFromByteArray(byte[] uncompressed_payload, MessageFactory msg_factory) {
        try {
            return Util.messageFromBuffer(uncompressed_payload, 0, uncompressed_payload.length, msg_factory);
        }
        catch (Exception ex) {
            throw new RuntimeException("failed unmarshalling message", ex);
        }
    }

    public static class CompressHeader
    extends Header {
        protected int original_size;
        protected boolean needs_deserialization;

        public CompressHeader() {
        }

        public CompressHeader(int s2) {
            this.original_size = s2;
        }

        @Override
        public short getMagicId() {
            return 58;
        }

        @Override
        public Supplier<? extends Header> create() {
            return CompressHeader::new;
        }

        public boolean needsDeserialization() {
            return this.needs_deserialization;
        }

        public CompressHeader needsDeserialization(boolean flag) {
            this.needs_deserialization = flag;
            return this;
        }

        @Override
        public int serializedSize() {
            return 5;
        }

        @Override
        public void writeTo(DataOutput out) throws IOException {
            out.writeInt(this.original_size);
            out.writeBoolean(this.needs_deserialization);
        }

        @Override
        public void readFrom(DataInput in) throws IOException {
            this.original_size = in.readInt();
            this.needs_deserialization = in.readBoolean();
        }
    }
}

