diff options
Diffstat (limited to 'inject/injectee.cpp')
-rw-r--r-- | inject/injectee.cpp | 149 |
1 files changed, 91 insertions, 58 deletions
diff --git a/inject/injectee.cpp b/inject/injectee.cpp index 6417d487..af60c8b0 100644 --- a/inject/injectee.cpp +++ b/inject/injectee.cpp @@ -42,6 +42,7 @@ #include <stdarg.h> #include <string.h> +#include <algorithm> #include <set> #include <map> #include <functional> @@ -402,11 +403,12 @@ replaceAddress(LPVOID *lpOldAddress, LPVOID lpNewAddress) * */ static LPVOID * -getOldFunctionAddress(HMODULE hModule, - const char *szDescriptorName, - DWORD OriginalFirstThunk, - DWORD FirstThunk, - const char* pszFunctionName) +getPatchAddress(HMODULE hModule, + const char *szDescriptorName, + DWORD OriginalFirstThunk, + DWORD FirstThunk, + const char* pszFunctionName, + LPVOID lpOldAddress) { if (VERBOSITY >= 4) { debugPrintf("inject: %s(%s, %s)\n", __FUNCTION__, @@ -416,7 +418,7 @@ getOldFunctionAddress(HMODULE hModule, PIMAGE_THUNK_DATA pThunkIAT = rvaToVa<IMAGE_THUNK_DATA>(hModule, FirstThunk); - UINT_PTR pRealFunction = 0; + UINT_PTR pOldFunction = (UINT_PTR)lpOldAddress; PIMAGE_THUNK_DATA pThunk; if (OriginalFirstThunk) { @@ -429,15 +431,10 @@ getOldFunctionAddress(HMODULE hModule, if (OriginalFirstThunk == 0 || pThunk->u1.Ordinal & IMAGE_ORDINAL_FLAG) { // No name -- search by the real function address - if (!pRealFunction) { - HMODULE hRealModule = GetModuleHandleA(szDescriptorName); - assert(hRealModule); - pRealFunction = (UINT_PTR)GetProcAddress(hRealModule, pszFunctionName); - if (!pRealFunction) { - return NULL; - } + if (!pOldFunction) { + return NULL; } - if (pThunkIAT->u1.Function == pRealFunction) { + if (pThunkIAT->u1.Function == pOldFunction) { return (LPVOID *)(&pThunkIAT->u1.Function); } } else { @@ -457,17 +454,19 @@ getOldFunctionAddress(HMODULE hModule, static LPVOID * -getOldFunctionAddress(HMODULE hModule, - PIMAGE_IMPORT_DESCRIPTOR pImportDescriptor, - const char* pszFunctionName) +getPatchAddress(HMODULE hModule, + PIMAGE_IMPORT_DESCRIPTOR pImportDescriptor, + const char* pszFunctionName, + LPVOID lpOldAddress) { assert(pImportDescriptor->TimeDateStamp != 0 || pImportDescriptor->Name != 0); - return getOldFunctionAddress(hModule, - getDescriptorName(hModule, pImportDescriptor), - pImportDescriptor->OriginalFirstThunk, - pImportDescriptor->FirstThunk, - pszFunctionName); + return getPatchAddress(hModule, + getDescriptorName(hModule, pImportDescriptor), + pImportDescriptor->OriginalFirstThunk, + pImportDescriptor->FirstThunk, + pszFunctionName, + lpOldAddress); } @@ -475,17 +474,19 @@ getOldFunctionAddress(HMODULE hModule, // http://www.microsoft.com/msj/1298/hood/hood1298.aspx // http://msdn.microsoft.com/en-us/library/16b2dyk5.aspx static LPVOID * -getOldFunctionAddress(HMODULE hModule, - PImgDelayDescr pDelayDescriptor, - const char* pszFunctionName) +getPatchAddress(HMODULE hModule, + PImgDelayDescr pDelayDescriptor, + const char* pszFunctionName, + LPVOID lpOldAddress) { assert(pDelayDescriptor->rvaDLLName != 0); - return getOldFunctionAddress(hModule, - getDescriptorName(hModule, pDelayDescriptor), - pDelayDescriptor->rvaINT, - pDelayDescriptor->rvaIAT, - pszFunctionName); + return getPatchAddress(hModule, + getDescriptorName(hModule, pDelayDescriptor), + pDelayDescriptor->rvaINT, + pDelayDescriptor->rvaIAT, + pszFunctionName, + lpOldAddress); } @@ -496,24 +497,25 @@ patchFunction(HMODULE hModule, const char *pszDllName, T pImportDescriptor, const char *pszFunctionName, + LPVOID lpOldAddress, LPVOID lpNewAddress) { - LPVOID* lpOldFunctionAddress = getOldFunctionAddress(hModule, pImportDescriptor, pszFunctionName); - if (lpOldFunctionAddress == NULL) { + LPVOID* lpPatchAddress = getPatchAddress(hModule, pImportDescriptor, pszFunctionName, lpOldAddress); + if (lpPatchAddress == NULL) { return FALSE; } - if (*lpOldFunctionAddress == lpNewAddress) { + if (*lpPatchAddress == lpNewAddress) { return TRUE; } - DWORD Offset = (DWORD)(UINT_PTR)lpOldFunctionAddress - (UINT_PTR)hModule; + DWORD Offset = (DWORD)(UINT_PTR)lpPatchAddress - (UINT_PTR)hModule; if (VERBOSITY > 0) { debugPrintf("inject: patching %s!0x%lx -> %s!%s\n", szModule, Offset, pszDllName, pszFunctionName); } BOOL bRet; - bRet = replaceAddress(lpOldFunctionAddress, lpNewAddress); + bRet = replaceAddress(lpPatchAddress, lpNewAddress); if (!bRet) { debugPrintf("inject: failed to patch %s!0x%lx -> %s!%s\n", szModule, Offset, pszDllName, pszFunctionName); } @@ -561,11 +563,19 @@ static std::set<HMODULE> g_hHookedModules; +enum Action { + ACTION_HOOK, + ACTION_UNHOOK, + +}; + + template< class T > void patchDescriptor(HMODULE hModule, const char *szModule, - T pImportDescriptor) + T pImportDescriptor, + Action action) { const char* szDescriptorName = getDescriptorName(hModule, pImportDescriptor); @@ -578,11 +588,26 @@ patchDescriptor(HMODULE hModule, FunctionMap::const_iterator fnIt; for (fnIt = functionMap.begin(); fnIt != functionMap.end(); ++fnIt) { const char *szFunctionName = fnIt->first; - LPVOID lpNewAddress = fnIt->second; + LPVOID lpHookAddress = fnIt->second; + + // Knowning the real address is useful when patching imports by ordinal + LPVOID lpRealAddress = NULL; + HMODULE hRealModule = GetModuleHandleA(szDescriptorName); + if (hRealModule) { + assert(hRealModule != g_hHookModule); + lpRealAddress = (LPVOID)GetProcAddress(hRealModule, szFunctionName); + } + + LPVOID lpOldAddress = lpRealAddress; + LPVOID lpNewAddress = lpHookAddress; + + if (action == ACTION_UNHOOK) { + std::swap(lpOldAddress, lpNewAddress); + } - BOOL bHooked; - bHooked = patchFunction(hModule, szModule, szMatchModule, pImportDescriptor, szFunctionName, lpNewAddress); - if (bHooked && !module.bInternal && pSharedMem) { + BOOL bPatched; + bPatched = patchFunction(hModule, szModule, szMatchModule, pImportDescriptor, szFunctionName, lpOldAddress, lpNewAddress); + if (action == ACTION_HOOK && bPatched && !module.bInternal && pSharedMem) { pSharedMem->bReplaced = TRUE; } } @@ -592,7 +617,8 @@ patchDescriptor(HMODULE hModule, static void patchModule(HMODULE hModule, - const char *szModule) + const char *szModule, + Action action) { /* Never patch this module */ if (hModule == g_hThisModule) { @@ -605,12 +631,14 @@ patchModule(HMODULE hModule, } /* Hook modules only once */ - std::pair< std::set<HMODULE>::iterator, bool > ret; - EnterCriticalSection(&Mutex); - ret = g_hHookedModules.insert(hModule); - LeaveCriticalSection(&Mutex); - if (!ret.second) { - return; + if (action == ACTION_HOOK) { + std::pair< std::set<HMODULE>::iterator, bool > ret; + EnterCriticalSection(&Mutex); + ret = g_hHookedModules.insert(hModule); + LeaveCriticalSection(&Mutex); + if (!ret.second) { + return; + } } const char *szBaseName = getBaseName(szModule); @@ -635,7 +663,7 @@ patchModule(HMODULE hModule, if (pImportDescriptor) { while (pImportDescriptor->FirstThunk) { - patchDescriptor(hModule, szModule, pImportDescriptor); + patchDescriptor(hModule, szModule, pImportDescriptor, action); ++pImportDescriptor; } @@ -651,15 +679,16 @@ patchModule(HMODULE hModule, szName); } - patchDescriptor(hModule, szModule, pDelayDescriptor); + patchDescriptor(hModule, szModule, pDelayDescriptor, action); ++pDelayDescriptor; } } } + static void -patchAllModules(void) +patchAllModules(Action action) { HANDLE hModuleSnap = CreateToolhelp32Snapshot(TH32CS_SNAPMODULE, GetCurrentProcessId()); if (hModuleSnap == INVALID_HANDLE_VALUE) { @@ -670,7 +699,7 @@ patchAllModules(void) me32.dwSize = sizeof me32; if (Module32First(hModuleSnap, &me32)) { do { - patchModule(me32.hModule, me32.szExePath); + patchModule(me32.hModule, me32.szExePath, action); } while (Module32Next(hModuleSnap, &me32)); } @@ -678,8 +707,6 @@ patchAllModules(void) } - - static HMODULE WINAPI MyLoadLibraryA(LPCSTR lpLibFileName) { @@ -714,7 +741,7 @@ MyLoadLibraryA(LPCSTR lpLibFileName) } // Hook all new modules (and not just this one, to pick up any dependencies) - patchAllModules(); + patchAllModules(ACTION_HOOK); SetLastError(dwLastError); return hModule; @@ -732,7 +759,7 @@ MyLoadLibraryW(LPCWSTR lpLibFileName) } // Hook all new modules (and not just this one, to pick up any dependencies) - patchAllModules(); + patchAllModules(ACTION_HOOK); SetLastError(dwLastError); return hModule; @@ -792,7 +819,7 @@ MyLoadLibraryExA(LPCSTR lpLibFileName, HANDLE hFile, DWORD dwFlags) } // Hook all new modules (and not just this one, to pick up any dependencies) - patchAllModules(); + patchAllModules(ACTION_HOOK); SetLastError(dwLastError); return hModule; @@ -810,7 +837,7 @@ MyLoadLibraryExW(LPCWSTR lpLibFileName, HANDLE hFile, DWORD dwFlags) } // Hook all new modules (and not just this one, to pick up any dependencies) - patchAllModules(); + patchAllModules(ACTION_HOOK); SetLastError(dwLastError); return hModule; @@ -1054,7 +1081,7 @@ DllMain(HINSTANCE hinstDLL, DWORD fdwReason, LPVOID lpReserved) dumpRegisteredHooks(); - patchAllModules(); + patchAllModules(ACTION_HOOK); break; case DLL_THREAD_ATTACH: @@ -1067,6 +1094,12 @@ DllMain(HINSTANCE hinstDLL, DWORD fdwReason, LPVOID lpReserved) if (VERBOSITY > 0) { debugPrintf("inject: DLL_PROCESS_DETACH\n"); } + + patchAllModules(ACTION_UNHOOK); + + if (g_hHookModule) { + FreeLibrary(g_hHookModule); + } break; } return TRUE; |