Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Store iptables when creating network attack #215

Merged
merged 7 commits into from
Sep 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions pkg/core/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"time"

"github.com/chaos-mesh/chaos-mesh/api/v1alpha1"
"github.com/chaos-mesh/chaos-mesh/controllers/podnetworkchaos/netutils"
"github.com/chaos-mesh/chaos-mesh/pkg/chaosdaemon/pb"
"github.com/chaos-mesh/chaos-mesh/pkg/netem"
"github.com/pingcap/errors"
Expand Down Expand Up @@ -510,10 +511,6 @@ func (n *NetworkCommand) NeedApplyIPSet() bool {
return false
}

func (n *NetworkCommand) NeedApplyIptables() bool {
return true
}

func (n *NetworkCommand) NeedApplyTC() bool {
switch n.Action {
case NetworkDelayAction, NetworkLossAction, NetworkCorruptAction, NetworkDuplicateAction, NetworkBandwidthAction:
Expand All @@ -523,20 +520,20 @@ func (n *NetworkCommand) NeedApplyTC() bool {
}
}

func (n *NetworkCommand) AdditionalChain(ipset string) ([]*pb.Chain, error) {
func (n *NetworkCommand) AdditionalChain(ipset string, uid string) ([]*pb.Chain, error) {
chains := make([]*pb.Chain, 0, 2)
var toChains, fromChains []*pb.Chain
var err error

if n.Direction == "to" || n.Direction == "both" {
toChains, err = n.getAdditionalChain(ipset, "to")
toChains, err = n.getAdditionalChain(ipset, "to", uid)
if err != nil {
return nil, err
}
}

if n.Direction == "from" || n.Direction == "both" {
fromChains, err = n.getAdditionalChain(ipset, "from")
fromChains, err = n.getAdditionalChain(ipset, "from", uid)
if err != nil {
return nil, err
}
Expand All @@ -548,7 +545,7 @@ func (n *NetworkCommand) AdditionalChain(ipset string) ([]*pb.Chain, error) {
return chains, nil
}

func (n *NetworkCommand) getAdditionalChain(ipset, direction string) ([]*pb.Chain, error) {
func (n *NetworkCommand) getAdditionalChain(ipset, direction string, uid string) ([]*pb.Chain, error) {
var directionStr string
var directionChain pb.Chain_Direction
if direction == "to" {
Expand All @@ -562,9 +559,11 @@ func (n *NetworkCommand) getAdditionalChain(ipset, direction string) ([]*pb.Chai
}

chains := make([]*pb.Chain, 0, 2)
// The `targetLength`s in `netutils.CompressName()` are different because of
// the need to distinguish between the different chains.
if len(n.AcceptTCPFlags) > 0 {
chains = append(chains, &pb.Chain{
Name: fmt.Sprintf("%s/0", directionStr),
Name: fmt.Sprintf("%s/%s", directionStr, netutils.CompressName(uid, 19, "")),
Ipsets: []string{ipset},
Direction: directionChain,
Protocol: n.IPProtocol,
Expand All @@ -575,7 +574,7 @@ func (n *NetworkCommand) getAdditionalChain(ipset, direction string) ([]*pb.Chai

if n.Action == NetworkPartitionAction {
chains = append(chains, &pb.Chain{
Name: fmt.Sprintf("%s/1", directionStr),
Name: fmt.Sprintf("%s/%s", directionStr, netutils.CompressName(uid, 20, "")),
Ipsets: []string{ipset},
Direction: directionChain,
Protocol: n.IPProtocol,
Expand Down
18 changes: 9 additions & 9 deletions pkg/core/network_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func TestPatitionChain(t *testing.T) {
},
chains: []*pb.Chain{
{
Name: "OUTPUT/1",
Name: "OUTPUT/3c552_e0172bc4fd046_",
Ipsets: []string{"test"},
Direction: pb.Chain_OUTPUT,
Protocol: "tcp",
Expand All @@ -53,7 +53,7 @@ func TestPatitionChain(t *testing.T) {
},
chains: []*pb.Chain{
{
Name: "INPUT/1",
Name: "INPUT/3c552_e0172bc4fd046_",
Ipsets: []string{"test"},
Direction: pb.Chain_INPUT,
Protocol: "tcp",
Expand All @@ -71,14 +71,14 @@ func TestPatitionChain(t *testing.T) {
},
chains: []*pb.Chain{
{
Name: "OUTPUT/1",
Name: "OUTPUT/3c552_e0172bc4fd046_",
Ipsets: []string{"test"},
Direction: pb.Chain_OUTPUT,
Protocol: "tcp",
Target: "DROP",
},
{
Name: "INPUT/1",
Name: "INPUT/3c552_e0172bc4fd046_",
Ipsets: []string{"test"},
Direction: pb.Chain_INPUT,
Protocol: "tcp",
Expand All @@ -97,30 +97,30 @@ func TestPatitionChain(t *testing.T) {
},
chains: []*pb.Chain{
{
Name: "OUTPUT/0",
Name: "OUTPUT/3c552_e0172bc4fd04_",
Ipsets: []string{"test"},
Direction: pb.Chain_OUTPUT,
Protocol: "tcp",
TcpFlags: "SYN,ACK SYN,ACK",
Target: "ACCEPT",
},
{
Name: "OUTPUT/1",
Name: "OUTPUT/3c552_e0172bc4fd046_",
Ipsets: []string{"test"},
Direction: pb.Chain_OUTPUT,
Protocol: "tcp",
Target: "DROP",
},
{
Name: "INPUT/0",
Name: "INPUT/3c552_e0172bc4fd04_",
Ipsets: []string{"test"},
Direction: pb.Chain_INPUT,
Protocol: "tcp",
TcpFlags: "SYN,ACK SYN,ACK",
Target: "ACCEPT",
},
{
Name: "INPUT/1",
Name: "INPUT/3c552_e0172bc4fd046_",
Ipsets: []string{"test"},
Direction: pb.Chain_INPUT,
Protocol: "tcp",
Expand All @@ -130,7 +130,7 @@ func TestPatitionChain(t *testing.T) {
},
}
for _, tc := range testCases {
chains, err := tc.cmd.AdditionalChain("test")
chains, err := tc.cmd.AdditionalChain("test", "3c5528e1-4c32-4f80-983c-913ad7e860e2")
if err != nil {
t.Errorf("failed to partition chain: %v", err)
}
Expand Down
72 changes: 37 additions & 35 deletions pkg/server/chaosd/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,13 @@ func (networkAttack) Attack(options core.AttackConfig, env Environment) (err err
}
}

if attack.NeedApplyIptables() {
if err = env.Chaos.applyIptables(attack, ipsetName, env.AttackUid); err != nil {
return perrors.WithStack(err)
}
if err = env.Chaos.applyIptables(attack, ipsetName, env.AttackUid); err != nil {
return perrors.WithStack(err)
}

if attack.NeedApplyTC() {
if err = env.Chaos.applyTC(attack, ipsetName, env.AttackUid); err != nil {
return perrors.WithStack(err)
}
// Because some tcs add filter iptables which will not be stored in the DB, we must re-apply these tcs to add the iptables.
if err = env.Chaos.applyTC(attack, ipsetName, env.AttackUid); err != nil {
return perrors.WithStack(err)
}

case core.NetworkNICDownAction:
Expand Down Expand Up @@ -140,9 +137,11 @@ func (s *Server) applyIptables(attack *core.NetworkCommand, ipset, uid string) e
return perrors.WithStack(err)
}
chains := core.IptablesRuleList(iptables).ToChains()

var newChains []*pb.Chain
// Presently, only partition and delay with `accept-tcp-flags` need to add additional chains
if attack.NeedAdditionalChains() {
newChains, err := attack.AdditionalChain(ipset)
newChains, err = attack.AdditionalChain(ipset, uid)
if err != nil {
return perrors.WithStack(err)
}
Expand All @@ -156,15 +155,17 @@ func (s *Server) applyIptables(attack *core.NetworkCommand, ipset, uid string) e
return perrors.WithStack(err)
}

// TODO: cwen0
//if err := s.iptablesRule.Set(context.Background(), &core.IptablesRule{
// Name: newChain.Name,
// IPSets: strings.Join(newChain.Ipsets, ","),
// Direction: pb.Chain_Direction_name[int32(newChain.Direction)],
// Experiment: uid,
//}); err != nil {
// return perrors.WithStack(err)
//}
for _, newChain := range newChains {
if err := s.iptablesRule.Set(context.Background(), &core.IptablesRule{
Name: newChain.Name,
IPSets: strings.Join(newChain.Ipsets, ","),
Direction: pb.Chain_Direction_name[int32(newChain.Direction)],
Protocol: newChain.Protocol,
Experiment: uid,
}); err != nil {
return perrors.WithStack(err)
}
}

return nil
}
Expand All @@ -180,17 +181,24 @@ func (s *Server) applyTC(attack *core.NetworkCommand, ipset string, uid string)
return perrors.WithStack(err)
}

newTC, err := attack.ToTC(ipset)
if err != nil {
return perrors.WithStack(err)
}
var newTC *pb.Tc
if attack.NeedApplyTC() {
newTC, err = attack.ToTC(ipset)
if err != nil {
return perrors.WithStack(err)
}

tcs = append(tcs, newTC)
tcs = append(tcs, newTC)
}

if _, err := s.svr.SetTcs(context.Background(), &pb.TcsRequest{Tcs: tcs, EnterNS: false}); err != nil {
return perrors.WithStack(err)
}

if !attack.NeedApplyTC() {
return nil
}

tc := &core.TcParameter{
Device: attack.Device,
}
Expand Down Expand Up @@ -380,22 +388,16 @@ func (networkAttack) Recover(exp core.Experiment, env Environment) error {
case core.NetworkPortOccupiedAction:
return env.Chaos.recoverPortOccupied(attack, env.AttackUid)
case core.NetworkDelayAction, core.NetworkLossAction, core.NetworkCorruptAction, core.NetworkDuplicateAction, core.NetworkPartitionAction, core.NetworkBandwidthAction:
if attack.NeedApplyIPSet() {
if err := env.Chaos.recoverIPSet(env.AttackUid); err != nil {
return perrors.WithStack(err)
}
if err := env.Chaos.recoverIPSet(env.AttackUid); err != nil {
return perrors.WithStack(err)
}

if attack.NeedApplyIptables() {
if err := env.Chaos.recoverIptables(env.AttackUid); err != nil {
return perrors.WithStack(err)
}
if err := env.Chaos.recoverIptables(env.AttackUid); err != nil {
return perrors.WithStack(err)
}

if attack.NeedApplyTC() {
if err := env.Chaos.recoverTC(env.AttackUid, attack.Device); err != nil {
return perrors.WithStack(err)
}
if err := env.Chaos.recoverTC(env.AttackUid, attack.Device); err != nil {
return perrors.WithStack(err)
}
case core.NetworkNICDownAction:
return env.Chaos.recoverNICDown(attack)
Expand Down