summaryrefslogtreecommitdiff
path: root/inject/injectee.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'inject/injectee.cpp')
-rw-r--r--inject/injectee.cpp149
1 files changed, 91 insertions, 58 deletions
diff --git a/inject/injectee.cpp b/inject/injectee.cpp
index 6417d487..af60c8b0 100644
--- a/inject/injectee.cpp
+++ b/inject/injectee.cpp
@@ -42,6 +42,7 @@
#include <stdarg.h>
#include <string.h>
+#include <algorithm>
#include <set>
#include <map>
#include <functional>
@@ -402,11 +403,12 @@ replaceAddress(LPVOID *lpOldAddress, LPVOID lpNewAddress)
*
*/
static LPVOID *
-getOldFunctionAddress(HMODULE hModule,
- const char *szDescriptorName,
- DWORD OriginalFirstThunk,
- DWORD FirstThunk,
- const char* pszFunctionName)
+getPatchAddress(HMODULE hModule,
+ const char *szDescriptorName,
+ DWORD OriginalFirstThunk,
+ DWORD FirstThunk,
+ const char* pszFunctionName,
+ LPVOID lpOldAddress)
{
if (VERBOSITY >= 4) {
debugPrintf("inject: %s(%s, %s)\n", __FUNCTION__,
@@ -416,7 +418,7 @@ getOldFunctionAddress(HMODULE hModule,
PIMAGE_THUNK_DATA pThunkIAT = rvaToVa<IMAGE_THUNK_DATA>(hModule, FirstThunk);
- UINT_PTR pRealFunction = 0;
+ UINT_PTR pOldFunction = (UINT_PTR)lpOldAddress;
PIMAGE_THUNK_DATA pThunk;
if (OriginalFirstThunk) {
@@ -429,15 +431,10 @@ getOldFunctionAddress(HMODULE hModule,
if (OriginalFirstThunk == 0 ||
pThunk->u1.Ordinal & IMAGE_ORDINAL_FLAG) {
// No name -- search by the real function address
- if (!pRealFunction) {
- HMODULE hRealModule = GetModuleHandleA(szDescriptorName);
- assert(hRealModule);
- pRealFunction = (UINT_PTR)GetProcAddress(hRealModule, pszFunctionName);
- if (!pRealFunction) {
- return NULL;
- }
+ if (!pOldFunction) {
+ return NULL;
}
- if (pThunkIAT->u1.Function == pRealFunction) {
+ if (pThunkIAT->u1.Function == pOldFunction) {
return (LPVOID *)(&pThunkIAT->u1.Function);
}
} else {
@@ -457,17 +454,19 @@ getOldFunctionAddress(HMODULE hModule,
static LPVOID *
-getOldFunctionAddress(HMODULE hModule,
- PIMAGE_IMPORT_DESCRIPTOR pImportDescriptor,
- const char* pszFunctionName)
+getPatchAddress(HMODULE hModule,
+ PIMAGE_IMPORT_DESCRIPTOR pImportDescriptor,
+ const char* pszFunctionName,
+ LPVOID lpOldAddress)
{
assert(pImportDescriptor->TimeDateStamp != 0 || pImportDescriptor->Name != 0);
- return getOldFunctionAddress(hModule,
- getDescriptorName(hModule, pImportDescriptor),
- pImportDescriptor->OriginalFirstThunk,
- pImportDescriptor->FirstThunk,
- pszFunctionName);
+ return getPatchAddress(hModule,
+ getDescriptorName(hModule, pImportDescriptor),
+ pImportDescriptor->OriginalFirstThunk,
+ pImportDescriptor->FirstThunk,
+ pszFunctionName,
+ lpOldAddress);
}
@@ -475,17 +474,19 @@ getOldFunctionAddress(HMODULE hModule,
// http://www.microsoft.com/msj/1298/hood/hood1298.aspx
// http://msdn.microsoft.com/en-us/library/16b2dyk5.aspx
static LPVOID *
-getOldFunctionAddress(HMODULE hModule,
- PImgDelayDescr pDelayDescriptor,
- const char* pszFunctionName)
+getPatchAddress(HMODULE hModule,
+ PImgDelayDescr pDelayDescriptor,
+ const char* pszFunctionName,
+ LPVOID lpOldAddress)
{
assert(pDelayDescriptor->rvaDLLName != 0);
- return getOldFunctionAddress(hModule,
- getDescriptorName(hModule, pDelayDescriptor),
- pDelayDescriptor->rvaINT,
- pDelayDescriptor->rvaIAT,
- pszFunctionName);
+ return getPatchAddress(hModule,
+ getDescriptorName(hModule, pDelayDescriptor),
+ pDelayDescriptor->rvaINT,
+ pDelayDescriptor->rvaIAT,
+ pszFunctionName,
+ lpOldAddress);
}
@@ -496,24 +497,25 @@ patchFunction(HMODULE hModule,
const char *pszDllName,
T pImportDescriptor,
const char *pszFunctionName,
+ LPVOID lpOldAddress,
LPVOID lpNewAddress)
{
- LPVOID* lpOldFunctionAddress = getOldFunctionAddress(hModule, pImportDescriptor, pszFunctionName);
- if (lpOldFunctionAddress == NULL) {
+ LPVOID* lpPatchAddress = getPatchAddress(hModule, pImportDescriptor, pszFunctionName, lpOldAddress);
+ if (lpPatchAddress == NULL) {
return FALSE;
}
- if (*lpOldFunctionAddress == lpNewAddress) {
+ if (*lpPatchAddress == lpNewAddress) {
return TRUE;
}
- DWORD Offset = (DWORD)(UINT_PTR)lpOldFunctionAddress - (UINT_PTR)hModule;
+ DWORD Offset = (DWORD)(UINT_PTR)lpPatchAddress - (UINT_PTR)hModule;
if (VERBOSITY > 0) {
debugPrintf("inject: patching %s!0x%lx -> %s!%s\n", szModule, Offset, pszDllName, pszFunctionName);
}
BOOL bRet;
- bRet = replaceAddress(lpOldFunctionAddress, lpNewAddress);
+ bRet = replaceAddress(lpPatchAddress, lpNewAddress);
if (!bRet) {
debugPrintf("inject: failed to patch %s!0x%lx -> %s!%s\n", szModule, Offset, pszDllName, pszFunctionName);
}
@@ -561,11 +563,19 @@ static std::set<HMODULE>
g_hHookedModules;
+enum Action {
+ ACTION_HOOK,
+ ACTION_UNHOOK,
+
+};
+
+
template< class T >
void
patchDescriptor(HMODULE hModule,
const char *szModule,
- T pImportDescriptor)
+ T pImportDescriptor,
+ Action action)
{
const char* szDescriptorName = getDescriptorName(hModule, pImportDescriptor);
@@ -578,11 +588,26 @@ patchDescriptor(HMODULE hModule,
FunctionMap::const_iterator fnIt;
for (fnIt = functionMap.begin(); fnIt != functionMap.end(); ++fnIt) {
const char *szFunctionName = fnIt->first;
- LPVOID lpNewAddress = fnIt->second;
+ LPVOID lpHookAddress = fnIt->second;
+
+ // Knowning the real address is useful when patching imports by ordinal
+ LPVOID lpRealAddress = NULL;
+ HMODULE hRealModule = GetModuleHandleA(szDescriptorName);
+ if (hRealModule) {
+ assert(hRealModule != g_hHookModule);
+ lpRealAddress = (LPVOID)GetProcAddress(hRealModule, szFunctionName);
+ }
+
+ LPVOID lpOldAddress = lpRealAddress;
+ LPVOID lpNewAddress = lpHookAddress;
+
+ if (action == ACTION_UNHOOK) {
+ std::swap(lpOldAddress, lpNewAddress);
+ }
- BOOL bHooked;
- bHooked = patchFunction(hModule, szModule, szMatchModule, pImportDescriptor, szFunctionName, lpNewAddress);
- if (bHooked && !module.bInternal && pSharedMem) {
+ BOOL bPatched;
+ bPatched = patchFunction(hModule, szModule, szMatchModule, pImportDescriptor, szFunctionName, lpOldAddress, lpNewAddress);
+ if (action == ACTION_HOOK && bPatched && !module.bInternal && pSharedMem) {
pSharedMem->bReplaced = TRUE;
}
}
@@ -592,7 +617,8 @@ patchDescriptor(HMODULE hModule,
static void
patchModule(HMODULE hModule,
- const char *szModule)
+ const char *szModule,
+ Action action)
{
/* Never patch this module */
if (hModule == g_hThisModule) {
@@ -605,12 +631,14 @@ patchModule(HMODULE hModule,
}
/* Hook modules only once */
- std::pair< std::set<HMODULE>::iterator, bool > ret;
- EnterCriticalSection(&Mutex);
- ret = g_hHookedModules.insert(hModule);
- LeaveCriticalSection(&Mutex);
- if (!ret.second) {
- return;
+ if (action == ACTION_HOOK) {
+ std::pair< std::set<HMODULE>::iterator, bool > ret;
+ EnterCriticalSection(&Mutex);
+ ret = g_hHookedModules.insert(hModule);
+ LeaveCriticalSection(&Mutex);
+ if (!ret.second) {
+ return;
+ }
}
const char *szBaseName = getBaseName(szModule);
@@ -635,7 +663,7 @@ patchModule(HMODULE hModule,
if (pImportDescriptor) {
while (pImportDescriptor->FirstThunk) {
- patchDescriptor(hModule, szModule, pImportDescriptor);
+ patchDescriptor(hModule, szModule, pImportDescriptor, action);
++pImportDescriptor;
}
@@ -651,15 +679,16 @@ patchModule(HMODULE hModule,
szName);
}
- patchDescriptor(hModule, szModule, pDelayDescriptor);
+ patchDescriptor(hModule, szModule, pDelayDescriptor, action);
++pDelayDescriptor;
}
}
}
+
static void
-patchAllModules(void)
+patchAllModules(Action action)
{
HANDLE hModuleSnap = CreateToolhelp32Snapshot(TH32CS_SNAPMODULE, GetCurrentProcessId());
if (hModuleSnap == INVALID_HANDLE_VALUE) {
@@ -670,7 +699,7 @@ patchAllModules(void)
me32.dwSize = sizeof me32;
if (Module32First(hModuleSnap, &me32)) {
do {
- patchModule(me32.hModule, me32.szExePath);
+ patchModule(me32.hModule, me32.szExePath, action);
} while (Module32Next(hModuleSnap, &me32));
}
@@ -678,8 +707,6 @@ patchAllModules(void)
}
-
-
static HMODULE WINAPI
MyLoadLibraryA(LPCSTR lpLibFileName)
{
@@ -714,7 +741,7 @@ MyLoadLibraryA(LPCSTR lpLibFileName)
}
// Hook all new modules (and not just this one, to pick up any dependencies)
- patchAllModules();
+ patchAllModules(ACTION_HOOK);
SetLastError(dwLastError);
return hModule;
@@ -732,7 +759,7 @@ MyLoadLibraryW(LPCWSTR lpLibFileName)
}
// Hook all new modules (and not just this one, to pick up any dependencies)
- patchAllModules();
+ patchAllModules(ACTION_HOOK);
SetLastError(dwLastError);
return hModule;
@@ -792,7 +819,7 @@ MyLoadLibraryExA(LPCSTR lpLibFileName, HANDLE hFile, DWORD dwFlags)
}
// Hook all new modules (and not just this one, to pick up any dependencies)
- patchAllModules();
+ patchAllModules(ACTION_HOOK);
SetLastError(dwLastError);
return hModule;
@@ -810,7 +837,7 @@ MyLoadLibraryExW(LPCWSTR lpLibFileName, HANDLE hFile, DWORD dwFlags)
}
// Hook all new modules (and not just this one, to pick up any dependencies)
- patchAllModules();
+ patchAllModules(ACTION_HOOK);
SetLastError(dwLastError);
return hModule;
@@ -1054,7 +1081,7 @@ DllMain(HINSTANCE hinstDLL, DWORD fdwReason, LPVOID lpReserved)
dumpRegisteredHooks();
- patchAllModules();
+ patchAllModules(ACTION_HOOK);
break;
case DLL_THREAD_ATTACH:
@@ -1067,6 +1094,12 @@ DllMain(HINSTANCE hinstDLL, DWORD fdwReason, LPVOID lpReserved)
if (VERBOSITY > 0) {
debugPrintf("inject: DLL_PROCESS_DETACH\n");
}
+
+ patchAllModules(ACTION_UNHOOK);
+
+ if (g_hHookModule) {
+ FreeLibrary(g_hHookModule);
+ }
break;
}
return TRUE;