1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
|
#pragma once
#include <ntddk.h>
#include <Ntstrsafe.h>
#define ARRAY_SIZE(a) (sizeof(a)/sizeof((a)[0]))
class CWdmSpinLock
{
public:
CWdmSpinLock()
{ KeInitializeSpinLock(&m_Lock); }
void Lock()
{ KeAcquireSpinLock(&m_Lock, &m_OldIrql); }
void Unlock()
{ KeReleaseSpinLock(&m_Lock, m_OldIrql); }
private:
KSPIN_LOCK m_Lock;
KIRQL m_OldIrql;
};
template <typename T>
class CLockedContext
{
public:
CLockedContext(T &LockObject)
: m_LockObject(LockObject)
{ m_LockObject.Lock(); }
~CLockedContext()
{ m_LockObject.Unlock(); }
private:
T &m_LockObject;
CLockedContext(const CLockedContext&) = delete;
CLockedContext& operator= (const CLockedContext&) = delete;
};
typedef CLockedContext<CWdmSpinLock> TSpinLocker;
class CWdmRefCounter
{
public:
void AddRef() { InterlockedIncrement(&m_Counter); }
void AddRef(LONG RefCnt) { InterlockedAdd(&m_Counter, RefCnt); }
LONG Release() { return InterlockedDecrement(&m_Counter); }
LONG Release(LONG RefCnt) { AddRef(-RefCnt); }
operator LONG () { return m_Counter; }
private:
LONG m_Counter = 0;
};
class CLockedAccess
{
public:
void Lock() { m_Lock.Lock(); }
void Unlock() { m_Lock.Unlock(); }
private:
CWdmSpinLock m_Lock;
};
class CRawAccess
{
public:
void Lock() { }
void Unlock() { }
};
class CCountingObject
{
public:
void CounterIncrement() { m_Counter++; }
void CounterDecrement() { m_Counter--; }
ULONG GetCount() { return m_Counter; }
private:
ULONG m_Counter = 0;
};
class CNonCountingObject
{
public:
void CounterIncrement() { }
void CounterDecrement() { }
protected:
ULONG GetCount() { return 0; }
};
template <typename TEntryType, typename TAccessStrategy, typename TCountingStrategy>
class CWdmList : private TAccessStrategy, public TCountingStrategy
{
public:
CWdmList()
{ InitializeListHead(&m_List); }
~CWdmList()
{ Clear(); }
void Clear()
{ ForEachDetached([](TEntryType* Entry) { delete Entry; return true; }); }
bool IsEmpty()
{ return IsListEmpty(&m_List) ? true : false; }
TEntryType *Pop()
{
CLockedContext<TAccessStrategy> LockedContext(*this);
return Pop_LockLess();
}
ULONG Push(TEntryType *Entry)
{
CLockedContext<TAccessStrategy> LockedContext(*this);
InsertHeadList(&m_List, Entry->GetListEntry());
CounterIncrement();
return GetCount();
}
ULONG PushBack(TEntryType *Entry)
{
CLockedContext<TAccessStrategy> LockedContext(*this);
InsertTailList(&m_List, Entry->GetListEntry());
CounterIncrement();
return GetCount();
}
void Remove(TEntryType *Entry)
{
CLockedContext<TAccessStrategy> LockedContext(*this);
Remove_LockLess(Entry->GetListEntry());
}
template <typename TFunctor>
bool ForEachDetached(TFunctor Functor)
{
CLockedContext<TAccessStrategy> LockedContext(*this);
while (!IsListEmpty(&m_List))
{
if (!Functor(Pop_LockLess()))
{
return false;
}
}
return true;
}
template <typename TPredicate, typename TFunctor>
bool ForEachDetachedIf(TPredicate Predicate, TFunctor Functor)
{
return ForEachPrepareIf(Predicate, [this](PLIST_ENTRY Entry){ Remove_LockLess(Entry); }, Functor);
}
template <typename TFunctor>
bool ForEach(TFunctor Functor)
{
return ForEachPrepareIf([](TEntryType*) { return true; }, [](PLIST_ENTRY){}, Functor);
}
template <typename TPredicate, typename TFunctor>
bool ForEachIf(TPredicate Predicate, TFunctor Functor)
{
return ForEachPrepareIf(Predicate, [](PLIST_ENTRY){}, Functor);
}
private:
template <typename TPredicate, typename TPrepareFunctor, typename TFunctor>
bool ForEachPrepareIf(TPredicate Predicate, TPrepareFunctor Prepare, TFunctor Functor)
{
CLockedContext<TAccessStrategy> LockedContext(*this);
PLIST_ENTRY NextEntry = nullptr;
for (auto CurrEntry = m_List.Flink; CurrEntry != &m_List; CurrEntry = NextEntry)
{
NextEntry = CurrEntry->Flink;
auto Object = TEntryType::GetByListEntry(CurrEntry);
if (Predicate(Object))
{
Prepare(CurrEntry);
if (!Functor(Object))
{
return false;
}
}
}
return true;
}
TEntryType *Pop_LockLess()
{
CounterDecrement();
return TEntryType::GetByListEntry(RemoveHeadList(&m_List));
}
void Remove_LockLess(PLIST_ENTRY Entry)
{
RemoveEntryList(Entry);
CounterDecrement();
}
LIST_ENTRY m_List;
ULONG m_NumEntries = 0;
};
static inline bool ConstTrue(...) { return true; }
static inline bool ConstFalse(...) { return false; }
template <typename TEntryType, typename TAccessStrategy, typename TCountingStrategy>
class CWdmSet : private TAccessStrategy, public TCountingStrategy
{
public:
bool Add(TEntryType *NewEntry)
{
CLockedContext<TAccessStrategy> LockedContext(*this);
if (!Contains_LockLess(NewEntry))
{
m_Objects.PushBack(NewEntry);
CounterIncrement();
return true;
}
return false;
}
template <typename TEntryId>
bool Delete(TEntryId *Id)
{
auto Removed = false;
CLockedContext<TAccessStrategy> LockedContext(*this);
m_Objects.ForEachDetachedIf([Id](TEntryType *ExistingEntry) { return *ExistingEntry == *Id; },
[this, &Removed](TEntryType *ExistingEntry)
{
delete ExistingEntry;
CounterDecrement();
Removed = true;
return false;
});
return Removed;
}
void Dump()
{
CLockedContext<TAccessStrategy> LockedContext(*this);
m_Objects.ForEach([](TEntryType *Entry) { Entry->Dump(); return true; });
}
template <typename TEntryId>
bool Contains(TEntryId *Id)
{
CLockedContext<TAccessStrategy> LockedContext(*this);
return Contains_LockLess(Id);
}
template <typename TEntryId, typename TModifier>
bool ModifyOne(TEntryId *Id, TModifier ModifierFunc)
{
CLockedContext<TAccessStrategy> LockedContext(*this);
return !m_Objects.ForEachIf([Id](TEntryType *ExistingEntry) { return *ExistingEntry == *Id; },
[&ModifierFunc](TEntryType *Entry) { ModifierFunc(Entry); return false; });
}
private:
template <typename TEntryId>
bool Contains_LockLess(TEntryId *Id)
{
auto MatchFound = false;
m_Objects.ForEachIf([Id](TEntryType *ExistingEntry) { return *ExistingEntry == *Id; },
[&MatchFound](TEntryType *) { MatchFound = true; return false; });
return MatchFound;
}
CWdmList<TEntryType, CRawAccess, CNonCountingObject> m_Objects;
};
class CWdmEvent : public CAllocatable<NonPagedPool, 'VEHR'>
{
public:
CWdmEvent(EVENT_TYPE Type = SynchronizationEvent, BOOLEAN InitialState = FALSE)
{ KeInitializeEvent(&m_Event, Type, InitialState); };
NTSTATUS Wait(bool WithTimeout = false, LONGLONG Timeout = 0, bool Alertable = false);
bool Set(KPRIORITY Increment = IO_NO_INCREMENT, bool Wait = false)
{ return KeSetEvent(&m_Event, Increment, Wait ? TRUE : FALSE) ? true : false; }
void Clear()
{ KeClearEvent(&m_Event); }
bool Reset()
{ return KeResetEvent(&m_Event) ? true : false; }
operator PKEVENT () { return &m_Event; }
CWdmEvent(const CWdmEvent&) = delete;
CWdmEvent& operator= (const CWdmEvent&) = delete;
private:
KEVENT m_Event;
};
class CStringBase
{
public:
bool operator== (const CStringBase &Other) const
{ return *this == Other.m_String; }
bool operator== (const UNICODE_STRING& Str) const;
bool operator== (PCWSTR Other) const
{
UNICODE_STRING str;
if (NT_SUCCESS(RtlUnicodeStringInit(&str, Other)))
{
return *this == str;
}
return false;
}
operator PCUNICODE_STRING() const { return &m_String; };
NTSTATUS ToString(ULONG Val, ULONG Base)
{ return RtlIntegerToUnicodeString(Val, Base, &m_String); }
size_t ToWSTR(PWCHAR Buffer, size_t SizeBytes) const;
protected:
CStringBase(const CStringBase&) = delete;
CStringBase& operator= (const CStringBase&) = delete;
CStringBase() {};
~CStringBase() {};
UNICODE_STRING m_String;
};
class CStringHolder : public CStringBase
{
public:
NTSTATUS Attach(NTSTRSAFE_PCWSTR String)
{ return RtlUnicodeStringInit(&m_String, String); }
NTSTATUS Attach(NTSTRSAFE_PCWSTR String, USHORT SizeInBytes)
{
m_String.Length = SizeInBytes;
m_String.MaximumLength = SizeInBytes;
m_String.Buffer = const_cast<PWCH>(String);
return RtlUnicodeStringValidate(&m_String);
}
//This initialization may be done in-class without
//constructor definition but MS compiler crashes with internal error
CStringHolder()
{ m_String = {}; }
private:
CStringHolder(const CStringHolder&) = delete;
CStringHolder& operator= (const CStringHolder&) = delete;
};
class CString : public CStringBase
{
public:
NTSTATUS Create(NTSTRSAFE_PCWSTR String);
NTSTATUS Create(NTSTRSAFE_PCWSTR Prefix, ULONG Postfix);
NTSTATUS Append(PCUNICODE_STRING String);
NTSTATUS Append(ULONG Num, ULONG Base = 10);
void Destroy();
//This initialization may be done in-class without
//constructor definition but MS compiler crashes with internal error
CString()
{ m_String = {}; }
~CString()
{ Destroy(); }
private:
NTSTATUS Resize(USHORT NewLenBytes);
CString(const CString&) = delete;
CString& operator= (const CString&) = delete;
};
PVOID DuplicateStaticBuffer(const void *Buffer, SIZE_T Length, POOL_TYPE PoolType = PagedPool);
template<typename T>
class CInstanceCounter
{
public:
CInstanceCounter()
{
static LONG volatile Counter = 0;
m_Number = InterlockedIncrement(&Counter);
}
operator ULONG() const { return static_cast<ULONG>(m_Number); };
private:
LONG m_Number = 0;
CInstanceCounter(const CInstanceCounter&) = delete;
CInstanceCounter& operator= (const CInstanceCounter&) = delete;
};
static inline
LONGLONG SecondsTo100Nanoseconds(LONGLONG Seconds)
{
return Seconds * 10 * 1000 * 1000;
}
|