diff --git a/pkg/core/network.go b/pkg/core/network.go index eae0e6cb..066c78f0 100644 --- a/pkg/core/network.go +++ b/pkg/core/network.go @@ -520,20 +520,20 @@ func (n *NetworkCommand) NeedApplyTC() bool { } } -func (n *NetworkCommand) AdditionalChain(ipset string, uid string) ([]*pb.Chain, error) { +func (n *NetworkCommand) AdditionalChain(ipset, device, uid string) ([]*pb.Chain, error) { chains := make([]*pb.Chain, 0, 2) var toChains, fromChains []*pb.Chain var err error if n.Direction == "to" || n.Direction == "both" { - toChains, err = n.getAdditionalChain(ipset, "to", uid) + toChains, err = n.getAdditionalChain(ipset, device, "to", uid) if err != nil { return nil, err } } if n.Direction == "from" || n.Direction == "both" { - fromChains, err = n.getAdditionalChain(ipset, "from", uid) + fromChains, err = n.getAdditionalChain(ipset, device, "from", uid) if err != nil { return nil, err } @@ -545,7 +545,7 @@ func (n *NetworkCommand) AdditionalChain(ipset string, uid string) ([]*pb.Chain, return chains, nil } -func (n *NetworkCommand) getAdditionalChain(ipset, direction string, uid string) ([]*pb.Chain, error) { +func (n *NetworkCommand) getAdditionalChain(ipset, device, direction, uid string) ([]*pb.Chain, error) { var directionStr string var directionChain pb.Chain_Direction if direction == "to" { @@ -569,6 +569,7 @@ func (n *NetworkCommand) getAdditionalChain(ipset, direction string, uid string) Protocol: n.IPProtocol, TcpFlags: n.AcceptTCPFlags, Target: "ACCEPT", + Device: device, }) } @@ -579,6 +580,7 @@ func (n *NetworkCommand) getAdditionalChain(ipset, direction string, uid string) Direction: directionChain, Protocol: n.IPProtocol, Target: "DROP", + Device: device, }) } return chains, nil diff --git a/pkg/core/network_test.go b/pkg/core/network_test.go index f1d2bc83..b98ccee3 100644 --- a/pkg/core/network_test.go +++ b/pkg/core/network_test.go @@ -130,7 +130,7 @@ func TestPatitionChain(t *testing.T) { }, } for _, tc := range testCases { - chains, err := tc.cmd.AdditionalChain("test", "3c5528e1-4c32-4f80-983c-913ad7e860e2") + chains, err := tc.cmd.AdditionalChain("test", "eth0", "3c5528e1-4c32-4f80-983c-913ad7e860e2") if err != nil { t.Errorf("failed to partition chain: %v", err) } diff --git a/pkg/server/chaosd/network.go b/pkg/server/chaosd/network.go index c99e19c3..ef93b232 100644 --- a/pkg/server/chaosd/network.go +++ b/pkg/server/chaosd/network.go @@ -141,7 +141,7 @@ func (s *Server) applyIptables(attack *core.NetworkCommand, ipset, uid string) e var newChains []*pb.Chain // Presently, only partition and delay with `accept-tcp-flags` need to add additional chains if attack.NeedAdditionalChains() { - newChains, err = attack.AdditionalChain(ipset, uid) + newChains, err = attack.AdditionalChain(ipset, attack.Device, uid) if err != nil { return perrors.WithStack(err) }