/* Implementation of COW fork() using NTDLL API for Windows systems (C)Copyright 2009 Simon Urbanek This code is partially based on the book "Windows NT/2000 Native API Reference" by Gary Nebbett (Sams Publishing, 2000, ISBN 1-57870-199-6) */ #ifdef WIN32 #include #include /* winternl.h is not part of MinGW so we have to declare whatever is needed */ #pragma mark ntdll API types typedef LONG NTSTATUS; typedef struct _SYSTEM_HANDLE_INFORMATION { ULONG ProcessId; UCHAR ObjectTypeNumber; UCHAR Flags; USHORT Handle; PVOID Object; ACCESS_MASK GrantedAccess; } SYSTEM_HANDLE_INFORMATION, *PSYSTEM_HANDLE_INFORMATION; typedef struct _OBJECT_ATTRIBUTES { ULONG Length; HANDLE RootDirectory; PVOID /* really PUNICODE_STRING */ ObjectName; ULONG Attributes; PVOID SecurityDescriptor; /* type SECURITY_DESCRIPTOR */ PVOID SecurityQualityOfService; /* type SECURITY_QUALITY_OF_SERVICE */ } OBJECT_ATTRIBUTES, *POBJECT_ATTRIBUTES; typedef enum _MEMORY_INFORMATION_{ MemoryBasicInformation, MemoryWorkingSetList, MemorySectionName, MemoryBasicVlmInformation } MEMORY_INFORMATION_CLASS; typedef struct _CLIENT_ID { HANDLE UniqueProcess; HANDLE UniqueThread; } CLIENT_ID, *PCLIENT_ID; typedef struct _USER_STACK { PVOID FixedStackBase; PVOID FixedStackLimit; PVOID ExpandableStackBase; PVOID ExpandableStackLimit; PVOID ExpandableStackBottom; } USER_STACK, *PUSER_STACK; typedef LONG KPRIORITY; typedef ULONG_PTR KAFFINITY; typedef KAFFINITY *PKAFFINITY; typedef struct _THREAD_BASIC_INFORMATION { NTSTATUS ExitStatus; PVOID TebBaseAddress; CLIENT_ID ClientId; KAFFINITY AffinityMask; KPRIORITY Priority; KPRIORITY BasePriority; } THREAD_BASIC_INFORMATION, *PTHREAD_BASIC_INFORMATION; typedef enum _THREAD_INFORMATION_CLASS { ThreadBasicInformation, ThreadTimes, ThreadPriority, ThreadBasePriority, ThreadAffinityMask, ThreadImpersonationToken, ThreadDescriptorTableEntry, ThreadEnableAlignmentFaultFixup, ThreadEventPair, ThreadQuerySetWin32StartAddress, ThreadZeroTlsCell, ThreadPerformanceCount, ThreadAmILastThread, ThreadIdealProcessor, ThreadPriorityBoost, ThreadSetTlsArrayAddress, ThreadIsIoPending, ThreadHideFromDebugger } THREAD_INFORMATION_CLASS, *PTHREAD_INFORMATION_CLASS; typedef enum _SYSTEM_INFORMATION_CLASS { SystemHandleInformation = 0x10 } SYSTEM_INFORMATION_CLASS; #pragma mark ntdll API - function entry points typedef NTSTATUS (NTAPI *ZwWriteVirtualMemory_t)(IN HANDLE ProcessHandle, IN PVOID BaseAddress, IN PVOID Buffer, IN ULONG NumberOfBytesToWrite, OUT PULONG NumberOfBytesWritten OPTIONAL); typedef NTSTATUS (NTAPI *ZwCreateProcess_t)(OUT PHANDLE ProcessHandle, IN ACCESS_MASK DesiredAccess, IN POBJECT_ATTRIBUTES ObjectAttributes, IN HANDLE InheriteFromProcessHandle, IN BOOLEAN InheritHandles, IN HANDLE SectionHandle OPTIONAL, IN HANDLE DebugPort OPTIONAL, IN HANDLE ExceptionPort OPTIONAL); typedef NTSTATUS (WINAPI *ZwQuerySystemInformation_t)(SYSTEM_INFORMATION_CLASS SystemInformationClass, PVOID SystemInformation, ULONG SystemInformationLength, PULONG ReturnLength); typedef NTSTATUS (NTAPI *ZwQueryVirtualMemory_t)(IN HANDLE ProcessHandle, IN PVOID BaseAddress, IN MEMORY_INFORMATION_CLASS MemoryInformationClass, OUT PVOID MemoryInformation, IN ULONG MemoryInformationLength, OUT PULONG ReturnLength OPTIONAL); typedef NTSTATUS (NTAPI *ZwGetContextThread_t)(IN HANDLE ThreadHandle, OUT PCONTEXT Context); typedef NTSTATUS (NTAPI *ZwCreateThread_t)(OUT PHANDLE ThreadHandle, IN ACCESS_MASK DesiredAccess, IN POBJECT_ATTRIBUTES ObjectAttributes, IN HANDLE ProcessHandle, OUT PCLIENT_ID ClientId, IN PCONTEXT ThreadContext, IN PUSER_STACK UserStack, IN BOOLEAN CreateSuspended); typedef NTSTATUS (NTAPI *ZwResumeThread_t)(IN HANDLE ThreadHandle, OUT PULONG SuspendCount OPTIONAL); typedef NTSTATUS (NTAPI *ZwClose_t)(IN HANDLE ObjectHandle); typedef NTSTATUS (NTAPI *ZwQueryInformationThread_t)(IN HANDLE ThreadHandle, IN THREAD_INFORMATION_CLASS ThreadInformationClass, OUT PVOID ThreadInformation, IN ULONG ThreadInformationLength, OUT PULONG ReturnLength OPTIONAL ); /* function pointers */ static ZwCreateProcess_t ZwCreateProcess; static ZwQuerySystemInformation_t ZwQuerySystemInformation; static ZwQueryVirtualMemory_t ZwQueryVirtualMemory; static ZwCreateThread_t ZwCreateThread; static ZwGetContextThread_t ZwGetContextThread; static ZwResumeThread_t ZwResumeThread; static ZwClose_t ZwClose; static ZwQueryInformationThread_t ZwQueryInformationThread; static ZwWriteVirtualMemory_t ZwWriteVirtualMemory; /* macro definitions */ #define NtCurrentProcess() ((HANDLE)-1) #define NtCurrentThread() ((HANDLE) -2) /* we use really the Nt versions - so the following is just for completeness */ #define ZwCurrentProcess() NtCurrentProcess() #define ZwCurrentThread() NtCurrentThread() #define STATUS_INFO_LENGTH_MISMATCH ((NTSTATUS)0xC0000004L) #define STATUS_SUCCESS ((NTSTATUS)0x00000000L) #pragma mark -- helper functions -- #ifdef INHERIT_ALL /* set all handles belonging to this process as inheritable */ static void set_inherit_all() { ULONG n = 0x1000; PULONG p = (PULONG) calloc(n, sizeof(ULONG)); /* some guesswork to allocate a structure that will fit it all */ while (ZwQuerySystemInformation(SystemHandleInformation, p, n * sizeof(ULONG), 0) == STATUS_INFO_LENGTH_MISMATCH) { free(p); n *= 2; p = (PULONG) calloc(n, sizeof(ULONG)); } /* p points to an ULONG with the count, the entries follow (hence p[0] is the size and p[1] is where the first entry starts */ PSYSTEM_HANDLE_INFORMATION h = (PSYSTEM_HANDLE_INFORMATION)(p + 1); ULONG pid = GetCurrentProcessId(); ULONG i = 0, count = *p; while (i < count) { if (h[i].ProcessId == pid) SetHandleInformation((HANDLE)(ULONG) h[i].Handle, HANDLE_FLAG_INHERIT, HANDLE_FLAG_INHERIT); i++; } free(p); } #endif /* setjmp env for the jump back into the fork() function */ static jmp_buf jenv; /* entry point for our child thread process - just longjmp into fork */ static int child_entry(void) { longjmp(jenv, 1); return 0; } /* initialize NTDLL entry points */ static int init_NTAPI(void) { HANDLE ntdll = GetModuleHandle("ntdll"); if (ntdll == NULL) return -1; ZwCreateProcess = (ZwCreateProcess_t) GetProcAddress(ntdll, "ZwCreateProcess"); ZwQuerySystemInformation = (ZwQuerySystemInformation_t) GetProcAddress(ntdll, "ZwQuerySystemInformation"); ZwQueryVirtualMemory = (ZwQueryVirtualMemory_t) GetProcAddress(ntdll, "ZwQueryVirtualMemory"); ZwCreateThread = (ZwCreateThread_t) GetProcAddress(ntdll, "ZwCreateThread"); ZwGetContextThread = (ZwGetContextThread_t) GetProcAddress(ntdll, "ZwGetContextThread"); ZwResumeThread = (ZwResumeThread_t) GetProcAddress(ntdll, "ZwResumeThread"); ZwQueryInformationThread = (ZwQueryInformationThread_t) GetProcAddress(ntdll, "ZwQueryInformationThread"); ZwWriteVirtualMemory = (ZwWriteVirtualMemory_t) GetProcAddress(ntdll, "ZwWriteVirtualMemory"); ZwClose = (ZwClose_t) GetProcAddress(ntdll, "ZwClose"); /* in theory we chould check all of them - but I guess that would be a waste of time ... */ return (!ZwCreateProcess) ? -1 : 0; } #pragma mark -- fork() -- int fork(void) { if (setjmp(jenv) != 0) return 0; /* return as a child */ /* check whether the entry points are initilized and get them if necessary */ if (!ZwCreateProcess && init_NTAPI()) return -1; #ifdef INHERIT_ALL /* make sure all handles are inheritable */ set_inherit_all(); #endif HANDLE hProcess = 0, hThread = 0; OBJECT_ATTRIBUTES oa = { sizeof(oa) }; /* create forked process */ ZwCreateProcess(&hProcess, PROCESS_ALL_ACCESS, &oa, NtCurrentProcess(), TRUE, 0, 0, 0); CONTEXT context = {CONTEXT_FULL | CONTEXT_DEBUG_REGISTERS | CONTEXT_FLOATING_POINT}; /* set the Eip for the child process to our child function */ ZwGetContextThread(NtCurrentThread(), &context); context.Eip = (ULONG)child_entry; MEMORY_BASIC_INFORMATION mbi; ZwQueryVirtualMemory(NtCurrentProcess(), (PVOID)context.Esp, MemoryBasicInformation, &mbi, sizeof mbi, 0); USER_STACK stack = {0, 0, (PCHAR)mbi.BaseAddress + mbi.RegionSize, mbi.BaseAddress, mbi.AllocationBase}; CLIENT_ID cid; /* create thread using the modified context and stack */ ZwCreateThread(&hThread, THREAD_ALL_ACCESS, &oa, hProcess, &cid, &context, &stack, TRUE); /* copy exception table */ THREAD_BASIC_INFORMATION tbi; ZwQueryInformationThread(NtCurrentThread(), ThreadBasicInformation, &tbi, sizeof tbi, 0); PNT_TIB tib = (PNT_TIB)tbi.TebBaseAddress; ZwQueryInformationThread(hThread, ThreadBasicInformation, &tbi, sizeof tbi, 0); ZwWriteVirtualMemory(hProcess, tbi.TebBaseAddress, &tib->ExceptionList, sizeof tib->ExceptionList, 0); /* start (resume really) the child */ ZwResumeThread(hThread, 0); /* clean up */ ZwClose(hThread); ZwClose(hProcess); /* exit with child's pid */ return (int)cid.UniqueProcess; } /* Dear Emacs, please be nice and use Local Variables: mode:c tab-width: 4 c-basic-offset:4 End: */ #else /* unix has fork() already */ #include #endif