diff --git a/machine_types.go b/machine_types.go index ff458a3..8775cf9 100644 --- a/machine_types.go +++ b/machine_types.go @@ -363,6 +363,7 @@ func (mg *MachineGuest) SetSize(size string) error { mg.CPUKind = guest.CPUKind mg.MemoryMB = guest.MemoryMB mg.GPUKind = guest.GPUKind + mg.GPUs = guest.GPUs return nil } diff --git a/machine_types_test.go b/machine_types_test.go index ee19a46..737be94 100644 --- a/machine_types_test.go +++ b/machine_types_test.go @@ -130,6 +130,18 @@ func TestMachineGuest_SetSize(t *testing.T) { t.Error("want error for invalid preset name") } + // Set GPU related fields that must be unset for non-gpu-size-alias + if err := guest.SetSize("a100-40gb"); err != nil { + t.Errorf("got error for valid preset name: %v", err) + } else { + if guest.GPUs != 1 { + t.Errorf("Expected 1 gpu, got: %v", guest.GPUs) + } + if guest.GPUKind != "a100-pcie-40gb" { + t.Errorf("Expected a100-pcie-40gb gpu kind, got: %v", guest.GPUKind) + } + } + if err := guest.SetSize("performance-4x"); err != nil { t.Errorf("got error for valid preset name: %v", err) } else { @@ -142,6 +154,12 @@ func TestMachineGuest_SetSize(t *testing.T) { if guest.MemoryMB != 8192 { t.Errorf("Expected 8192 MB of memory , got: %v", guest.MemoryMB) } + if guest.GPUs != 0 { + t.Errorf("Expected 0 gpus, got: %v", guest.GPUs) + } + if guest.GPUKind != "" { + t.Errorf("Expected non gpu kind, got: %v", guest.GPUKind) + } } }