StochTree 0.1.1
Loading...
Searching...
No Matches
openmp_utils.h
1#ifndef STOCHTREE_OPENMP_UTILS_H
2#define STOCHTREE_OPENMP_UTILS_H
3
4#include <stochtree/common.h>
5#include <algorithm>
6
7namespace StochTree {
8
9#ifdef STOCHTREE_OPENMP_AVAILABLE
10
11#include <omp.h>
12#define STOCHTREE_HAS_OPENMP 1
13
14// OpenMP thread management
15inline int get_max_threads() {
16 return omp_get_max_threads();
17}
18
19inline int get_thread_num() {
20 return omp_get_thread_num();
21}
22
23inline int get_num_threads() {
24 return omp_get_num_threads();
25}
26
27inline void set_num_threads(int num_threads) {
28 omp_set_num_threads(num_threads);
29}
30
31#define STOCHTREE_PARALLEL_FOR(num_threads) \
32 _Pragma("omp parallel for num_threads(num_threads)")
33
34#define STOCHTREE_REDUCTION_ADD(var) \
35 _Pragma("omp reduction(+:var)")
36
37#define STOCHTREE_CRITICAL \
38 _Pragma("omp critical")
39
40#else
41#define STOCHTREE_HAS_OPENMP 0
42
43// Fallback implementations when OpenMP is not available
44inline int get_max_threads() {return 1;}
45
46inline int get_thread_num() {return 0;}
47
48inline int get_num_threads() {return 1;}
49
50inline void set_num_threads(int num_threads) {}
51
52#define STOCHTREE_PARALLEL_FOR(num_threads)
53
54#define STOCHTREE_REDUCTION_ADD(var)
55
56#define STOCHTREE_CRITICAL
57
58#endif
59
60static int GetMaxThreads() {
61 return get_max_threads();
62}
63
64static int GetCurrentThreadNum() {
65 return get_thread_num();
66}
67
68static int GetNumThreads() {
69 return get_num_threads();
70}
71
72static void SetNumThreads(int num_threads) {
73 set_num_threads(num_threads);
74}
75
76static bool IsOpenMPAvailable() {
77 return STOCHTREE_HAS_OPENMP;
78}
79
80static int GetOptimalThreadCount(int workload_size, int min_work_per_thread = 1000) {
81 if (!IsOpenMPAvailable()) {
82 return 1;
83 }
84
85 int max_threads = GetMaxThreads();
86 int optimal_threads = workload_size / min_work_per_thread;
87
88 return std::min(optimal_threads, max_threads);
89}
90
91// Parallel execution utilities
92template<typename Func>
93void ParallelFor(int start, int end, int num_threads, Func func) {
94 if (num_threads <= 0) {
95 num_threads = GetOptimalThreadCount(end - start);
96 }
97
98 if (num_threads == 1 || !STOCHTREE_HAS_OPENMP) {
99 // Sequential execution
100 for (int i = start; i < end; ++i) {
101 func(i);
102 }
103 } else {
104 // Parallel execution
105 STOCHTREE_PARALLEL_FOR(num_threads)
106 for (int i = start; i < end; ++i) {
107 func(i);
108 }
109 }
110}
111
112} // namespace StochTree
113
114#endif // STOCHTREE_OPENMP_UTILS_H
Definition category_tracker.h:36