summaryrefslogtreecommitdiff
path: root/inject
diff options
context:
space:
mode:
authorJosé Fonseca <jfonseca@vmware.com>2014-09-26 19:44:29 +0100
committerJose Fonseca <jfonseca@vmware.com>2015-06-29 14:03:48 +0100
commit3e60c2413c0183754ec8459081f162c0f1dcdeed (patch)
treed39544d50983e7058b950a658d08847c0c7ba6bc /inject
parent4f1789840be404ce23ac7390bd7ce9866b62e2c2 (diff)
inject: Support ejecting DLL from remote process.
Diffstat (limited to 'inject')
-rw-r--r--inject/injectee.cpp149
-rw-r--r--inject/injector.cpp84
2 files changed, 175 insertions, 58 deletions
diff --git a/inject/injectee.cpp b/inject/injectee.cpp
index 6417d487..af60c8b0 100644
--- a/inject/injectee.cpp
+++ b/inject/injectee.cpp
@@ -42,6 +42,7 @@
#include <stdarg.h>
#include <string.h>
+#include <algorithm>
#include <set>
#include <map>
#include <functional>
@@ -402,11 +403,12 @@ replaceAddress(LPVOID *lpOldAddress, LPVOID lpNewAddress)
*
*/
static LPVOID *
-getOldFunctionAddress(HMODULE hModule,
- const char *szDescriptorName,
- DWORD OriginalFirstThunk,
- DWORD FirstThunk,
- const char* pszFunctionName)
+getPatchAddress(HMODULE hModule,
+ const char *szDescriptorName,
+ DWORD OriginalFirstThunk,
+ DWORD FirstThunk,
+ const char* pszFunctionName,
+ LPVOID lpOldAddress)
{
if (VERBOSITY >= 4) {
debugPrintf("inject: %s(%s, %s)\n", __FUNCTION__,
@@ -416,7 +418,7 @@ getOldFunctionAddress(HMODULE hModule,
PIMAGE_THUNK_DATA pThunkIAT = rvaToVa<IMAGE_THUNK_DATA>(hModule, FirstThunk);
- UINT_PTR pRealFunction = 0;
+ UINT_PTR pOldFunction = (UINT_PTR)lpOldAddress;
PIMAGE_THUNK_DATA pThunk;
if (OriginalFirstThunk) {
@@ -429,15 +431,10 @@ getOldFunctionAddress(HMODULE hModule,
if (OriginalFirstThunk == 0 ||
pThunk->u1.Ordinal & IMAGE_ORDINAL_FLAG) {
// No name -- search by the real function address
- if (!pRealFunction) {
- HMODULE hRealModule = GetModuleHandleA(szDescriptorName);
- assert(hRealModule);
- pRealFunction = (UINT_PTR)GetProcAddress(hRealModule, pszFunctionName);
- if (!pRealFunction) {
- return NULL;
- }
+ if (!pOldFunction) {
+ return NULL;
}
- if (pThunkIAT->u1.Function == pRealFunction) {
+ if (pThunkIAT->u1.Function == pOldFunction) {
return (LPVOID *)(&pThunkIAT->u1.Function);
}
} else {
@@ -457,17 +454,19 @@ getOldFunctionAddress(HMODULE hModule,
static LPVOID *
-getOldFunctionAddress(HMODULE hModule,
- PIMAGE_IMPORT_DESCRIPTOR pImportDescriptor,
- const char* pszFunctionName)
+getPatchAddress(HMODULE hModule,
+ PIMAGE_IMPORT_DESCRIPTOR pImportDescriptor,
+ const char* pszFunctionName,
+ LPVOID lpOldAddress)
{
assert(pImportDescriptor->TimeDateStamp != 0 || pImportDescriptor->Name != 0);
- return getOldFunctionAddress(hModule,
- getDescriptorName(hModule, pImportDescriptor),
- pImportDescriptor->OriginalFirstThunk,
- pImportDescriptor->FirstThunk,
- pszFunctionName);
+ return getPatchAddress(hModule,
+ getDescriptorName(hModule, pImportDescriptor),
+ pImportDescriptor->OriginalFirstThunk,
+ pImportDescriptor->FirstThunk,
+ pszFunctionName,
+ lpOldAddress);
}
@@ -475,17 +474,19 @@ getOldFunctionAddress(HMODULE hModule,
// http://www.microsoft.com/msj/1298/hood/hood1298.aspx
// http://msdn.microsoft.com/en-us/library/16b2dyk5.aspx
static LPVOID *
-getOldFunctionAddress(HMODULE hModule,
- PImgDelayDescr pDelayDescriptor,
- const char* pszFunctionName)
+getPatchAddress(HMODULE hModule,
+ PImgDelayDescr pDelayDescriptor,
+ const char* pszFunctionName,
+ LPVOID lpOldAddress)
{
assert(pDelayDescriptor->rvaDLLName != 0);
- return getOldFunctionAddress(hModule,
- getDescriptorName(hModule, pDelayDescriptor),
- pDelayDescriptor->rvaINT,
- pDelayDescriptor->rvaIAT,
- pszFunctionName);
+ return getPatchAddress(hModule,
+ getDescriptorName(hModule, pDelayDescriptor),
+ pDelayDescriptor->rvaINT,
+ pDelayDescriptor->rvaIAT,
+ pszFunctionName,
+ lpOldAddress);
}
@@ -496,24 +497,25 @@ patchFunction(HMODULE hModule,
const char *pszDllName,
T pImportDescriptor,
const char *pszFunctionName,
+ LPVOID lpOldAddress,
LPVOID lpNewAddress)
{
- LPVOID* lpOldFunctionAddress = getOldFunctionAddress(hModule, pImportDescriptor, pszFunctionName);
- if (lpOldFunctionAddress == NULL) {
+ LPVOID* lpPatchAddress = getPatchAddress(hModule, pImportDescriptor, pszFunctionName, lpOldAddress);
+ if (lpPatchAddress == NULL) {
return FALSE;
}
- if (*lpOldFunctionAddress == lpNewAddress) {
+ if (*lpPatchAddress == lpNewAddress) {
return TRUE;
}
- DWORD Offset = (DWORD)(UINT_PTR)lpOldFunctionAddress - (UINT_PTR)hModule;
+ DWORD Offset = (DWORD)(UINT_PTR)lpPatchAddress - (UINT_PTR)hModule;
if (VERBOSITY > 0) {
debugPrintf("inject: patching %s!0x%lx -> %s!%s\n", szModule, Offset, pszDllName, pszFunctionName);
}
BOOL bRet;
- bRet = replaceAddress(lpOldFunctionAddress, lpNewAddress);
+ bRet = replaceAddress(lpPatchAddress, lpNewAddress);
if (!bRet) {
debugPrintf("inject: failed to patch %s!0x%lx -> %s!%s\n", szModule, Offset, pszDllName, pszFunctionName);
}
@@ -561,11 +563,19 @@ static std::set<HMODULE>
g_hHookedModules;
+enum Action {
+ ACTION_HOOK,
+ ACTION_UNHOOK,
+
+};
+
+
template< class T >
void
patchDescriptor(HMODULE hModule,
const char *szModule,
- T pImportDescriptor)
+ T pImportDescriptor,
+ Action action)
{
const char* szDescriptorName = getDescriptorName(hModule, pImportDescriptor);
@@ -578,11 +588,26 @@ patchDescriptor(HMODULE hModule,
FunctionMap::const_iterator fnIt;
for (fnIt = functionMap.begin(); fnIt != functionMap.end(); ++fnIt) {
const char *szFunctionName = fnIt->first;
- LPVOID lpNewAddress = fnIt->second;
+ LPVOID lpHookAddress = fnIt->second;
+
+ // Knowning the real address is useful when patching imports by ordinal
+ LPVOID lpRealAddress = NULL;
+ HMODULE hRealModule = GetModuleHandleA(szDescriptorName);
+ if (hRealModule) {
+ assert(hRealModule != g_hHookModule);
+ lpRealAddress = (LPVOID)GetProcAddress(hRealModule, szFunctionName);
+ }
+
+ LPVOID lpOldAddress = lpRealAddress;
+ LPVOID lpNewAddress = lpHookAddress;
+
+ if (action == ACTION_UNHOOK) {
+ std::swap(lpOldAddress, lpNewAddress);
+ }
- BOOL bHooked;
- bHooked = patchFunction(hModule, szModule, szMatchModule, pImportDescriptor, szFunctionName, lpNewAddress);
- if (bHooked && !module.bInternal && pSharedMem) {
+ BOOL bPatched;
+ bPatched = patchFunction(hModule, szModule, szMatchModule, pImportDescriptor, szFunctionName, lpOldAddress, lpNewAddress);
+ if (action == ACTION_HOOK && bPatched && !module.bInternal && pSharedMem) {
pSharedMem->bReplaced = TRUE;
}
}
@@ -592,7 +617,8 @@ patchDescriptor(HMODULE hModule,
static void
patchModule(HMODULE hModule,
- const char *szModule)
+ const char *szModule,
+ Action action)
{
/* Never patch this module */
if (hModule == g_hThisModule) {
@@ -605,12 +631,14 @@ patchModule(HMODULE hModule,
}
/* Hook modules only once */
- std::pair< std::set<HMODULE>::iterator, bool > ret;
- EnterCriticalSection(&Mutex);
- ret = g_hHookedModules.insert(hModule);
- LeaveCriticalSection(&Mutex);
- if (!ret.second) {
- return;
+ if (action == ACTION_HOOK) {
+ std::pair< std::set<HMODULE>::iterator, bool > ret;
+ EnterCriticalSection(&Mutex);
+ ret = g_hHookedModules.insert(hModule);
+ LeaveCriticalSection(&Mutex);
+ if (!ret.second) {
+ return;
+ }
}
const char *szBaseName = getBaseName(szModule);
@@ -635,7 +663,7 @@ patchModule(HMODULE hModule,
if (pImportDescriptor) {
while (pImportDescriptor->FirstThunk) {
- patchDescriptor(hModule, szModule, pImportDescriptor);
+ patchDescriptor(hModule, szModule, pImportDescriptor, action);
++pImportDescriptor;
}
@@ -651,15 +679,16 @@ patchModule(HMODULE hModule,
szName);
}
- patchDescriptor(hModule, szModule, pDelayDescriptor);
+ patchDescriptor(hModule, szModule, pDelayDescriptor, action);
++pDelayDescriptor;
}
}
}
+
static void
-patchAllModules(void)
+patchAllModules(Action action)
{
HANDLE hModuleSnap = CreateToolhelp32Snapshot(TH32CS_SNAPMODULE, GetCurrentProcessId());
if (hModuleSnap == INVALID_HANDLE_VALUE) {
@@ -670,7 +699,7 @@ patchAllModules(void)
me32.dwSize = sizeof me32;
if (Module32First(hModuleSnap, &me32)) {
do {
- patchModule(me32.hModule, me32.szExePath);
+ patchModule(me32.hModule, me32.szExePath, action);
} while (Module32Next(hModuleSnap, &me32));
}
@@ -678,8 +707,6 @@ patchAllModules(void)
}
-
-
static HMODULE WINAPI
MyLoadLibraryA(LPCSTR lpLibFileName)
{
@@ -714,7 +741,7 @@ MyLoadLibraryA(LPCSTR lpLibFileName)
}
// Hook all new modules (and not just this one, to pick up any dependencies)
- patchAllModules();
+ patchAllModules(ACTION_HOOK);
SetLastError(dwLastError);
return hModule;
@@ -732,7 +759,7 @@ MyLoadLibraryW(LPCWSTR lpLibFileName)
}
// Hook all new modules (and not just this one, to pick up any dependencies)
- patchAllModules();
+ patchAllModules(ACTION_HOOK);
SetLastError(dwLastError);
return hModule;
@@ -792,7 +819,7 @@ MyLoadLibraryExA(LPCSTR lpLibFileName, HANDLE hFile, DWORD dwFlags)
}
// Hook all new modules (and not just this one, to pick up any dependencies)
- patchAllModules();
+ patchAllModules(ACTION_HOOK);
SetLastError(dwLastError);
return hModule;
@@ -810,7 +837,7 @@ MyLoadLibraryExW(LPCWSTR lpLibFileName, HANDLE hFile, DWORD dwFlags)
}
// Hook all new modules (and not just this one, to pick up any dependencies)
- patchAllModules();
+ patchAllModules(ACTION_HOOK);
SetLastError(dwLastError);
return hModule;
@@ -1054,7 +1081,7 @@ DllMain(HINSTANCE hinstDLL, DWORD fdwReason, LPVOID lpReserved)
dumpRegisteredHooks();
- patchAllModules();
+ patchAllModules(ACTION_HOOK);
break;
case DLL_THREAD_ATTACH:
@@ -1067,6 +1094,12 @@ DllMain(HINSTANCE hinstDLL, DWORD fdwReason, LPVOID lpReserved)
if (VERBOSITY > 0) {
debugPrintf("inject: DLL_PROCESS_DETACH\n");
}
+
+ patchAllModules(ACTION_UNHOOK);
+
+ if (g_hHookModule) {
+ FreeLibrary(g_hHookModule);
+ }
break;
}
return TRUE;
diff --git a/inject/injector.cpp b/inject/injector.cpp
index 752a543f..7bb3a7d6 100644
--- a/inject/injector.cpp
+++ b/inject/injector.cpp
@@ -357,6 +357,85 @@ isNumber(const char *arg) {
return true;
}
+
+static BOOL
+ejectDll(HANDLE hProcess, const char *szDllPath)
+{
+ /*
+ * Enumerate all modules.
+ */
+
+ HMODULE *phModules = NULL;
+ DWORD cb = sizeof *phModules *
+#ifdef NDEBUG
+ 32
+#else
+ 4
+#endif
+ ;
+ DWORD cbNeeded = 0;
+ while (true) {
+ phModules = (HMODULE *)realloc(phModules, cb);
+ if (!EnumProcessModules(hProcess, phModules, cb, &cbNeeded)) {
+ logLastError("failed to enumerate modules in remote process");
+ free(phModules);
+ return FALSE;
+ }
+
+ if (cbNeeded < cb) {
+ break;
+ }
+
+ cb *= 2;
+ }
+
+ DWORD cNumModules = cbNeeded / sizeof *phModules;
+
+ /*
+ * Search our DLL.
+ */
+
+ const char *szDllName = getBaseName(szDllPath);
+ HMODULE hModule = NULL;
+ for (unsigned i = 0; i < cNumModules; ++i) {
+ char szModName[MAX_PATH];
+ if (GetModuleFileNameExA(hProcess, phModules[i], szModName, ARRAY_SIZE(szModName))) {
+ if (stricmp(getBaseName(szModName), szDllName) == 0) {
+ hModule = phModules[i];
+ break;
+ }
+ }
+ }
+
+ free(phModules);
+
+ if (!hModule) {
+ debugPrintf("inject: error: failed to find %s module in the remote process\n", szDllName);
+ return FALSE;
+ }
+
+ PTHREAD_START_ROUTINE lpStartAddress =
+ (PTHREAD_START_ROUTINE)GetProcAddress(GetModuleHandleA("KERNEL32"), "FreeLibrary");
+
+ HANDLE hThread = CreateRemoteThread(hProcess, NULL, 0, lpStartAddress, hModule, 0, NULL);
+ if (!hThread) {
+ logLastError("failed to create remote thread");
+ return FALSE;
+ }
+
+ WaitForSingleObject(hThread, INFINITE);
+
+ DWORD bRet = 0;
+ GetExitCodeThread(hThread, &bRet);
+ if (!bRet) {
+ debugPrintf("inject: error: failed to unload %s from the remote process\n", szDllPath);
+ return FALSE;
+ }
+
+ return TRUE;
+}
+
+
static void
help(void)
{
@@ -630,6 +709,11 @@ main(int argc, char *argv[])
if (bAttach) {
if (bAttachDwm) {
restartDwmComposition(hProcess);
+ } else {
+ fprintf(stderr, "Press any key when finished tracing\n");
+ getchar();
+
+ ejectDll(hProcess, szDllPath);
}
if (dwThreadId) {