GordianKMAC.java

package org.bouncycastle.crypto.patch.macs;

import org.bouncycastle.crypto.CipherParameters;
import org.bouncycastle.crypto.DataLengthException;
import org.bouncycastle.crypto.Mac;
import org.bouncycastle.crypto.Xof;
import org.bouncycastle.crypto.digests.CSHAKEDigest;
import org.bouncycastle.crypto.digests.XofUtils;
import org.bouncycastle.crypto.params.KeyParameter;
import org.bouncycastle.util.Arrays;
import org.bouncycastle.util.Strings;

/**
 * KMAC - MAC with optional XOF mode.
 * <p>
 * From NIST Special Publication 800-185 - SHA-3 Derived Functions:cSHAKE, KMAC, TupleHash and ParallelHash
 * </p>
 */
public class GordianKMAC
        implements Mac, Xof
{
    private static final byte[] padding = new byte[100];

    private final CSHAKEDigest cshake;
    private final int bitLength;
    private final int outputLength;

    private byte[] key;
    private boolean initialised;
    private boolean firstOutput;

    /**
     * Base constructor.
     *
     * @param bitLength bit length of the underlying SHAKE function, 128 or 256.
     * @param S         the customization string - available for local use.
     */
    public GordianKMAC(int bitLength, byte[] S)
    {
        this.cshake = new CSHAKEDigest(bitLength, Strings.toByteArray("KMAC"), S);
        this.bitLength = bitLength;
        this.outputLength = bitLength * 2 / 8;
    }

    public void init(CipherParameters params)
            throws IllegalArgumentException
    {
        KeyParameter kParam = (KeyParameter)params;

        this.key = Arrays.clone(kParam.getKey());
        this.initialised = true;

        reset();
    }

    public String getAlgorithmName()
    {
        return "KMAC" + cshake.getAlgorithmName().substring(6);
    }

    public int getByteLength()
    {
        return cshake.getByteLength();
    }

    public int getMacSize()
    {
        return outputLength;
    }

    public int getDigestSize()
    {
        return outputLength;
    }

    public void update(byte in)
            throws IllegalStateException
    {
        if (!initialised)
        {
            throw new IllegalStateException("KMAC not initialized");
        }

        cshake.update(in);
    }

    public void update(byte[] in, int inOff, int len)
            throws DataLengthException, IllegalStateException
    {
        if (!initialised)
        {
            throw new IllegalStateException("KMAC not initialized");
        }

        cshake.update(in, inOff, len);
    }

    public int doFinal(byte[] out, int outOff)
            throws DataLengthException, IllegalStateException
    {
        if (firstOutput)
        {
            if (!initialised)
            {
                throw new IllegalStateException("KMAC not initialized");
            }

            byte[] encOut = XofUtils.rightEncode(getMacSize() * 8);

            cshake.update(encOut, 0, encOut.length);
        }

        int rv = cshake.doFinal(out, outOff, getMacSize());

        reset();

        return rv;
    }

    public int doFinal(byte[] out, int outOff, int outLen)
    {
        if (firstOutput)
        {
            if (!initialised)
            {
                throw new IllegalStateException("KMAC not initialized");
            }

            byte[] encOut = XofUtils.rightEncode(0); // Same as doOutput;

            cshake.update(encOut, 0, encOut.length);
        }

        int rv = cshake.doFinal(out, outOff, outLen);

        reset();

        return rv;
    }

    public int doOutput(byte[] out, int outOff, int outLen)
    {
        if (firstOutput)
        {
            if (!initialised)
            {
                throw new IllegalStateException("KMAC not initialized");
            }

            byte[] encOut = XofUtils.rightEncode(0);

            cshake.update(encOut, 0, encOut.length);

            firstOutput = false;
        }

        return cshake.doOutput(out, outOff, outLen);
    }

    public void reset()
    {
        cshake.reset();

        if (key != null)
        {
            if (bitLength == 128)
            {
                bytePad(key, 168);
            }
            else
            {
                bytePad(key, 136);
            }
        }

        firstOutput = true;
    }

    private void bytePad(byte[] X, int w)
    {
        byte[] bytes = XofUtils.leftEncode(w);
        update(bytes, 0, bytes.length);
        byte[] encX = encode(X);
        update(encX, 0, encX.length);

        int required = w - ((bytes.length + encX.length) % w);

        if (required > 0 && required != w)
        {
            while (required > padding.length)
            {
                update(padding, 0, padding.length);
                required -= padding.length;
            }

            update(padding, 0, required);
        }
    }

    private static byte[] encode(byte[] X)
    {
        return Arrays.concatenate(XofUtils.leftEncode(X.length * 8), X);
    }
}