proconlib

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub KodamaD/proconlib

:heavy_check_mark: math/mod_pow.cpp

Depends on

Required by

Verified with

Code

#pragma once
#include <cassert>
#include <type_traits>
#include "../utility/int_alias.cpp"
#include "../internal/barret_reduction.cpp"
#include "rem_euclid.cpp"

template <class T> constexpr u32 mod_pow(T x, u64 exp, const u32 mod) {
    assert(mod > 0);
    if (mod == 1) return 0;
    const proconlib::BarretReduction bt(mod);
    u32 ret = 1, mul = rem_euclid<std::common_type_t<T, i64>>(x, mod);
    for (; exp > 0; exp >>= 1) {
        if (exp & 1) ret = bt.product(ret, mul);
        mul = bt.product(mul, mul);
    }
    return ret;
}
#line 2 "math/mod_pow.cpp"
#include <cassert>
#include <type_traits>
#line 2 "utility/int_alias.cpp"
#include <cstdint>

using i32 = std::int32_t;
using u32 = std::uint32_t;
using i64 = std::int64_t;
using u64 = std::uint64_t;
using i128 = __int128_t;
using u128 = __uint128_t;
#line 3 "internal/barret_reduction.cpp"

namespace proconlib {

class BarretReduction {
    u32 mod;
    u64 near_inv;

  public:
    explicit constexpr BarretReduction(const u32 mod) noexcept : mod(mod), near_inv((u64)(-1) / mod + 1) {}
    constexpr u32 product(const u32 a, const u32 b) const noexcept {
        const u64 z = (u64)a * b;
        const u64 x = ((u128)z * near_inv) >> 64;
        const u32 v = z - x * mod;
        return v < mod ? v : v + mod;
    }
    constexpr u32 get_mod() const noexcept { return mod; }
};

}  // namespace proconlib
#line 3 "math/rem_euclid.cpp"

template <class T> constexpr T rem_euclid(T value, const T& mod) {
    assert(mod > 0);
    return (value %= mod) >= 0 ? value : value + mod;
}
#line 7 "math/mod_pow.cpp"

template <class T> constexpr u32 mod_pow(T x, u64 exp, const u32 mod) {
    assert(mod > 0);
    if (mod == 1) return 0;
    const proconlib::BarretReduction bt(mod);
    u32 ret = 1, mul = rem_euclid<std::common_type_t<T, i64>>(x, mod);
    for (; exp > 0; exp >>= 1) {
        if (exp & 1) ret = bt.product(ret, mul);
        mul = bt.product(mul, mul);
    }
    return ret;
}
Back to top page