diff options
author | José Fonseca <jfonseca@vmware.com> | 2014-09-26 19:44:29 +0100 |
---|---|---|
committer | Jose Fonseca <jfonseca@vmware.com> | 2015-06-29 14:03:48 +0100 |
commit | 3e60c2413c0183754ec8459081f162c0f1dcdeed (patch) | |
tree | d39544d50983e7058b950a658d08847c0c7ba6bc /inject | |
parent | 4f1789840be404ce23ac7390bd7ce9866b62e2c2 (diff) |
inject: Support ejecting DLL from remote process.
Diffstat (limited to 'inject')
-rw-r--r-- | inject/injectee.cpp | 149 | ||||
-rw-r--r-- | inject/injector.cpp | 84 |
2 files changed, 175 insertions, 58 deletions
diff --git a/inject/injectee.cpp b/inject/injectee.cpp index 6417d487..af60c8b0 100644 --- a/inject/injectee.cpp +++ b/inject/injectee.cpp @@ -42,6 +42,7 @@ #include <stdarg.h> #include <string.h> +#include <algorithm> #include <set> #include <map> #include <functional> @@ -402,11 +403,12 @@ replaceAddress(LPVOID *lpOldAddress, LPVOID lpNewAddress) * */ static LPVOID * -getOldFunctionAddress(HMODULE hModule, - const char *szDescriptorName, - DWORD OriginalFirstThunk, - DWORD FirstThunk, - const char* pszFunctionName) +getPatchAddress(HMODULE hModule, + const char *szDescriptorName, + DWORD OriginalFirstThunk, + DWORD FirstThunk, + const char* pszFunctionName, + LPVOID lpOldAddress) { if (VERBOSITY >= 4) { debugPrintf("inject: %s(%s, %s)\n", __FUNCTION__, @@ -416,7 +418,7 @@ getOldFunctionAddress(HMODULE hModule, PIMAGE_THUNK_DATA pThunkIAT = rvaToVa<IMAGE_THUNK_DATA>(hModule, FirstThunk); - UINT_PTR pRealFunction = 0; + UINT_PTR pOldFunction = (UINT_PTR)lpOldAddress; PIMAGE_THUNK_DATA pThunk; if (OriginalFirstThunk) { @@ -429,15 +431,10 @@ getOldFunctionAddress(HMODULE hModule, if (OriginalFirstThunk == 0 || pThunk->u1.Ordinal & IMAGE_ORDINAL_FLAG) { // No name -- search by the real function address - if (!pRealFunction) { - HMODULE hRealModule = GetModuleHandleA(szDescriptorName); - assert(hRealModule); - pRealFunction = (UINT_PTR)GetProcAddress(hRealModule, pszFunctionName); - if (!pRealFunction) { - return NULL; - } + if (!pOldFunction) { + return NULL; } - if (pThunkIAT->u1.Function == pRealFunction) { + if (pThunkIAT->u1.Function == pOldFunction) { return (LPVOID *)(&pThunkIAT->u1.Function); } } else { @@ -457,17 +454,19 @@ getOldFunctionAddress(HMODULE hModule, static LPVOID * -getOldFunctionAddress(HMODULE hModule, - PIMAGE_IMPORT_DESCRIPTOR pImportDescriptor, - const char* pszFunctionName) +getPatchAddress(HMODULE hModule, + PIMAGE_IMPORT_DESCRIPTOR pImportDescriptor, + const char* pszFunctionName, + LPVOID lpOldAddress) { assert(pImportDescriptor->TimeDateStamp != 0 || pImportDescriptor->Name != 0); - return getOldFunctionAddress(hModule, - getDescriptorName(hModule, pImportDescriptor), - pImportDescriptor->OriginalFirstThunk, - pImportDescriptor->FirstThunk, - pszFunctionName); + return getPatchAddress(hModule, + getDescriptorName(hModule, pImportDescriptor), + pImportDescriptor->OriginalFirstThunk, + pImportDescriptor->FirstThunk, + pszFunctionName, + lpOldAddress); } @@ -475,17 +474,19 @@ getOldFunctionAddress(HMODULE hModule, // http://www.microsoft.com/msj/1298/hood/hood1298.aspx // http://msdn.microsoft.com/en-us/library/16b2dyk5.aspx static LPVOID * -getOldFunctionAddress(HMODULE hModule, - PImgDelayDescr pDelayDescriptor, - const char* pszFunctionName) +getPatchAddress(HMODULE hModule, + PImgDelayDescr pDelayDescriptor, + const char* pszFunctionName, + LPVOID lpOldAddress) { assert(pDelayDescriptor->rvaDLLName != 0); - return getOldFunctionAddress(hModule, - getDescriptorName(hModule, pDelayDescriptor), - pDelayDescriptor->rvaINT, - pDelayDescriptor->rvaIAT, - pszFunctionName); + return getPatchAddress(hModule, + getDescriptorName(hModule, pDelayDescriptor), + pDelayDescriptor->rvaINT, + pDelayDescriptor->rvaIAT, + pszFunctionName, + lpOldAddress); } @@ -496,24 +497,25 @@ patchFunction(HMODULE hModule, const char *pszDllName, T pImportDescriptor, const char *pszFunctionName, + LPVOID lpOldAddress, LPVOID lpNewAddress) { - LPVOID* lpOldFunctionAddress = getOldFunctionAddress(hModule, pImportDescriptor, pszFunctionName); - if (lpOldFunctionAddress == NULL) { + LPVOID* lpPatchAddress = getPatchAddress(hModule, pImportDescriptor, pszFunctionName, lpOldAddress); + if (lpPatchAddress == NULL) { return FALSE; } - if (*lpOldFunctionAddress == lpNewAddress) { + if (*lpPatchAddress == lpNewAddress) { return TRUE; } - DWORD Offset = (DWORD)(UINT_PTR)lpOldFunctionAddress - (UINT_PTR)hModule; + DWORD Offset = (DWORD)(UINT_PTR)lpPatchAddress - (UINT_PTR)hModule; if (VERBOSITY > 0) { debugPrintf("inject: patching %s!0x%lx -> %s!%s\n", szModule, Offset, pszDllName, pszFunctionName); } BOOL bRet; - bRet = replaceAddress(lpOldFunctionAddress, lpNewAddress); + bRet = replaceAddress(lpPatchAddress, lpNewAddress); if (!bRet) { debugPrintf("inject: failed to patch %s!0x%lx -> %s!%s\n", szModule, Offset, pszDllName, pszFunctionName); } @@ -561,11 +563,19 @@ static std::set<HMODULE> g_hHookedModules; +enum Action { + ACTION_HOOK, + ACTION_UNHOOK, + +}; + + template< class T > void patchDescriptor(HMODULE hModule, const char *szModule, - T pImportDescriptor) + T pImportDescriptor, + Action action) { const char* szDescriptorName = getDescriptorName(hModule, pImportDescriptor); @@ -578,11 +588,26 @@ patchDescriptor(HMODULE hModule, FunctionMap::const_iterator fnIt; for (fnIt = functionMap.begin(); fnIt != functionMap.end(); ++fnIt) { const char *szFunctionName = fnIt->first; - LPVOID lpNewAddress = fnIt->second; + LPVOID lpHookAddress = fnIt->second; + + // Knowning the real address is useful when patching imports by ordinal + LPVOID lpRealAddress = NULL; + HMODULE hRealModule = GetModuleHandleA(szDescriptorName); + if (hRealModule) { + assert(hRealModule != g_hHookModule); + lpRealAddress = (LPVOID)GetProcAddress(hRealModule, szFunctionName); + } + + LPVOID lpOldAddress = lpRealAddress; + LPVOID lpNewAddress = lpHookAddress; + + if (action == ACTION_UNHOOK) { + std::swap(lpOldAddress, lpNewAddress); + } - BOOL bHooked; - bHooked = patchFunction(hModule, szModule, szMatchModule, pImportDescriptor, szFunctionName, lpNewAddress); - if (bHooked && !module.bInternal && pSharedMem) { + BOOL bPatched; + bPatched = patchFunction(hModule, szModule, szMatchModule, pImportDescriptor, szFunctionName, lpOldAddress, lpNewAddress); + if (action == ACTION_HOOK && bPatched && !module.bInternal && pSharedMem) { pSharedMem->bReplaced = TRUE; } } @@ -592,7 +617,8 @@ patchDescriptor(HMODULE hModule, static void patchModule(HMODULE hModule, - const char *szModule) + const char *szModule, + Action action) { /* Never patch this module */ if (hModule == g_hThisModule) { @@ -605,12 +631,14 @@ patchModule(HMODULE hModule, } /* Hook modules only once */ - std::pair< std::set<HMODULE>::iterator, bool > ret; - EnterCriticalSection(&Mutex); - ret = g_hHookedModules.insert(hModule); - LeaveCriticalSection(&Mutex); - if (!ret.second) { - return; + if (action == ACTION_HOOK) { + std::pair< std::set<HMODULE>::iterator, bool > ret; + EnterCriticalSection(&Mutex); + ret = g_hHookedModules.insert(hModule); + LeaveCriticalSection(&Mutex); + if (!ret.second) { + return; + } } const char *szBaseName = getBaseName(szModule); @@ -635,7 +663,7 @@ patchModule(HMODULE hModule, if (pImportDescriptor) { while (pImportDescriptor->FirstThunk) { - patchDescriptor(hModule, szModule, pImportDescriptor); + patchDescriptor(hModule, szModule, pImportDescriptor, action); ++pImportDescriptor; } @@ -651,15 +679,16 @@ patchModule(HMODULE hModule, szName); } - patchDescriptor(hModule, szModule, pDelayDescriptor); + patchDescriptor(hModule, szModule, pDelayDescriptor, action); ++pDelayDescriptor; } } } + static void -patchAllModules(void) +patchAllModules(Action action) { HANDLE hModuleSnap = CreateToolhelp32Snapshot(TH32CS_SNAPMODULE, GetCurrentProcessId()); if (hModuleSnap == INVALID_HANDLE_VALUE) { @@ -670,7 +699,7 @@ patchAllModules(void) me32.dwSize = sizeof me32; if (Module32First(hModuleSnap, &me32)) { do { - patchModule(me32.hModule, me32.szExePath); + patchModule(me32.hModule, me32.szExePath, action); } while (Module32Next(hModuleSnap, &me32)); } @@ -678,8 +707,6 @@ patchAllModules(void) } - - static HMODULE WINAPI MyLoadLibraryA(LPCSTR lpLibFileName) { @@ -714,7 +741,7 @@ MyLoadLibraryA(LPCSTR lpLibFileName) } // Hook all new modules (and not just this one, to pick up any dependencies) - patchAllModules(); + patchAllModules(ACTION_HOOK); SetLastError(dwLastError); return hModule; @@ -732,7 +759,7 @@ MyLoadLibraryW(LPCWSTR lpLibFileName) } // Hook all new modules (and not just this one, to pick up any dependencies) - patchAllModules(); + patchAllModules(ACTION_HOOK); SetLastError(dwLastError); return hModule; @@ -792,7 +819,7 @@ MyLoadLibraryExA(LPCSTR lpLibFileName, HANDLE hFile, DWORD dwFlags) } // Hook all new modules (and not just this one, to pick up any dependencies) - patchAllModules(); + patchAllModules(ACTION_HOOK); SetLastError(dwLastError); return hModule; @@ -810,7 +837,7 @@ MyLoadLibraryExW(LPCWSTR lpLibFileName, HANDLE hFile, DWORD dwFlags) } // Hook all new modules (and not just this one, to pick up any dependencies) - patchAllModules(); + patchAllModules(ACTION_HOOK); SetLastError(dwLastError); return hModule; @@ -1054,7 +1081,7 @@ DllMain(HINSTANCE hinstDLL, DWORD fdwReason, LPVOID lpReserved) dumpRegisteredHooks(); - patchAllModules(); + patchAllModules(ACTION_HOOK); break; case DLL_THREAD_ATTACH: @@ -1067,6 +1094,12 @@ DllMain(HINSTANCE hinstDLL, DWORD fdwReason, LPVOID lpReserved) if (VERBOSITY > 0) { debugPrintf("inject: DLL_PROCESS_DETACH\n"); } + + patchAllModules(ACTION_UNHOOK); + + if (g_hHookModule) { + FreeLibrary(g_hHookModule); + } break; } return TRUE; diff --git a/inject/injector.cpp b/inject/injector.cpp index 752a543f..7bb3a7d6 100644 --- a/inject/injector.cpp +++ b/inject/injector.cpp @@ -357,6 +357,85 @@ isNumber(const char *arg) { return true; } + +static BOOL +ejectDll(HANDLE hProcess, const char *szDllPath) +{ + /* + * Enumerate all modules. + */ + + HMODULE *phModules = NULL; + DWORD cb = sizeof *phModules * +#ifdef NDEBUG + 32 +#else + 4 +#endif + ; + DWORD cbNeeded = 0; + while (true) { + phModules = (HMODULE *)realloc(phModules, cb); + if (!EnumProcessModules(hProcess, phModules, cb, &cbNeeded)) { + logLastError("failed to enumerate modules in remote process"); + free(phModules); + return FALSE; + } + + if (cbNeeded < cb) { + break; + } + + cb *= 2; + } + + DWORD cNumModules = cbNeeded / sizeof *phModules; + + /* + * Search our DLL. + */ + + const char *szDllName = getBaseName(szDllPath); + HMODULE hModule = NULL; + for (unsigned i = 0; i < cNumModules; ++i) { + char szModName[MAX_PATH]; + if (GetModuleFileNameExA(hProcess, phModules[i], szModName, ARRAY_SIZE(szModName))) { + if (stricmp(getBaseName(szModName), szDllName) == 0) { + hModule = phModules[i]; + break; + } + } + } + + free(phModules); + + if (!hModule) { + debugPrintf("inject: error: failed to find %s module in the remote process\n", szDllName); + return FALSE; + } + + PTHREAD_START_ROUTINE lpStartAddress = + (PTHREAD_START_ROUTINE)GetProcAddress(GetModuleHandleA("KERNEL32"), "FreeLibrary"); + + HANDLE hThread = CreateRemoteThread(hProcess, NULL, 0, lpStartAddress, hModule, 0, NULL); + if (!hThread) { + logLastError("failed to create remote thread"); + return FALSE; + } + + WaitForSingleObject(hThread, INFINITE); + + DWORD bRet = 0; + GetExitCodeThread(hThread, &bRet); + if (!bRet) { + debugPrintf("inject: error: failed to unload %s from the remote process\n", szDllPath); + return FALSE; + } + + return TRUE; +} + + static void help(void) { @@ -630,6 +709,11 @@ main(int argc, char *argv[]) if (bAttach) { if (bAttachDwm) { restartDwmComposition(hProcess); + } else { + fprintf(stderr, "Press any key when finished tracing\n"); + getchar(); + + ejectDll(hProcess, szDllPath); } if (dwThreadId) { |