- chen2312 的博客
C++线段树模板(very easy)
- @ 2025-10-5 14:57:21
C++线段树模板(very easy):
template <typename T>
class SegmentTree {
private:
std::vector<T> tree; // 存储线段树的数组
std::vector<T> lazy; // 懒标记数组,用于区间更新
std::vector<T> data; // 存储原始数据的数组
int n; // 原始数据的大小
std::function<T(T, T)> op; // 合并两个节点的操作函数
T default_val; // 操作的默认值(如求和的0,求min的INF)
// 构建线段树
void build(int node, int start, int end) {
if (start == end) {
tree[node] = data[start];
} else {
int mid = (start + end) / 2;
int left = 2 * node + 1;
int right = 2 * node + 2;
build(left, start, mid);
build(right, mid + 1, end);
tree[node] = op(tree[left], tree[right]);
}
lazy[node] = 0; // 初始化懒标记为0
}
// 下推懒标记
void push_down(int node, int start, int end) {
if (lazy[node] == 0) return; // 没有需要传递的标记
int mid = (start + end) / 2;
int left = 2 * node + 1;
int right = 2 * node + 2;
// 更新左子树
tree[left] += lazy[node] * (mid - start + 1);
lazy[left] += lazy[node];
// 更新右子树
tree[right] += lazy[node] * (end - mid);
lazy[right] += lazy[node];
// 同步更新原始数据(关键修复)
if (start == end) {
data[start] += lazy[node];
}
// 清除当前节点的懒标记
lazy[node] = 0;
}
// 区间修改(内部实现):[l, r]区间每个元素加上val
void range_add(int node, int start, int end, int l, int r, T val) {
if (r < start || end < l) {
return; // 无交集
}
if (l <= start && end <= r) {
// 当前区间完全在更新区间内
tree[node] += val * (end - start + 1);
lazy[node] += val;
// 叶子节点直接更新原始数据
if (start == end) {
data[start] += val;
}
return;
}
// 下推懒标记
push_down(node, start, end);
int mid = (start + end) / 2;
int left = 2 * node + 1;
int right = 2 * node + 2;
range_add(left, start, mid, l, r, val);
range_add(right, mid + 1, end, l, r, val);
tree[node] = op(tree[left], tree[right]);
}
// 区间查询(内部实现)
T query_range(int node, int start, int end, int l, int r) {
if (r < start || end < l) {
return default_val; // 无交集
}
if (l <= start && end <= r) {
return tree[node]; // 完全覆盖
}
// 下推懒标记
push_down(node, start, end);
int mid = (start + end) / 2;
int left = 2 * node + 1;
int right = 2 * node + 2;
T left_res = query_range(left, start, mid, l, r);
T right_res = query_range(right, mid + 1, end, l, r);
return op(left_res, right_res);
}
// 单点查询(内部实现)
T query_point(int node, int start, int end, int idx) {
if (start == end) {
return tree[node];
}
// 下推懒标记
push_down(node, start, end);
int mid = (start + end) / 2;
int left = 2 * node + 1;
int right = 2 * node + 2;
if (idx <= mid) {
return query_point(left, start, mid, idx);
} else {
return query_point(right, mid + 1, end, idx);
}
}
// 单点修改(内部实现)
void update_point(int node, int start, int end, int idx, T val) {
if (start == end) {
data[idx] = val;
tree[node] = val;
} else {
// 下推懒标记
push_down(node, start, end);
int mid = (start + end) / 2;
int left = 2 * node + 1;
int right = 2 * node + 2;
if (idx <= mid) {
update_point(left, start, mid, idx, val);
} else {
update_point(right, mid + 1, end, idx, val);
}
tree[node] = op(tree[left], tree[right]);
}
}
public:
// 构造函数
SegmentTree(const std::vector<T>& _data, std::function<T(T, T)> _op, T _default_val)
: data(_data), op(_op), default_val(_default_val) {
n = data.size();
if (n == 0) return;
// 计算线段树的大小
int size = 1;
while (size < n) size <<= 1;
tree.resize(2 * size, default_val);
lazy.resize(2 * size, 0);
build(0, 0, n - 1);
}
// 对外接口:单点查询
T query(int idx) {
if (idx < 0 || idx >= n) {
//cerr << "Index out of bounds" << endl;
return default_val;
}
return query_point(0, 0, n - 1, idx);
}
// 对外接口:区间查询 [l, r]
T query(int l, int r) {
if (l < 0 || r >= n || l > r) {
//cerr << "Invalid query range" << endl;
return default_val;
}
return query_range(0, 0, n - 1, l, r);
}
// 对外接口:单点修改
void update(int idx, T val) {
if (idx < 0 || idx >= n) {
//cerr << "Index out of bounds" << endl;
return;
}
update_point(0, 0, n - 1, idx, val);
}
// 对外接口:区间修改 [l, r],每个元素加上val
void range_add(int l, int r, T val) {
if (l < 0 || r >= n || l > r) {
//cerr << "Invalid range for update" << endl;
return;
}
range_add(0, 0, n - 1, l, r, val);
}
// 获取原始数据
std::vector<T> get_data() const {
return data;
}
// 打印线段树(用于调试)
void print_tree() const {
for (size_t i = 0; i < tree.size(); ++i) {
fastout << tree[i] << " ";
}
fastout << fastendl;
}
// 打印懒标记(用于调试)
void print_lazy() const {
for (size_t i = 0; i < lazy.size(); ++i) {
fastout << lazy[i] << " ";
}
fastout << fastendl;
}
};
int n;
vector<long long>s(100001);
int main(){
fastin>>n;
for(int i=1;i<=n;i++)
fastin>>s[i];
SegmentTree<long long> st(s, [](long long a, long long b) { return a+b; }, 0);
}