From 3193860710ba5b4d8c01bbc1a455a8f577b94593 Mon Sep 17 00:00:00 2001 From: Joel Sing Date: Thu, 21 Apr 2022 06:35:05 +1000 Subject: [PATCH] Provide correct pty/tty file paths on OpenBSD While here, add test coverage for opening the TTY from the given filename. --- doc_test.go | 39 +++++++++++++++++++++++++++++++++++++++ pty_openbsd.go | 15 +++++++++++++-- 2 files changed, 52 insertions(+), 2 deletions(-) diff --git a/doc_test.go b/doc_test.go index b9a4486..a1d915a 100644 --- a/doc_test.go +++ b/doc_test.go @@ -3,6 +3,7 @@ package pty import ( "bytes" "io" + "os" "testing" ) @@ -65,6 +66,44 @@ func TestName(t *testing.T) { } } +// TestOpenByName ensures that the name associated with the tty is valid +// and can be opened and used if passed by file name (rather than passing +// the existing open file descriptor). +func TestOpenByName(t *testing.T) { + t.Parallel() + + pty, tty, err := Open() + if err != nil { + t.Fatal(err) + } + defer pty.Close() + defer tty.Close() + + ttyFile, err := os.OpenFile(tty.Name(), os.O_RDWR, 0600) + if err != nil { + t.Fatalf("Failed to open tty file: %v", err) + } + defer ttyFile.Close() + + // Ensure we can write to the newly opened tty file and read on the pty. + text := []byte("ping") + n, err := ttyFile.Write(text) + if err != nil { + t.Errorf("Unexpected error from Write: %s", err) + } + if n != len(text) { + t.Errorf("Unexpected count returned from Write, got %d expected %d", n, len(text)) + } + + buffer := make([]byte, len(text)) + if err := readBytes(pty, buffer); err != nil { + t.Errorf("Unexpected error from readBytes: %s", err) + } + if !bytes.Equal(text, buffer) { + t.Errorf("Unexpected result returned from Read, got %v expected %v", buffer, text) + } +} + func TestGetsize(t *testing.T) { t.Parallel() diff --git a/pty_openbsd.go b/pty_openbsd.go index 031367a..aada5e3 100644 --- a/pty_openbsd.go +++ b/pty_openbsd.go @@ -9,6 +9,17 @@ import ( "unsafe" ) +func cInt8ToString(in []int8) string { + var s []byte + for _, v := range in { + if v == 0 { + break + } + s = append(s, byte(v)) + } + return string(s) +} + func open() (pty, tty *os.File, err error) { /* * from ptm(4): @@ -29,8 +40,8 @@ func open() (pty, tty *os.File, err error) { return nil, nil, err } - pty = os.NewFile(uintptr(ptm.Cfd), "/dev/ptm") - tty = os.NewFile(uintptr(ptm.Sfd), "/dev/ptm") + pty = os.NewFile(uintptr(ptm.Cfd), cInt8ToString(ptm.Cn[:])) + tty = os.NewFile(uintptr(ptm.Sfd), cInt8ToString(ptm.Sn[:])) return pty, tty, nil }