Skip to content

Commit

Permalink
Migrated to use CsWin32
Browse files Browse the repository at this point in the history
  • Loading branch information
shmuelie committed Dec 16, 2023
1 parent 091cb68 commit 5bd8979
Show file tree
Hide file tree
Showing 33 changed files with 181 additions and 1,401 deletions.
7 changes: 4 additions & 3 deletions src/Shmuelie.WinRTServer/BaseActivationFactory.cs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
using System;
using System.Runtime.InteropServices.WindowsRuntime;

namespace Shmuelie.WinRTServer;

/// <summary>
/// Base for a WinRT Activation Factory for a .NET type.
/// </summary>
/// <seealso cref="IActivationFactory"/>
public abstract class BaseActivationFactory : IActivationFactory
#if !NETSTANDARD
[System.Runtime.Versioning.SupportedOSPlatform("windows8.0")]
#endif
public abstract class BaseActivationFactory
{
/// <inheritdoc/>
public abstract object ActivateInstance();
Expand Down
1 change: 0 additions & 1 deletion src/Shmuelie.WinRTServer/BaseClassFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ namespace Shmuelie.WinRTServer;
/// <summary>
/// Base for a COM class factory for a .NET type.
/// </summary>
/// <seealso cref="Interop.Windows.IClassFactory"/>
/// <remarks>Does not support aggregation. Will always return <c>CLASS_E_NOAGGREGATION</c> if requested.</remarks>
public abstract class BaseClassFactory
{
Expand Down
12 changes: 7 additions & 5 deletions src/Shmuelie.WinRTServer/ComServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
using System.Threading.Tasks;
using System.Timers;
using Shmuelie.Interop.Windows;
using static Shmuelie.Interop.Windows.ComBaseAPI;
using static Shmuelie.Interop.Windows.Windows;
using Windows.Win32.Foundation;
using Windows.Win32.System.Com;
using static Windows.Win32.PInvoke;

namespace Shmuelie.WinRTServer;

Expand Down Expand Up @@ -42,8 +43,9 @@ public sealed class ComServer : IAsyncDisposable
public unsafe ComServer()
{
using ComPtr<IGlobalOptions> options = default;
Guid clsid = IGlobalOptions.CLSID;
if (CoCreateInstance(&clsid, null, (uint)CLSCTX.CLSCTX_INPROC_SERVER, __uuidof<IGlobalOptions>(), (void**)options.GetAddressOf()) == S.S_OK)
Guid clsid = CLSID_GlobalOptions;
Guid iid = IGlobalOptions.IID_Guid;
if (CoCreateInstance(&clsid, null, CLSCTX.CLSCTX_INPROC_SERVER, &iid, (void**)options.GetAddressOf()) == HRESULT.S_OK)
{
options.Get()->Set(GLOBALOPT_PROPERTIES.COMGLB_RO_SETTINGS, (nuint)GLOBALOPT_RO_FLAGS.COMGLB_FAST_RUNDOWN);
}
Expand Down Expand Up @@ -128,7 +130,7 @@ public unsafe bool RegisterClassFactory(BaseClassFactory factory)
proxy.Attach(BaseClassFactoryProxy.Create(factory));

uint cookie;
Marshal.ThrowExceptionForHR(CoRegisterClassObject(&clsid, (IUnknown*)proxy.Get(), (uint)CLSCTX.CLSCTX_LOCAL_SERVER, (uint)(REGCLS.REGCLS_MULTIPLEUSE | REGCLS.REGCLS_SUSPENDED), &cookie));
Marshal.ThrowExceptionForHR(CoRegisterClassObject(&clsid, (IUnknown*)proxy.Get(), CLSCTX.CLSCTX_LOCAL_SERVER, (REGCLS.REGCLS_MULTIPLEUSE | REGCLS.REGCLS_SUSPENDED), &cookie));

factories.Add(clsid, (factory, cookie));
return true;
Expand Down
70 changes: 38 additions & 32 deletions src/Shmuelie.WinRTServer/Internal/BaseActivationFactoryProxy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,20 @@
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Threading;
using Windows.Win32.System.WinRT;
using Windows.Win32.System.Com;
using Windows.Win32.Foundation;
using static Windows.Win32.PInvoke;
using Shmuelie.Interop.Windows;
using static Shmuelie.Interop.Windows.Windows;

namespace Shmuelie.WinRTServer;

/// <summary>
/// CCW for <see cref="BaseActivationFactory" />.
/// </summary>
#if !NETSTANDARD
[System.Runtime.Versioning.SupportedOSPlatform("windows8.0")]
#endif
internal unsafe struct BaseActivationFactoryProxy
{
private static readonly void** Vtbl = InitVtbl();
Expand Down Expand Up @@ -64,7 +70,7 @@ public uint Release()
private static class Impl
{
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
public delegate int QueryInterfaceDelegate(BaseActivationFactoryProxy* @this, Guid* riid, void** ppvObject);
public delegate HRESULT QueryInterfaceDelegate(BaseActivationFactoryProxy* @this, Guid* riid, void** ppvObject);

[UnmanagedFunctionPointer(CallingConvention.StdCall)]
public delegate uint AddRefDelegate(BaseActivationFactoryProxy* @this);
Expand All @@ -73,16 +79,16 @@ private static class Impl
public delegate uint ReleaseDelegate(BaseActivationFactoryProxy* @this);

[UnmanagedFunctionPointer(CallingConvention.StdCall)]
public delegate int GetIidsDelegate(BaseActivationFactoryProxy* @this, uint* iidCount, Guid** iids);
public delegate HRESULT GetIidsDelegate(BaseActivationFactoryProxy* @this, uint* iidCount, Guid** iids);

[UnmanagedFunctionPointer(CallingConvention.StdCall)]
public delegate int GetRuntimeClassNameDelegate(BaseActivationFactoryProxy* @this, HSTRING* className);
public delegate HRESULT GetRuntimeClassNameDelegate(BaseActivationFactoryProxy* @this, HSTRING* className);

[UnmanagedFunctionPointer(CallingConvention.StdCall)]
public delegate int GetTrustLevelDelegate(BaseActivationFactoryProxy* @this, TrustLevel* trustLevel);
public delegate HRESULT GetTrustLevelDelegate(BaseActivationFactoryProxy* @this, TrustLevel* trustLevel);

[UnmanagedFunctionPointer(CallingConvention.StdCall)]
public delegate int ActivateInstanceDelegate(BaseActivationFactoryProxy* @this, IInspectable** instance);
public delegate HRESULT ActivateInstanceDelegate(BaseActivationFactoryProxy* @this, IInspectable** instance);

/// <summary>
/// The cached <see cref="QueryInterfaceDelegate"/> for <c>IUnknown.QueryInterface(REFIID, void**)</c>.
Expand Down Expand Up @@ -110,20 +116,20 @@ private static class Impl
/// <summary>
/// Implements <see href="https://docs.microsoft.com/en-us/windows/win32/api/unknwn/nf-unknwn-iunknown-queryinterface(refiid_void)"><c>IUnknown.QueryInterface(REFIID, void**)</c></see>.
/// </summary>
private static int QueryInterface(BaseActivationFactoryProxy* @this, Guid* riid, void** ppvObject)
private static HRESULT QueryInterface(BaseActivationFactoryProxy* @this, Guid* riid, void** ppvObject)
{
if (riid->Equals(__uuidof<IUnknown>()) ||
riid->Equals(__uuidof<IInspectable>()) ||
riid->Equals(__uuidof<IActivationFactory>()))
if (riid->Equals(IUnknown.IID_Guid) ||
riid->Equals(IInspectable.IID_Guid) ||
riid->Equals(IActivationFactory.IID_Guid))
{
_ = Interlocked.Increment(ref Unsafe.As<uint, int>(ref @this->_referenceCount));

*ppvObject = @this;

return S.S_OK;
return HRESULT.S_OK;
}

return E.E_NOINTERFACE;
return HRESULT.E_NOINTERFACE;
}

/// <summary>
Expand Down Expand Up @@ -151,99 +157,99 @@ public static uint Release(BaseActivationFactoryProxy* @this)
return referenceCount;
}

public static int GetIids(BaseActivationFactoryProxy* @this, uint* iidCount, Guid** iids)
public static HRESULT GetIids(BaseActivationFactoryProxy* @this, uint* iidCount, Guid** iids)
{
if (iidCount is null || iids is null)
{
return E.E_INVALIDARG;
return HRESULT.E_INVALIDARG;
}

*iidCount = 1;
*iids = (Guid*)Marshal.AllocHGlobal(sizeof(Guid));
*iids[0] = __uuidof<IActivationFactory>();
return S.S_OK;
*iids[0] = IActivationFactory.IID_Guid;
return HRESULT.S_OK;
}

public static int GetRuntimeClassName(BaseActivationFactoryProxy* @this, HSTRING* className)
public static HRESULT GetRuntimeClassName(BaseActivationFactoryProxy* @this, HSTRING* className)
{
try
{
if (className is null)
{
return E.E_INVALIDARG;
return HRESULT.E_INVALIDARG;
}

BaseActivationFactory? factory = Unsafe.As<BaseActivationFactory>(@this->_factory.Target);

if (factory is null)
{
return E.E_HANDLE;
return HRESULT.E_HANDLE;
}

string? fullName = factory.GetType().FullName;

if (fullName is null)
{
return E.E_UNEXPECTED;
return HRESULT.E_UNEXPECTED;
}

fixed (char* fullNamePtr = fullName)
{
return WinString.WindowsCreateString((ushort*)fullNamePtr, (uint)fullName.Length, className);
return WindowsCreateString((PCWSTR)fullNamePtr, (uint)fullName.Length, className);
}

}
catch (Exception e)
{
return Marshal.GetHRForException(e);
return (HRESULT)Marshal.GetHRForException(e);
}
}

public static int GetTrustLevel(BaseActivationFactoryProxy* @this, TrustLevel* trustLevel)
public static HRESULT GetTrustLevel(BaseActivationFactoryProxy* @this, TrustLevel* trustLevel)
{
if (trustLevel is null)
{
return E.E_INVALIDARG;
return HRESULT.E_INVALIDARG;
}

*trustLevel = TrustLevel.BaseTrust;
return S.S_OK;
return HRESULT.S_OK;
}

public static int ActivateInstance(BaseActivationFactoryProxy* @this, IInspectable** instance)
public static HRESULT ActivateInstance(BaseActivationFactoryProxy* @this, IInspectable** instance)
{
try
{
if (instance is null)
{
return E.E_INVALIDARG;
return HRESULT.E_INVALIDARG;
}

BaseActivationFactory? factory = Unsafe.As<BaseActivationFactory>(@this->_factory.Target);

if (factory is null)
{
return E.E_HANDLE;
return HRESULT.E_HANDLE;
}

object managedInstance = factory.ActivateInstance();

using ComPtr<IUnknown> unkwnPtr = default;
unkwnPtr.Attach((IUnknown*)Marshal.GetIUnknownForObject(managedInstance));
int result = unkwnPtr.CopyTo(instance);
if (result != S.S_OK)
HRESULT result = unkwnPtr.CopyTo(instance);
if (result != HRESULT.S_OK)
{
return result;
}

factory.OnInstanceCreated(managedInstance);

return S.S_OK;
return HRESULT.S_OK;

}
catch (Exception e)
{
return Marshal.GetHRForException(e);
return (HRESULT)Marshal.GetHRForException(e);
}
}
}
Expand Down
46 changes: 23 additions & 23 deletions src/Shmuelie.WinRTServer/Internal/BaseClassFactoryProxy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Threading;
using Shmuelie.Interop.Windows;
using static Shmuelie.Interop.Windows.ComBaseAPI;
using static Shmuelie.Interop.Windows.Windows;
using Windows.Win32.Foundation;
using Windows.Win32.System.Com;
using static Windows.Win32.PInvoke;

namespace Shmuelie.WinRTServer;

Expand Down Expand Up @@ -63,7 +63,7 @@ public uint Release()
private static class Impl
{
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
public delegate int QueryInterfaceDelegate(BaseClassFactoryProxy* @this, Guid* riid, void** ppvObject);
public delegate HRESULT QueryInterfaceDelegate(BaseClassFactoryProxy* @this, Guid* riid, void** ppvObject);

[UnmanagedFunctionPointer(CallingConvention.StdCall)]
public delegate uint AddRefDelegate(BaseClassFactoryProxy* @this);
Expand All @@ -72,10 +72,10 @@ private static class Impl
public delegate uint ReleaseDelegate(BaseClassFactoryProxy* @this);

[UnmanagedFunctionPointer(CallingConvention.StdCall)]
public delegate int CreateInstanceDelegate(BaseClassFactoryProxy* @this, IUnknown* pUnkOuter, Guid* riid, void** ppvObject);
public delegate HRESULT CreateInstanceDelegate(BaseClassFactoryProxy* @this, IUnknown* pUnkOuter, Guid* riid, void** ppvObject);

[UnmanagedFunctionPointer(CallingConvention.StdCall)]
public delegate int LockServerDelegate(BaseClassFactoryProxy* @this, int fLock);
public delegate HRESULT LockServerDelegate(BaseClassFactoryProxy* @this, int fLock);

/// <summary>
/// The cached <see cref="QueryInterfaceDelegate"/> for <c>IUnknown.QueryInterface(REFIID, void**)</c>.
Expand All @@ -99,19 +99,19 @@ private static class Impl
/// <summary>
/// Implements <see href="https://docs.microsoft.com/en-us/windows/win32/api/unknwn/nf-unknwn-iunknown-queryinterface(refiid_void)"><c>IUnknown.QueryInterface(REFIID, void**)</c></see>.
/// </summary>
private static int QueryInterface(BaseClassFactoryProxy* @this, Guid* riid, void** ppvObject)
private static HRESULT QueryInterface(BaseClassFactoryProxy* @this, Guid* riid, void** ppvObject)
{
if (riid->Equals(__uuidof<IUnknown>()) ||
riid->Equals(__uuidof<IClassFactory>()))
if (riid->Equals(IUnknown.IID_Guid) ||
riid->Equals(IClassFactory.IID_Guid))
{
_ = Interlocked.Increment(ref Unsafe.As<uint, int>(ref @this->_referenceCount));

*ppvObject = @this;

return S.S_OK;
return HRESULT.S_OK;
}

return E.E_NOINTERFACE;
return HRESULT.E_NOINTERFACE;
}

/// <summary>
Expand Down Expand Up @@ -139,30 +139,30 @@ public static uint Release(BaseClassFactoryProxy* @this)
return referenceCount;
}

public static int CreateInstance(BaseClassFactoryProxy* @this, IUnknown* pUnkOuter, Guid* riid, void** ppvObject)
public static HRESULT CreateInstance(BaseClassFactoryProxy* @this, IUnknown* pUnkOuter, Guid* riid, void** ppvObject)
{
try
{
if (pUnkOuter is not null)
{
return WinError.CLASS_E_NOAGGREGATION;
return HRESULT.CLASS_E_NOAGGREGATION;
}

BaseClassFactory? factory = Unsafe.As<BaseClassFactory>(@this->_factory.Target);

if (factory is null)
{
return E.E_HANDLE;
return HRESULT.E_HANDLE;
}

if (!riid->Equals(__uuidof<IUnknown>()) && !riid->Equals(factory.Iid))
if (!riid->Equals(IUnknown.IID_Guid) && !riid->Equals(factory.Iid))
{
return E.E_NOINTERFACE;
return HRESULT.E_NOINTERFACE;
}

var instance = factory.CreateInstance();

if (riid->Equals(__uuidof<IUnknown>()))
if (riid->Equals(IUnknown.IID_Guid))
{
*ppvObject = (void*)Marshal.GetIUnknownForObject(instance);
}
Expand All @@ -172,7 +172,7 @@ public static int CreateInstance(BaseClassFactoryProxy* @this, IUnknown* pUnkOut

if (t is null)
{
return E.E_UNEXPECTED;
return HRESULT.E_UNEXPECTED;
}

*ppvObject = (void*)Marshal.GetComInterfaceForObject(instance, t);
Expand All @@ -182,12 +182,12 @@ public static int CreateInstance(BaseClassFactoryProxy* @this, IUnknown* pUnkOut
}
catch (Exception e)
{
return Marshal.GetHRForException(e);
return (HRESULT)Marshal.GetHRForException(e);
}
return S.S_OK;
return HRESULT.S_OK;
}

public static int LockServer(BaseClassFactoryProxy* @this, int fLock)
public static HRESULT LockServer(BaseClassFactoryProxy* @this, int fLock)
{
try
{
Expand All @@ -202,9 +202,9 @@ public static int LockServer(BaseClassFactoryProxy* @this, int fLock)
}
catch (Exception e)
{
return Marshal.GetHRForException(e);
return (HRESULT)Marshal.GetHRForException(e);
}
return S.S_OK;
return HRESULT.S_OK;
}
}
}
Expand Down
Loading

0 comments on commit 5bd8979

Please sign in to comment.