StochTree 0.2.1
Loading...
Searching...
No Matches
openmp_utils.h
1#ifndef STOCHTREE_OPENMP_UTILS_H
2#define STOCHTREE_OPENMP_UTILS_H
3
4#include <stochtree/common.h>
5#include <algorithm>
6
7namespace StochTree {
8
9#ifdef _OPENMP
10
11#include <omp.h>
12#define STOCHTREE_HAS_OPENMP 1
13
14// OpenMP thread management
15inline int get_max_threads() {
16 return omp_get_max_threads();
17}
18
19inline int get_thread_num() {
20 return omp_get_thread_num();
21}
22
23inline int get_num_threads() {
24 return omp_get_num_threads();
25}
26
27inline void set_num_threads(int num_threads) {
28 omp_set_num_threads(num_threads);
29}
30
31#define STOCHTREE_PARALLEL_FOR(num_threads) \
32 _Pragma("omp parallel for num_threads(num_threads)")
33
34#define STOCHTREE_REDUCTION_ADD(var) \
35 _Pragma("omp reduction(+:var)")
36
37#define STOCHTREE_CRITICAL \
38 _Pragma("omp critical")
39
40#else
41#define STOCHTREE_HAS_OPENMP 0
42
43inline int get_max_threads() {return 1;}
44
45inline int get_thread_num() {return 0;}
46
47inline int get_num_threads() {return 1;}
48
49inline void set_num_threads(int num_threads) {}
50
51#define STOCHTREE_PARALLEL_FOR(num_threads)
52
53#define STOCHTREE_REDUCTION_ADD(var)
54
55#define STOCHTREE_CRITICAL
56
57#endif
58
59static int GetMaxThreads() {
60 return get_max_threads();
61}
62
63static int GetCurrentThreadNum() {
64 return get_thread_num();
65}
66
67static int GetNumThreads() {
68 return get_num_threads();
69}
70
71static void SetNumThreads(int num_threads) {
72 set_num_threads(num_threads);
73}
74
75static bool IsOpenMPAvailable() {
76 return STOCHTREE_HAS_OPENMP;
77}
78
79static int GetOptimalThreadCount(int workload_size, int min_work_per_thread = 1000) {
80 if (!IsOpenMPAvailable()) {
81 return 1;
82 }
83
84 int max_threads = GetMaxThreads();
85 int optimal_threads = workload_size / min_work_per_thread;
86
87 return std::min(optimal_threads, max_threads);
88}
89
90// Parallel execution utilities
91template<typename Func>
92void ParallelFor(int start, int end, int num_threads, Func func) {
93 if (num_threads <= 0) {
94 num_threads = GetOptimalThreadCount(end - start);
95 }
96
97 if (num_threads == 1 || !STOCHTREE_HAS_OPENMP) {
98 // Sequential execution
99 for (int i = start; i < end; ++i) {
100 func(i);
101 }
102 } else {
103 // Parallel execution
104 STOCHTREE_PARALLEL_FOR(num_threads)
105 for (int i = start; i < end; ++i) {
106 func(i);
107 }
108 }
109}
110
111} // namespace StochTree
112
113#endif // STOCHTREE_OPENMP_UTILS_H
Definition category_tracker.h:36