diff options
-rw-r--r-- | UsbDk/FilterDevice.cpp | 112 | ||||
-rw-r--r-- | UsbDk/FilterDevice.h | 12 | ||||
-rw-r--r-- | UsbDk/stdafx.h | 11 |
3 files changed, 132 insertions, 3 deletions
diff --git a/UsbDk/FilterDevice.cpp b/UsbDk/FilterDevice.cpp index 1a806a0..dab8daf 100644 --- a/UsbDk/FilterDevice.cpp +++ b/UsbDk/FilterDevice.cpp @@ -666,3 +666,115 @@ size_t CUsbDkFilterDevice::CStrategist::GetRequestContextSize() max(CUsbDkHubFilterStrategy::GetRequestContextSize(), CUsbDkRedirectorStrategy::GetRequestContextSize())); } + +static ULONG InterfaceTypeMask(UCHAR bClass) { + switch (bClass) { + case USB_DEVICE_CLASS_AUDIO: + TraceEvents(TRACE_LEVEL_INFORMATION, TRACE_FILTERDEVICE, "Class 0x%X -> audio", bClass); + return 1 << USB_DEVICE_CLASS_AUDIO; + case USB_DEVICE_CLASS_COMMUNICATIONS: + case USB_DEVICE_CLASS_CDC_DATA: + case USB_DEVICE_CLASS_WIRELESS_CONTROLLER: + TraceEvents(TRACE_LEVEL_INFORMATION, TRACE_FILTERDEVICE, "Class 0x%X -> network", bClass); + return 1 << USB_DEVICE_CLASS_COMMUNICATIONS; + case USB_DEVICE_CLASS_PRINTER: + TraceEvents(TRACE_LEVEL_INFORMATION, TRACE_FILTERDEVICE, "Class 0x%X -> printer", bClass); + return 1 << USB_DEVICE_CLASS_PRINTER; + case USB_DEVICE_CLASS_STORAGE: + TraceEvents(TRACE_LEVEL_INFORMATION, TRACE_FILTERDEVICE, "Class 0x%X -> storage", bClass); + return 1 << USB_DEVICE_CLASS_STORAGE; + case USB_DEVICE_CLASS_VIDEO: + TraceEvents(TRACE_LEVEL_INFORMATION, TRACE_FILTERDEVICE, "Class 0x%X -> video", bClass); + return 1 << USB_DEVICE_CLASS_VIDEO; + case USB_DEVICE_CLASS_AUDIO_VIDEO: + TraceEvents(TRACE_LEVEL_INFORMATION, TRACE_FILTERDEVICE, "Class 0x%X -> audio/video", bClass); + return (1 << USB_DEVICE_CLASS_VIDEO) | + (1 << USB_DEVICE_CLASS_AUDIO); + case USB_DEVICE_CLASS_HUB: + return (1 << USB_DEVICE_CLASS_HUB); + case USB_DEVICE_CLASS_HUMAN_INTERFACE: + return (1 << USB_DEVICE_CLASS_HUMAN_INTERFACE); + default: + return 1U << 31; + } +} + +static PUSB_INTERFACE_DESCRIPTOR FindNextInterface(PUSB_CONFIGURATION_DESCRIPTOR cfg, ULONG& offset) +{ + PUSB_COMMON_DESCRIPTOR desc; + if (offset >= cfg->wTotalLength) + return NULL; + do { + if (offset + sizeof(*desc) > cfg->wTotalLength) + return NULL; + desc = (PUSB_COMMON_DESCRIPTOR)((PUCHAR)cfg + offset); + if (desc->bLength + offset > cfg->wTotalLength) + return NULL; + offset += desc->bLength; + if (desc->bDescriptorType == USB_INTERFACE_DESCRIPTOR_TYPE) { + return (PUSB_INTERFACE_DESCRIPTOR)desc; + } + } while (1); +} + +void CUsbDkChildDevice::DetermineDeviceClasses() +{ + if (m_DevDescriptor.bDeviceClass) + { + m_ClassMaskForExtHider |= InterfaceTypeMask(m_DevDescriptor.bDeviceClass); + } + + USB_CONFIGURATION_DESCRIPTOR cfg; + for (UCHAR index = 0; index < m_CfgDescriptors.Size(); ++index) + { + if (ConfigurationDescriptor(index, cfg, sizeof(cfg))) + { + PVOID buffer = ExAllocatePoolWithTag(USBDK_NON_PAGED_POOL, cfg.wTotalLength, 'CFGD'); + if (buffer) + { + USB_CONFIGURATION_DESCRIPTOR *p = (USB_CONFIGURATION_DESCRIPTOR *)buffer; + if (ConfigurationDescriptor(index, *p, cfg.wTotalLength)) + { + ULONG offset = 0; + PUSB_INTERFACE_DESCRIPTOR intf_desc; + while (true) + { + intf_desc = FindNextInterface(p, offset); + if (!intf_desc) + break; + m_ClassMaskForExtHider |= InterfaceTypeMask(intf_desc->bInterfaceClass); + } + } + ExFreePool(buffer); + } + } + } +#define SINGLE_DETERMINATIVE_CLASS +#if defined(SINGLE_DETERMINATIVE_CLASS) + // only one determinative type present in final mask + if (m_ClassMaskForExtHider & (1 << USB_DEVICE_CLASS_PRINTER)) + { + m_ClassMaskForExtHider = 1 << USB_DEVICE_CLASS_PRINTER; + } + else if (m_ClassMaskForExtHider & (1 << USB_DEVICE_CLASS_COMMUNICATIONS)) + { + m_ClassMaskForExtHider = 1 << USB_DEVICE_CLASS_COMMUNICATIONS; + } + else if (m_ClassMaskForExtHider & ((1 << USB_DEVICE_CLASS_AUDIO) | (1 << USB_DEVICE_CLASS_VIDEO))) + { + m_ClassMaskForExtHider &= (1 << USB_DEVICE_CLASS_AUDIO) | (1 << USB_DEVICE_CLASS_VIDEO); + } +#else + // all the determinative types present in final mask + ULONG determinativeMask = + (1 << USB_DEVICE_CLASS_PRINTER) | + (1 << USB_DEVICE_CLASS_COMMUNICATIONS) | + (1 << USB_DEVICE_CLASS_AUDIO) | + (1 << USB_DEVICE_CLASS_VIDEO); + if (m_ClassMaskForExtHider & determinativeMask) + { + m_ClassMaskForExtHider &= determinativeMask; + } +#endif + TraceEvents(TRACE_LEVEL_INFORMATION, TRACE_FILTERDEVICE, "Class mask %08X", m_ClassMaskForExtHider); +} diff --git a/UsbDk/FilterDevice.h b/UsbDk/FilterDevice.h index 71d3d63..3d79ca4 100644 --- a/UsbDk/FilterDevice.h +++ b/UsbDk/FilterDevice.h @@ -65,8 +65,10 @@ public: , m_CfgDescriptors(CfgDescriptors) , m_ParentDevice(ParentDevice) , m_PDO(PDO) - {} - + , m_ClassMaskForExtHider(0) + { + DetermineDeviceClasses(); + } ULONG ParentID() const; PCWCHAR DeviceID() const { return *m_DeviceID->begin(); } PCWCHAR InstanceID() const { return *m_InstanceID->begin(); } @@ -77,6 +79,8 @@ public: const USB_DEVICE_DESCRIPTOR &DeviceDescriptor() const { return m_DevDescriptor; } PDEVICE_OBJECT PDO() const { return m_PDO; } + ULONG ClassesBitMask() const + { return m_ClassMaskForExtHider; } bool ConfigurationDescriptor(UCHAR Index, USB_CONFIGURATION_DESCRIPTOR &Buffer, size_t BufferLength) { @@ -107,10 +111,12 @@ private: TDescriptorsCache m_CfgDescriptors; PDEVICE_OBJECT m_PDO; const CUsbDkFilterDevice &m_ParentDevice; - + ULONG m_ClassMaskForExtHider; CUsbDkChildDevice(const CUsbDkChildDevice&) = delete; CUsbDkChildDevice& operator= (const CUsbDkChildDevice&) = delete; + void DetermineDeviceClasses(); + DECLARE_CWDMLIST_ENTRY(CUsbDkChildDevice); }; diff --git a/UsbDk/stdafx.h b/UsbDk/stdafx.h index 11b8af7..fccb90e 100644 --- a/UsbDk/stdafx.h +++ b/UsbDk/stdafx.h @@ -17,6 +17,17 @@ extern "C" #if !TARGET_OS_WIN_XP #include <UsbSpec.h> +#else +#define USB_DEVICE_CLASS_AUDIO 0x01 +#define USB_DEVICE_CLASS_COMMUNICATIONS 0x02 +#define USB_DEVICE_CLASS_HUMAN_INTERFACE 0x03 +#define USB_DEVICE_CLASS_PRINTER 0x07 +#define USB_DEVICE_CLASS_STORAGE 0x08 +#define USB_DEVICE_CLASS_HUB 0x09 +#define USB_DEVICE_CLASS_CDC_DATA 0x0A +#define USB_DEVICE_CLASS_VIDEO 0x0E +#define USB_DEVICE_CLASS_AUDIO_VIDEO 0x10 +#define USB_DEVICE_CLASS_WIRELESS_CONTROLLER 0xE0 #endif #include <wdfusb.h> |