diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDependency.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDependency.cs index 52bcec0599..938f43e898 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDependency.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDependency.cs @@ -10,7 +10,6 @@ using System.Runtime.CompilerServices; using System.Runtime.Remoting; using System.Runtime.Serialization; -using System.Runtime.Serialization.Formatters.Binary; using System.Runtime.Versioning; using System.Security.Permissions; using System.Text; @@ -241,29 +240,39 @@ private static void InvokeCallback(object eventContextPair) // END EventContextPair private class. // ---------------------------------------- - // ---------------------------------------- - // Private class for restricting allowed types from deserialization. - // ---------------------------------------- - - private class SqlDependencyProcessDispatcherSerializationBinder : SerializationBinder + //----------------------------------------------- + // Private Class to add ObjRef as DataContract + //----------------------------------------------- + [SecurityPermission(SecurityAction.Assert, Flags = SecurityPermissionFlag.RemotingConfiguration)] + [DataContract] + private class SqlClientObjRef { - public override Type BindToType(string assemblyName, string typeName) + [DataMember] + private static ObjRef s_sqlObjRef; + internal static IRemotingTypeInfo _typeInfo; + + private SqlClientObjRef() { } + + public SqlClientObjRef(SqlDependencyProcessDispatcher dispatcher) : base() { - // Deserializing an unexpected type can inject objects with malicious side effects. - // If the type is unexpected, throw an exception to stop deserialization. - if (typeName == nameof(SqlDependencyProcessDispatcher)) - { - return typeof(SqlDependencyProcessDispatcher); - } - else - { - throw new ArgumentException("Unexpected type", nameof(typeName)); - } + s_sqlObjRef = RemotingServices.Marshal(dispatcher); + _typeInfo = s_sqlObjRef.TypeInfo; + } + + internal static bool CanCastToSqlDependencyProcessDispatcher() + { + return _typeInfo.CanCastTo(typeof(SqlDependencyProcessDispatcher), s_sqlObjRef); } + + internal ObjRef GetObjRef() + { + return s_sqlObjRef; + } + } - // ---------------------------------------- - // END SqlDependencyProcessDispatcherSerializationBinder private class. - // ---------------------------------------- + // ------------------------------------------ + // End SqlClientObjRef private class. + // ------------------------------------------- // ---------------- // Instance members @@ -306,10 +315,9 @@ public override Type BindToType(string assemblyName, string typeName) private static readonly string _typeName = (typeof(SqlDependencyProcessDispatcher)).FullName; // ----------- - // BID members + // EventSource members // ----------- - private readonly int _objectID = System.Threading.Interlocked.Increment(ref _objectTypeCount); private static int _objectTypeCount; // EventSource Counter internal int ObjectID @@ -336,7 +344,7 @@ public SqlDependency(SqlCommand command) : this(command, null, SQL.SqlDependency } /// - [System.Security.Permissions.HostProtectionAttribute(ExternalThreading = true)] + [HostProtection(ExternalThreading = true)] public SqlDependency(SqlCommand command, string options, int timeout) { long scopeID = SqlClientEventSource.Log.TryNotificationScopeEnterEvent(" {0}, options: '{1}', timeout: '{2}'", ObjectID, options, timeout); @@ -597,11 +605,13 @@ private static void ObtainProcessDispatcher() _processDispatcher = dependency.SingletonProcessDispatcher; // Set to static instance. // Serialize and set in native. - ObjRef objRef = GetObjRef(_processDispatcher); - BinaryFormatter formatter = new BinaryFormatter(); - MemoryStream stream = new MemoryStream(); - GetSerializedObject(objRef, formatter, stream); - SNINativeMethodWrapper.SetData(stream.GetBuffer()); // Native will be forced to synchronize and not overwrite. + using (MemoryStream stream = new MemoryStream()) + { + SqlClientObjRef objRef = new SqlClientObjRef(_processDispatcher); + DataContractSerializer serializer = new DataContractSerializer(objRef.GetType()); + GetSerializedObject(objRef, serializer, stream); + SNINativeMethodWrapper.SetData(stream.ToArray()); // Native will be forced to synchronize and not overwrite. + } } else { @@ -628,10 +638,20 @@ private static void ObtainProcessDispatcher() #if DEBUG // Possibly expensive, limit to debug. SqlClientEventSource.Log.TryNotificationTraceEvent(" AppDomain.CurrentDomain.FriendlyName: {0}", AppDomain.CurrentDomain.FriendlyName); #endif - BinaryFormatter formatter = new BinaryFormatter(); - MemoryStream stream = new MemoryStream(nativeStorage); - _processDispatcher = GetDeserializedObject(formatter, stream); // Deserialize and set for appdomain. - SqlClientEventSource.Log.TryNotificationTraceEvent(" processDispatcher obtained, ID: {0}", _processDispatcher.ObjectID); + using (MemoryStream stream = new MemoryStream(nativeStorage)) + { + DataContractSerializer serializer = new DataContractSerializer(typeof(SqlClientObjRef)); + if (SqlClientObjRef.CanCastToSqlDependencyProcessDispatcher()) + { + // Deserialize and set for appdomain. + _processDispatcher = GetDeserializedObject(serializer, stream); + } + else + { + throw new ArgumentException(Strings.SqlDependency_UnexpectedValueOnDeserialize); + } + SqlClientEventSource.Log.TryNotificationTraceEvent(" processDispatcher obtained, ID: {0}", _processDispatcher.ObjectID); + } } } @@ -639,26 +659,18 @@ private static void ObtainProcessDispatcher() // Static security asserted methods - limit scope of assert. // --------------------------------------------------------- - [SecurityPermission(SecurityAction.Assert, Flags = SecurityPermissionFlag.RemotingConfiguration)] - private static ObjRef GetObjRef(SqlDependencyProcessDispatcher _processDispatcher) - { - return RemotingServices.Marshal(_processDispatcher); - } - [SecurityPermission(SecurityAction.Assert, Flags = SecurityPermissionFlag.SerializationFormatter)] - private static void GetSerializedObject(ObjRef objRef, BinaryFormatter formatter, MemoryStream stream) + private static void GetSerializedObject(SqlClientObjRef objRef, DataContractSerializer serializer, MemoryStream stream) { - formatter.Serialize(stream, objRef); + serializer.WriteObject(stream, objRef); } [SecurityPermission(SecurityAction.Assert, Flags = SecurityPermissionFlag.SerializationFormatter)] - private static SqlDependencyProcessDispatcher GetDeserializedObject(BinaryFormatter formatter, MemoryStream stream) + private static SqlDependencyProcessDispatcher GetDeserializedObject(DataContractSerializer serializer, MemoryStream stream) { - // Use a custom SerializationBinder to restrict deserialized types to SqlDependencyProcessDispatcher. - formatter.Binder = new SqlDependencyProcessDispatcherSerializationBinder(); - object result = formatter.Deserialize(stream); - Debug.Assert(result.GetType() == typeof(SqlDependencyProcessDispatcher), "Unexpected type stored in native!"); - return (SqlDependencyProcessDispatcher)result; + object refResult = serializer.ReadObject(stream); + var result = RemotingServices.Unmarshal((refResult as SqlClientObjRef).GetObjRef()); + return result as SqlDependencyProcessDispatcher; } // ------------------------- @@ -1325,7 +1337,6 @@ private void AddCommandInternal(SqlCommand cmd) { if (cmd != null) { - // Don't bother with BID if command null. long scopeID = SqlClientEventSource.Log.TryNotificationScopeEnterEvent(" {0}, SqlCommand: {1}", ObjectID, cmd.ObjectID); try { diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Resources/Strings.Designer.cs b/src/Microsoft.Data.SqlClient/netfx/src/Resources/Strings.Designer.cs index acfa8b7778..3ac23773ca 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Resources/Strings.Designer.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Resources/Strings.Designer.cs @@ -10917,6 +10917,15 @@ internal static string SqlDependency_SqlDependency { } } + /// + /// Looks up a localized string similar to Unexpected type detected on deserialize.. + /// + internal static string SqlDependency_UnexpectedValueOnDeserialize { + get { + return ResourceManager.GetString("SqlDependency_UnexpectedValueOnDeserialize", resourceCulture); + } + } + /// /// Looks up a localized string similar to The process cannot access the file specified because it has been opened in another transaction.. /// diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Resources/Strings.resx b/src/Microsoft.Data.SqlClient/netfx/src/Resources/Strings.resx index 3c2d2c8e0d..cdb6c2c741 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Resources/Strings.resx +++ b/src/Microsoft.Data.SqlClient/netfx/src/Resources/Strings.resx @@ -4602,4 +4602,7 @@ Failed after 5 retries. - \ No newline at end of file + + Unexpected type detected on deserialize. + + diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlExceptionTest.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlExceptionTest.cs index 97cd1909ab..328f3643c2 100644 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlExceptionTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlExceptionTest.cs @@ -33,7 +33,6 @@ public void SerializationTest() Assert.Equal(e.StackTrace, sqlEx.StackTrace); } - [Fact] [ActiveIssue("12161", TestPlatforms.AnyUnix)] public static void SqlExcpetionSerializationTest()