Ginkgo Generated from branch based on main. Ginkgo version 1.9.0
A numerical linear algebra library targeting many-core architectures
 
Loading...
Searching...
No Matches
math.hpp
1// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#ifndef GKO_PUBLIC_CORE_BASE_MATH_HPP_
6#define GKO_PUBLIC_CORE_BASE_MATH_HPP_
7
8
9#include <cmath>
10#include <complex>
11#include <cstdlib>
12#include <limits>
13#include <type_traits>
14#include <utility>
15
16#include <ginkgo/config.hpp>
17#include <ginkgo/core/base/half.hpp>
18#include <ginkgo/core/base/types.hpp>
19#include <ginkgo/core/base/utils.hpp>
20
21
22namespace gko {
23
24
25// type manipulations
26
27
33namespace detail {
34
35
39template <typename T>
40struct remove_complex_impl {
41 using type = T;
42};
43
47template <typename T>
48struct remove_complex_impl<std::complex<T>> {
49 using type = T;
50};
51
52
58template <typename T>
59struct to_complex_impl {
60 using type = std::complex<T>;
61};
62
68template <typename T>
69struct to_complex_impl<std::complex<T>> {
70 using type = std::complex<T>;
71};
72
73
74template <typename T>
75struct is_complex_impl : public std::integral_constant<bool, false> {};
76
77template <typename T>
78struct is_complex_impl<std::complex<T>>
79 : public std::integral_constant<bool, true> {};
80
81
82template <typename T>
83struct is_complex_or_scalar_impl : std::is_scalar<T> {};
84
85template <>
86struct is_complex_or_scalar_impl<half> : std::true_type {};
87
88template <typename T>
89struct is_complex_or_scalar_impl<std::complex<T>>
90 : is_complex_or_scalar_impl<T> {};
91
92
100template <template <typename> class converter, typename T>
101struct template_converter {};
102
112template <template <typename> class converter, template <typename...> class T,
113 typename... Rest>
114struct template_converter<converter, T<Rest...>> {
115 using type = T<typename converter<Rest>::type...>;
116};
117
118
119template <typename T, typename = void>
120struct remove_complex_s {};
121
128template <typename T>
129struct remove_complex_s<T,
130 std::enable_if_t<is_complex_or_scalar_impl<T>::value>> {
131 using type = typename detail::remove_complex_impl<T>::type;
132};
133
140template <typename T>
141struct remove_complex_s<
142 T, std::enable_if_t<!is_complex_or_scalar_impl<T>::value>> {
143 using type =
144 typename detail::template_converter<detail::remove_complex_impl,
145 T>::type;
146};
147
148
149template <typename T, typename = void>
150struct to_complex_s {};
151
158template <typename T>
159struct to_complex_s<T, std::enable_if_t<is_complex_or_scalar_impl<T>::value>> {
160 using type = typename detail::to_complex_impl<T>::type;
161};
162
169template <typename T>
170struct to_complex_s<T, std::enable_if_t<!is_complex_or_scalar_impl<T>::value>> {
171 using type =
172 typename detail::template_converter<detail::to_complex_impl, T>::type;
173};
174
175
176} // namespace detail
177
178
184template <typename T>
187 using type = T;
188};
189
195template <typename T>
196struct cpx_real_type<std::complex<T>> {
198 using type = typename std::complex<T>::value_type;
199};
200
201
210template <typename T>
211using is_complex_s = detail::is_complex_impl<T>;
212
220template <typename T>
221GKO_INLINE constexpr bool is_complex()
222{
223 return detail::is_complex_impl<T>::value;
224}
225
226
234template <typename T>
235using is_complex_or_scalar_s = detail::is_complex_or_scalar_impl<T>;
236
244template <typename T>
245GKO_INLINE constexpr bool is_complex_or_scalar()
246{
247 return detail::is_complex_or_scalar_impl<T>::value;
248}
249
250
259template <typename T>
260using remove_complex = typename detail::remove_complex_s<T>::type;
261
262
278template <typename T>
279using to_complex = typename detail::to_complex_s<T>::type;
280
281
287template <typename T>
289
290
291namespace detail {
292
293
294// singly linked list of all our supported precisions
295template <typename T>
296struct next_precision_base_impl {};
297
298template <>
299struct next_precision_base_impl<float> {
300 using type = double;
301};
302
303template <>
304struct next_precision_base_impl<double> {
305 using type = float;
306};
307
308template <typename T>
309struct next_precision_base_impl<std::complex<T>> {
310 using type = std::complex<typename next_precision_base_impl<T>::type>;
311};
312
313
314template <typename T>
315struct next_precision_impl {};
316
317
318template <>
319struct next_precision_impl<gko::half> {
320 using type = float;
321};
322
323template <>
324struct next_precision_impl<float> {
325 using type = double;
326};
327
328template <>
329struct next_precision_impl<double> {
330 using type = gko::half;
331};
332
333template <typename T>
334struct next_precision_impl<std::complex<T>> {
335 using type = std::complex<typename next_precision_impl<T>::type>;
336};
337
338
339template <typename T>
340struct reduce_precision_impl {
341 using type = T;
342};
343
344template <typename T>
345struct reduce_precision_impl<std::complex<T>> {
346 using type = std::complex<typename reduce_precision_impl<T>::type>;
347};
348
349template <>
350struct reduce_precision_impl<double> {
351 using type = float;
352};
353
354template <>
355struct reduce_precision_impl<float> {
356 using type = half;
357};
358
359
360template <typename T>
361struct increase_precision_impl {
362 using type = T;
363};
364
365template <typename T>
366struct increase_precision_impl<std::complex<T>> {
367 using type = std::complex<typename increase_precision_impl<T>::type>;
368};
369
370template <>
371struct increase_precision_impl<float> {
372 using type = double;
373};
374
375template <>
376struct increase_precision_impl<half> {
377 using type = float;
378};
379
380
381template <typename T>
382struct infinity_impl {
383 // CUDA doesn't allow us to call std::numeric_limits functions
384 // so we need to store the value instead.
385 static constexpr auto value = std::numeric_limits<T>::infinity();
386};
387
388
392template <typename T1, typename T2>
393struct highest_precision_impl {
394 using type = decltype(T1{} + T2{});
395};
396
397template <typename T1, typename T2>
398struct highest_precision_impl<std::complex<T1>, std::complex<T2>> {
399 using type = std::complex<typename highest_precision_impl<T1, T2>::type>;
400};
401
402template <typename Head, typename... Tail>
403struct highest_precision_variadic {
404 using type = typename highest_precision_impl<
405 Head, typename highest_precision_variadic<Tail...>::type>::type;
406};
407
408template <typename Head>
409struct highest_precision_variadic<Head> {
410 using type = Head;
411};
412
413
414} // namespace detail
415
416
420template <typename T>
421using next_precision_base = typename detail::next_precision_base_impl<T>::type;
422
423
430template <typename T>
432
436#if GINKGO_ENABLE_HALF
437template <typename T>
438using next_precision = typename detail::next_precision_impl<T>::type;
439
440template <typename T>
441using previous_precision = next_precision<next_precision<T>>;
442#else
443// fallback to float/double list
444template <typename T>
446
447template <typename T>
448using previous_precision = previous_precision_base<T>;
449#endif
450
451
455template <typename T>
456using reduce_precision = typename detail::reduce_precision_impl<T>::type;
457
458
462template <typename T>
463using increase_precision = typename detail::increase_precision_impl<T>::type;
464
465
477template <typename... Ts>
479 typename detail::highest_precision_variadic<Ts...>::type;
480
481
491template <typename T>
492GKO_INLINE constexpr reduce_precision<T> round_down(T val)
493{
494 return static_cast<reduce_precision<T>>(val);
495}
496
497
507template <typename T>
508GKO_INLINE constexpr increase_precision<T> round_up(T val)
509{
510 return static_cast<increase_precision<T>>(val);
511}
512
513
514template <typename FloatType, size_type NumComponents, size_type ComponentId>
515class truncated;
516
517
518namespace detail {
519
520
521template <typename T>
522struct truncate_type_impl {
523 using type = truncated<T, 2, 0>;
524};
525
526template <typename T, size_type Components>
527struct truncate_type_impl<truncated<T, Components, 0>> {
528 using type = truncated<T, 2 * Components, 0>;
529};
530
531template <typename T>
532struct truncate_type_impl<std::complex<T>> {
533 using type = std::complex<typename truncate_type_impl<T>::type>;
534};
535
536
537template <typename T>
538struct type_size_impl {
539 static constexpr auto value = sizeof(T) * byte_size;
540};
541
542template <typename T>
543struct type_size_impl<std::complex<T>> {
544 static constexpr auto value = sizeof(T) * byte_size;
545};
546
547
548} // namespace detail
549
550
555template <typename T, size_type Limit = sizeof(uint16) * byte_size>
557 std::conditional_t<detail::type_size_impl<T>::value >= 2 * Limit,
558 typename detail::truncate_type_impl<T>::type, T>;
559
560
567template <typename S, typename R>
575 GKO_ATTRIBUTES R operator()(S val) { return static_cast<R>(val); }
576};
577
578
579// mathematical functions
580
581
590GKO_INLINE constexpr int64 ceildiv(int64 num, int64 den)
591{
592 return (num + den - 1) / den;
593}
594
595
601template <typename T>
602GKO_INLINE constexpr T zero()
603{
604 return T{};
605}
606
607
617template <typename T>
618GKO_INLINE constexpr T zero(const T&)
619{
620 return zero<T>();
621}
622
623
629template <typename T>
630GKO_INLINE constexpr T one()
631{
632 return T(1);
633}
634
635template <>
636GKO_INLINE constexpr half one<half>()
637{
638 constexpr auto bits = static_cast<uint16>(0b0'01111'0000000000u);
639 return half::create_from_bits(bits);
640}
641
642
652template <typename T>
653GKO_INLINE constexpr T one(const T&)
654{
655 return one<T>();
656}
657
658
667template <typename T>
668GKO_INLINE constexpr bool is_zero(T value)
669{
670 return value == zero<T>();
671}
672
673
682template <typename T>
683GKO_INLINE constexpr bool is_nonzero(T value)
684{
685 return value != zero<T>();
686}
687
688
700template <typename T>
701GKO_INLINE constexpr T max(const T& x, const T& y)
702{
703 return x >= y ? x : y;
704}
705
706
718template <typename T>
719GKO_INLINE constexpr T min(const T& x, const T& y)
720{
721 return x <= y ? x : y;
722}
723
724
725namespace detail {
726
727
737template <typename Ref, typename Dummy = std::void_t<>>
738struct has_to_arithmetic_type : std::false_type {
739 static_assert(std::is_same<Dummy, void>::value,
740 "Do not modify the Dummy value!");
741 using type = Ref;
742};
743
744template <typename Ref>
745struct has_to_arithmetic_type<
746 Ref, std::void_t<decltype(std::declval<Ref>().to_arithmetic_type())>>
747 : std::true_type {
748 using type = decltype(std::declval<Ref>().to_arithmetic_type());
749};
750
751
756template <typename Ref, typename Dummy = std::void_t<>>
757struct has_arithmetic_type : std::false_type {
758 static_assert(std::is_same<Dummy, void>::value,
759 "Do not modify the Dummy value!");
760};
761
762template <typename Ref>
763struct has_arithmetic_type<Ref, std::void_t<typename Ref::arithmetic_type>>
764 : std::true_type {};
765
766
778template <typename Ref>
779constexpr GKO_ATTRIBUTES
780 std::enable_if_t<has_to_arithmetic_type<Ref>::value,
781 typename has_to_arithmetic_type<Ref>::type>
782 to_arithmetic_type(const Ref& ref)
783{
784 return ref.to_arithmetic_type();
785}
786
787template <typename Ref>
788constexpr GKO_ATTRIBUTES std::enable_if_t<!has_to_arithmetic_type<Ref>::value &&
789 has_arithmetic_type<Ref>::value,
790 typename Ref::arithmetic_type>
791to_arithmetic_type(const Ref& ref)
792{
793 return ref;
794}
795
796template <typename Ref>
797constexpr GKO_ATTRIBUTES std::enable_if_t<!has_to_arithmetic_type<Ref>::value &&
798 !has_arithmetic_type<Ref>::value,
799 Ref>
800to_arithmetic_type(const Ref& ref)
801{
802 return ref;
803}
804
805
806// Note: All functions have postfix `impl` so they are not considered for
807// overload resolution (in case a class / function also is in the namespace
808// `detail`)
809template <typename T>
810GKO_ATTRIBUTES GKO_INLINE constexpr std::enable_if_t<!is_complex_s<T>::value, T>
811real_impl(const T& x)
812{
813 return x;
814}
815
816template <typename T>
817GKO_ATTRIBUTES GKO_INLINE constexpr std::enable_if_t<is_complex_s<T>::value,
819real_impl(const T& x)
820{
821 return x.real();
822}
823
824
825template <typename T>
826GKO_ATTRIBUTES GKO_INLINE constexpr std::enable_if_t<!is_complex_s<T>::value, T>
827imag_impl(const T&)
828{
829 return T{};
830}
831
832template <typename T>
833GKO_ATTRIBUTES GKO_INLINE constexpr std::enable_if_t<is_complex_s<T>::value,
835imag_impl(const T& x)
836{
837 return x.imag();
838}
839
840
841template <typename T>
842GKO_ATTRIBUTES GKO_INLINE constexpr std::enable_if_t<!is_complex_s<T>::value, T>
843conj_impl(const T& x)
844{
845 return x;
846}
847
848template <typename T>
849GKO_ATTRIBUTES GKO_INLINE constexpr std::enable_if_t<is_complex_s<T>::value, T>
850conj_impl(const T& x)
851{
852 return T{real_impl(x), -imag_impl(x)};
853}
854
855
856} // namespace detail
857
858
868template <typename T>
869GKO_ATTRIBUTES GKO_INLINE constexpr auto real(const T& x)
870{
871 return detail::real_impl(detail::to_arithmetic_type(x));
872}
873
874
884template <typename T>
885GKO_ATTRIBUTES GKO_INLINE constexpr auto imag(const T& x)
886{
887 return detail::imag_impl(detail::to_arithmetic_type(x));
888}
889
890
898template <typename T>
899GKO_ATTRIBUTES GKO_INLINE constexpr auto conj(const T& x)
900{
901 return detail::conj_impl(detail::to_arithmetic_type(x));
902}
903
904
912template <typename T>
913GKO_INLINE constexpr auto squared_norm(const T& x)
914 -> decltype(real(conj(x) * x))
915{
916 return real(conj(x) * x);
917}
918
919using std::abs;
920
930template <typename T>
931GKO_INLINE constexpr std::enable_if_t<!is_complex_s<T>::value, T> abs(
932 const T& x)
933{
934 return x >= zero<T>() ? x : -x;
935}
936
937
938template <typename T>
939GKO_INLINE constexpr std::enable_if_t<is_complex_s<T>::value, remove_complex<T>>
940abs(const T& x)
941{
942 return sqrt(squared_norm(x));
943}
944
945// increase the priority in function lookup
946GKO_INLINE gko::half abs(const std::complex<gko::half>& x)
947{
948 // Using float abs not sqrt on norm to avoid overflow
949 return static_cast<gko::half>(abs(std::complex<float>(x)));
950}
951
952
953using std::sqrt;
954
955GKO_INLINE gko::half sqrt(gko::half a)
956{
957 return gko::half(std::sqrt(float(a)));
958}
959
960GKO_INLINE std::complex<gko::half> sqrt(std::complex<gko::half> a)
961{
962 return std::complex<gko::half>(sqrt(std::complex<float>(
963 static_cast<float>(a.real()), static_cast<float>(a.imag()))));
964}
965
966
972template <typename T>
973GKO_INLINE constexpr T pi()
974{
975 return static_cast<T>(3.1415926535897932384626433);
976}
977
978
987template <typename T>
988GKO_INLINE constexpr std::complex<remove_complex<T>> unit_root(int64 n,
989 int64 k = 1)
990{
991 return std::polar(one<remove_complex<T>>(),
993}
994
995
1008template <typename T>
1009constexpr uint32 get_significant_bit(const T& n, uint32 hint = 0u) noexcept
1010{
1011 return (T{1} << (hint + 1)) > n ? hint : get_significant_bit(n, hint + 1u);
1012}
1013
1014
1026template <typename T>
1027constexpr T get_superior_power(const T& base, const T& limit,
1028 const T& hint = T{1}) noexcept
1029{
1030 return hint >= limit ? hint : get_superior_power(base, limit, hint * base);
1031}
1032
1033
1045template <typename T>
1046GKO_INLINE GKO_ATTRIBUTES std::enable_if_t<!is_complex_s<T>::value, bool>
1047is_finite(const T& value)
1048{
1049 constexpr T infinity{detail::infinity_impl<T>::value};
1050 return abs(value) < infinity;
1051}
1052
1053
1065template <typename T>
1066GKO_INLINE GKO_ATTRIBUTES std::enable_if_t<is_complex_s<T>::value, bool>
1067is_finite(const T& value)
1068{
1069 return is_finite(value.real()) && is_finite(value.imag());
1070}
1071
1072
1084template <typename T>
1085GKO_INLINE GKO_ATTRIBUTES T safe_divide(T a, T b)
1086{
1087 return b == zero<T>() ? zero<T>() : a / b;
1088}
1089
1090
1100template <typename T>
1101GKO_DEPRECATED(
1102 "is_nan can't be used safely on the device (MSVC+CUDA), and will thus be "
1103 "removed in a future release, without replacement")
1104GKO_INLINE GKO_ATTRIBUTES
1105 std::enable_if_t<!is_complex_s<T>::value, bool> is_nan(const T& value)
1106{
1107 using std::isnan;
1108 return isnan(value);
1109}
1110
1111
1121template <typename T>
1122GKO_DEPRECATED(
1123 "is_nan can't be used safely on the device (MSVC+CUDA), and will thus be "
1124 "removed in a future release, without replacement")
1125GKO_INLINE GKO_ATTRIBUTES std::enable_if_t<is_complex_s<T>::value, bool> is_nan(
1126 const T& value)
1127{
1128 return is_nan(value.real()) || is_nan(value.imag());
1129}
1130
1131
1139template <typename T>
1140GKO_INLINE constexpr std::enable_if_t<!is_complex_s<T>::value, T> nan()
1141{
1142 return std::numeric_limits<T>::quiet_NaN();
1143}
1144
1145
1153template <typename T>
1154GKO_INLINE constexpr std::enable_if_t<is_complex_s<T>::value, T> nan()
1155{
1157}
1158
1159
1160} // namespace gko
1161
1162
1163#endif // GKO_PUBLIC_CORE_BASE_MATH_HPP_
A class providing basic support for half precision floating point types.
Definition half.hpp:286
typename detail::make_void< Ts... >::type void_t
Use the custom implementation, since the std::void_t used in is_matrix_type_builder seems to trigger ...
Definition std_extensions.hpp:47
The Ginkgo namespace.
Definition abstract_factory.hpp:20
constexpr T one()
Returns the multiplicative identity for T.
Definition math.hpp:630
std::enable_if_t<!is_complex_s< T >::value, bool > is_finite(const T &value)
Checks if a floating point number is finite, meaning it is neither +/- infinity nor NaN.
Definition math.hpp:1047
constexpr T pi()
Returns the value of pi.
Definition math.hpp:973
constexpr std::enable_if_t<!is_complex_s< T >::value, T > abs(const T &x)
Returns the absolute value of the object.
Definition math.hpp:931
typename detail::remove_complex_s< T >::type remove_complex
Obtain the type which removed the complex of complex/scalar type or the template parameter of class b...
Definition math.hpp:260
constexpr increase_precision< T > round_up(T val)
Increases the precision of the input parameter.
Definition math.hpp:508
typename detail::next_precision_base_impl< T >::type next_precision_base
Obtains the next type in the singly-linked precision list.
Definition math.hpp:421
std::conditional_t< detail::type_size_impl< T >::value >=2 *Limit, typename detail::truncate_type_impl< T >::type, T > truncate_type
Truncates the type by half (by dropping bits), but ensures that it is at least Limit bits wide.
Definition math.hpp:556
typename detail::next_precision_impl< T >::type next_precision
Obtains the next type in the singly-linked precision list with half.
Definition math.hpp:438
typename detail::highest_precision_variadic< Ts... >::type highest_precision
Obtains the smallest arithmetic type that is able to store elements of all template parameter types e...
Definition math.hpp:478
typename detail::to_complex_s< T >::type to_complex
Obtain the type which adds the complex of complex/scalar type or the template parameter of class by a...
Definition math.hpp:279
constexpr uint32 get_significant_bit(const T &n, uint32 hint=0u) noexcept
Returns the position of the most significant bit of the number.
Definition math.hpp:1009
constexpr bool is_complex_or_scalar()
Checks if T is a complex/scalar type.
Definition math.hpp:245
std::enable_if_t<!is_complex_s< T >::value, bool > is_nan(const T &value)
Checks if a floating point number is NaN.
Definition math.hpp:1105
detail::is_complex_impl< T > is_complex_s
Allows to check if T is a complex value during compile time by accessing the value attribute of this ...
Definition math.hpp:211
constexpr std::enable_if_t<!is_complex_s< T >::value, T > nan()
Returns a quiet NaN of the given type.
Definition math.hpp:1140
constexpr T zero()
Returns the additive identity for T.
Definition math.hpp:602
constexpr bool is_zero(T value)
Returns true if and only if the given value is zero.
Definition math.hpp:668
constexpr auto imag(const T &x)
Returns the imaginary part of the object.
Definition math.hpp:885
std::uint32_t uint32
32-bit unsigned integral type.
Definition types.hpp:129
constexpr std::complex< remove_complex< T > > unit_root(int64 n, int64 k=1)
Returns the value of exp(2 * pi * i * k / n), i.e.
Definition math.hpp:988
next_precision_base< T > previous_precision_base
Obtains the previous type in the singly-linked precision list.
Definition math.hpp:431
typename detail::reduce_precision_impl< T >::type reduce_precision
Obtains the next type in the hierarchy with lower precision than T.
Definition math.hpp:456
constexpr reduce_precision< T > round_down(T val)
Reduces the precision of the input parameter.
Definition math.hpp:492
std::int64_t int64
64-bit signed integral type.
Definition types.hpp:112
constexpr int64 ceildiv(int64 num, int64 den)
Performs integer division with rounding up.
Definition math.hpp:590
constexpr bool is_complex()
Checks if T is a complex type.
Definition math.hpp:221
T safe_divide(T a, T b)
Computes the quotient of the given parameters, guarding against division by zero.
Definition math.hpp:1085
constexpr T min(const T &x, const T &y)
Returns the smaller of the arguments.
Definition math.hpp:719
constexpr auto squared_norm(const T &x) -> decltype(real(conj(x) *x))
Returns the squared norm of the object.
Definition math.hpp:913
detail::is_complex_or_scalar_impl< T > is_complex_or_scalar_s
Allows to check if T is a complex or scalar value during compile time by accessing the value attribut...
Definition math.hpp:235
constexpr T get_superior_power(const T &base, const T &limit, const T &hint=T{1}) noexcept
Returns the smallest power of base not smaller than limit.
Definition math.hpp:1027
typename detail::increase_precision_impl< T >::type increase_precision
Obtains the next type in the hierarchy with higher precision than T.
Definition math.hpp:463
remove_complex< T > to_real
to_real is alias of remove_complex
Definition math.hpp:288
constexpr auto conj(const T &x)
Returns the conjugate of an object.
Definition math.hpp:899
constexpr bool is_nonzero(T value)
Returns true if and only if the given value is not zero.
Definition math.hpp:683
std::uint16_t uint16
16-bit unsigned integral type.
Definition types.hpp:123
constexpr T max(const T &x, const T &y)
Returns the larger of the arguments.
Definition math.hpp:701
constexpr auto real(const T &x)
Returns the real part of the object.
Definition math.hpp:869
STL namespace.
Access the underlying real type of a complex number.
Definition math.hpp:185
T type
The type.
Definition math.hpp:187
Used to convert objects of type S to objects of type R using static_cast.
Definition math.hpp:568
R operator()(S val)
Converts the object to result type.
Definition math.hpp:575