Commit 45331735 authored by jan.koester's avatar jan.koester
Browse files

optimize

parent d1fcb61c
Loading
Loading
Loading
Loading
+123 −35
Original line number Diff line number Diff line
@@ -39,6 +39,85 @@
#define  __restrict__ __restrict
#endif // !__restrict__

// --- BMI2 + ADX hardware-accelerated CIOS Montgomery multiplication ---
// MULX: flag-free 64×64→128 multiply (better pipelining than MUL)
// ADCX/ADOX: dual carry-chain additions (CF and OF independent)
#if defined(__x86_64__) && (defined(__GNUC__) || defined(__clang__))
#include <cpuid.h>
#include <immintrin.h>

static bool detect_bmi2_adx() {
    unsigned int eax, ebx, ecx, edx;
    if (__get_cpuid_count(7, 0, &eax, &ebx, &ecx, &edx)) {
        return ((ebx >> 8) & 1) && ((ebx >> 19) & 1); // BMI2 bit 8, ADX bit 19
    }
    return false;
}

static const bool s_has_bmi2_adx = detect_bmi2_adx();

__attribute__((target("bmi2,adx")))
static void cios_inner_hw(limb_t* __restrict__ T,
                           const limb_t* __restrict__ ap, size_t au,
                           const limb_t* __restrict__ bp, size_t bu,
                           const limb_t* __restrict__ mp,
                           size_t n, limb_t n_prime) {
    for (size_t i = 0; i < n; ++i) {
        // Step 1: T += a[i] * b  (MULX + ADCX)
        const unsigned long long ai = (i < au) ? (unsigned long long)ap[i] : 0ULL;
        unsigned long long carry = 0;
        if (ai != 0) {
            for (size_t j = 0; j < n; ++j) {
                unsigned long long hi;
                const unsigned long long bj = (j < bu) ? (unsigned long long)bp[j] : 0ULL;
                unsigned long long lo = _mulx_u64(ai, bj, &hi);
                unsigned char cf = _addcarryx_u64(0, lo, carry, &lo);
                unsigned long long tj = (unsigned long long)T[j];
                unsigned char cf2 = _addcarryx_u64(0, tj, lo, &tj);
                T[j] = (limb_t)tj;
                _addcarryx_u64(cf, hi, (unsigned long long)cf2, &carry);
            }
        }
        {
            unsigned long long tn = (unsigned long long)T[n];
            unsigned char cf = _addcarryx_u64(0, tn, carry, &tn);
            T[n] = (limb_t)tn;
            T[n + 1] = (limb_t)cf;
        }

        // Step 2: Montgomery reduction — T += m_i * mod, shift right
        const unsigned long long mi = (unsigned long long)((limb_t)(T[0] * n_prime));
        carry = 0;
        {
            unsigned long long hi;
            unsigned long long lo = _mulx_u64(mi, (unsigned long long)mp[0], &hi);
            unsigned long long t0 = (unsigned long long)T[0];
            unsigned char cf = _addcarryx_u64(0, t0, lo, &t0); // low word cancels to 0
            carry = hi + (unsigned long long)cf;
        }
        for (size_t j = 1; j < n; ++j) {
            unsigned long long hi;
            unsigned long long lo = _mulx_u64(mi, (unsigned long long)mp[j], &hi);
            unsigned char cf = _addcarryx_u64(0, lo, carry, &lo);
            unsigned long long tj = (unsigned long long)T[j];
            unsigned char cf2 = _addcarryx_u64(0, tj, lo, &tj);
            T[j - 1] = (limb_t)tj;
            _addcarryx_u64(cf, hi, (unsigned long long)cf2, &carry);
        }
        {
            unsigned long long tn = (unsigned long long)T[n];
            unsigned char cf = _addcarryx_u64(0, tn, carry, &tn);
            T[n - 1] = (limb_t)tn;
            carry = (unsigned long long)cf;
        }
        T[n] = T[n + 1] + (limb_t)carry;
        T[n + 1] = 0;
    }
}

#define HAS_CIOS_HW 1
#endif // x86_64 GCC/Clang


namespace netplus {
    // Forward declarations
@@ -804,7 +883,7 @@ namespace netplus {
        montgomeryMultiply_into(a, R2, mod, n_prime, scratch, a_bar);

        // res_bar = 1 * R mod N  (Montgomery form of 1)
        bigInt one(1U, 1);
        bigInt one(1U, n + 1);
        bigInt res_bar(n + 2);
        montgomeryMultiply_into(one, R2, mod, n_prime, scratch, res_bar);

@@ -1260,16 +1339,24 @@ namespace netplus {
        const size_t au = a.used;
        const size_t bu = b.used;

#ifdef HAS_CIOS_HW
        if (s_has_bmi2_adx) {
            cios_inner_hw(T, ap, au, bp, bu, mp, n, n_prime);
        } else
#endif
        {
            for (size_t i = 0; i < n; ++i) {
                // Step 1: T += a[i] * b
                const dlimb_t ai = (i < au) ? (dlimb_t)ap[i] : 0;
                dlimb_t carry = 0;
                if (ai != 0) {
                    for (size_t j = 0; j < n; ++j) {
                        const dlimb_t bj = (j < bu) ? (dlimb_t)bp[j] : 0;
                        dlimb_t cur = (dlimb_t)T[j] + ai * bj + carry;
                        T[j] = (limb_t)cur;
                        carry = cur >> LIMB_BITS;
                    }
                }
                {
                    dlimb_t cur = (dlimb_t)T[n] + carry;
                    T[n] = (limb_t)cur;
@@ -1297,6 +1384,7 @@ namespace netplus {
                T[n] = T[n + 1] + (limb_t)carry;
                T[n + 1] = 0;
            }
        }

        // Result is T[0..n-1] (possibly n words)
        if (out.capacity < n + 1) out.reserve(n + 1);