diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDependency.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDependency.cs
index 52bcec0599..938f43e898 100644
--- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDependency.cs
+++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDependency.cs
@@ -10,7 +10,6 @@
using System.Runtime.CompilerServices;
using System.Runtime.Remoting;
using System.Runtime.Serialization;
-using System.Runtime.Serialization.Formatters.Binary;
using System.Runtime.Versioning;
using System.Security.Permissions;
using System.Text;
@@ -241,29 +240,39 @@ private static void InvokeCallback(object eventContextPair)
// END EventContextPair private class.
// ----------------------------------------
- // ----------------------------------------
- // Private class for restricting allowed types from deserialization.
- // ----------------------------------------
-
- private class SqlDependencyProcessDispatcherSerializationBinder : SerializationBinder
+ //-----------------------------------------------
+ // Private Class to add ObjRef as DataContract
+ //-----------------------------------------------
+ [SecurityPermission(SecurityAction.Assert, Flags = SecurityPermissionFlag.RemotingConfiguration)]
+ [DataContract]
+ private class SqlClientObjRef
{
- public override Type BindToType(string assemblyName, string typeName)
+ [DataMember]
+ private static ObjRef s_sqlObjRef;
+ internal static IRemotingTypeInfo _typeInfo;
+
+ private SqlClientObjRef() { }
+
+ public SqlClientObjRef(SqlDependencyProcessDispatcher dispatcher) : base()
{
- // Deserializing an unexpected type can inject objects with malicious side effects.
- // If the type is unexpected, throw an exception to stop deserialization.
- if (typeName == nameof(SqlDependencyProcessDispatcher))
- {
- return typeof(SqlDependencyProcessDispatcher);
- }
- else
- {
- throw new ArgumentException("Unexpected type", nameof(typeName));
- }
+ s_sqlObjRef = RemotingServices.Marshal(dispatcher);
+ _typeInfo = s_sqlObjRef.TypeInfo;
+ }
+
+ internal static bool CanCastToSqlDependencyProcessDispatcher()
+ {
+ return _typeInfo.CanCastTo(typeof(SqlDependencyProcessDispatcher), s_sqlObjRef);
}
+
+ internal ObjRef GetObjRef()
+ {
+ return s_sqlObjRef;
+ }
+
}
- // ----------------------------------------
- // END SqlDependencyProcessDispatcherSerializationBinder private class.
- // ----------------------------------------
+ // ------------------------------------------
+ // End SqlClientObjRef private class.
+ // -------------------------------------------
// ----------------
// Instance members
@@ -306,10 +315,9 @@ public override Type BindToType(string assemblyName, string typeName)
private static readonly string _typeName = (typeof(SqlDependencyProcessDispatcher)).FullName;
// -----------
- // BID members
+ // EventSource members
// -----------
-
private readonly int _objectID = System.Threading.Interlocked.Increment(ref _objectTypeCount);
private static int _objectTypeCount; // EventSource Counter
internal int ObjectID
@@ -336,7 +344,7 @@ public SqlDependency(SqlCommand command) : this(command, null, SQL.SqlDependency
}
///
- [System.Security.Permissions.HostProtectionAttribute(ExternalThreading = true)]
+ [HostProtection(ExternalThreading = true)]
public SqlDependency(SqlCommand command, string options, int timeout)
{
long scopeID = SqlClientEventSource.Log.TryNotificationScopeEnterEvent(" {0}, options: '{1}', timeout: '{2}'", ObjectID, options, timeout);
@@ -597,11 +605,13 @@ private static void ObtainProcessDispatcher()
_processDispatcher = dependency.SingletonProcessDispatcher; // Set to static instance.
// Serialize and set in native.
- ObjRef objRef = GetObjRef(_processDispatcher);
- BinaryFormatter formatter = new BinaryFormatter();
- MemoryStream stream = new MemoryStream();
- GetSerializedObject(objRef, formatter, stream);
- SNINativeMethodWrapper.SetData(stream.GetBuffer()); // Native will be forced to synchronize and not overwrite.
+ using (MemoryStream stream = new MemoryStream())
+ {
+ SqlClientObjRef objRef = new SqlClientObjRef(_processDispatcher);
+ DataContractSerializer serializer = new DataContractSerializer(objRef.GetType());
+ GetSerializedObject(objRef, serializer, stream);
+ SNINativeMethodWrapper.SetData(stream.ToArray()); // Native will be forced to synchronize and not overwrite.
+ }
}
else
{
@@ -628,10 +638,20 @@ private static void ObtainProcessDispatcher()
#if DEBUG // Possibly expensive, limit to debug.
SqlClientEventSource.Log.TryNotificationTraceEvent(" AppDomain.CurrentDomain.FriendlyName: {0}", AppDomain.CurrentDomain.FriendlyName);
#endif
- BinaryFormatter formatter = new BinaryFormatter();
- MemoryStream stream = new MemoryStream(nativeStorage);
- _processDispatcher = GetDeserializedObject(formatter, stream); // Deserialize and set for appdomain.
- SqlClientEventSource.Log.TryNotificationTraceEvent(" processDispatcher obtained, ID: {0}", _processDispatcher.ObjectID);
+ using (MemoryStream stream = new MemoryStream(nativeStorage))
+ {
+ DataContractSerializer serializer = new DataContractSerializer(typeof(SqlClientObjRef));
+ if (SqlClientObjRef.CanCastToSqlDependencyProcessDispatcher())
+ {
+ // Deserialize and set for appdomain.
+ _processDispatcher = GetDeserializedObject(serializer, stream);
+ }
+ else
+ {
+ throw new ArgumentException(Strings.SqlDependency_UnexpectedValueOnDeserialize);
+ }
+ SqlClientEventSource.Log.TryNotificationTraceEvent(" processDispatcher obtained, ID: {0}", _processDispatcher.ObjectID);
+ }
}
}
@@ -639,26 +659,18 @@ private static void ObtainProcessDispatcher()
// Static security asserted methods - limit scope of assert.
// ---------------------------------------------------------
- [SecurityPermission(SecurityAction.Assert, Flags = SecurityPermissionFlag.RemotingConfiguration)]
- private static ObjRef GetObjRef(SqlDependencyProcessDispatcher _processDispatcher)
- {
- return RemotingServices.Marshal(_processDispatcher);
- }
-
[SecurityPermission(SecurityAction.Assert, Flags = SecurityPermissionFlag.SerializationFormatter)]
- private static void GetSerializedObject(ObjRef objRef, BinaryFormatter formatter, MemoryStream stream)
+ private static void GetSerializedObject(SqlClientObjRef objRef, DataContractSerializer serializer, MemoryStream stream)
{
- formatter.Serialize(stream, objRef);
+ serializer.WriteObject(stream, objRef);
}
[SecurityPermission(SecurityAction.Assert, Flags = SecurityPermissionFlag.SerializationFormatter)]
- private static SqlDependencyProcessDispatcher GetDeserializedObject(BinaryFormatter formatter, MemoryStream stream)
+ private static SqlDependencyProcessDispatcher GetDeserializedObject(DataContractSerializer serializer, MemoryStream stream)
{
- // Use a custom SerializationBinder to restrict deserialized types to SqlDependencyProcessDispatcher.
- formatter.Binder = new SqlDependencyProcessDispatcherSerializationBinder();
- object result = formatter.Deserialize(stream);
- Debug.Assert(result.GetType() == typeof(SqlDependencyProcessDispatcher), "Unexpected type stored in native!");
- return (SqlDependencyProcessDispatcher)result;
+ object refResult = serializer.ReadObject(stream);
+ var result = RemotingServices.Unmarshal((refResult as SqlClientObjRef).GetObjRef());
+ return result as SqlDependencyProcessDispatcher;
}
// -------------------------
@@ -1325,7 +1337,6 @@ private void AddCommandInternal(SqlCommand cmd)
{
if (cmd != null)
{
- // Don't bother with BID if command null.
long scopeID = SqlClientEventSource.Log.TryNotificationScopeEnterEvent(" {0}, SqlCommand: {1}", ObjectID, cmd.ObjectID);
try
{
diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Resources/Strings.Designer.cs b/src/Microsoft.Data.SqlClient/netfx/src/Resources/Strings.Designer.cs
index acfa8b7778..3ac23773ca 100644
--- a/src/Microsoft.Data.SqlClient/netfx/src/Resources/Strings.Designer.cs
+++ b/src/Microsoft.Data.SqlClient/netfx/src/Resources/Strings.Designer.cs
@@ -10917,6 +10917,15 @@ internal static string SqlDependency_SqlDependency {
}
}
+ ///
+ /// Looks up a localized string similar to Unexpected type detected on deserialize..
+ ///
+ internal static string SqlDependency_UnexpectedValueOnDeserialize {
+ get {
+ return ResourceManager.GetString("SqlDependency_UnexpectedValueOnDeserialize", resourceCulture);
+ }
+ }
+
///
/// Looks up a localized string similar to The process cannot access the file specified because it has been opened in another transaction..
///
diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Resources/Strings.resx b/src/Microsoft.Data.SqlClient/netfx/src/Resources/Strings.resx
index 3c2d2c8e0d..cdb6c2c741 100644
--- a/src/Microsoft.Data.SqlClient/netfx/src/Resources/Strings.resx
+++ b/src/Microsoft.Data.SqlClient/netfx/src/Resources/Strings.resx
@@ -4602,4 +4602,7 @@
Failed after 5 retries.
-
\ No newline at end of file
+
+ Unexpected type detected on deserialize.
+
+
diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlExceptionTest.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlExceptionTest.cs
index 97cd1909ab..328f3643c2 100644
--- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlExceptionTest.cs
+++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlExceptionTest.cs
@@ -33,7 +33,6 @@ public void SerializationTest()
Assert.Equal(e.StackTrace, sqlEx.StackTrace);
}
-
[Fact]
[ActiveIssue("12161", TestPlatforms.AnyUnix)]
public static void SqlExcpetionSerializationTest()