diff --git a/odl.go b/odl.go index 912c099..027f2f1 100644 --- a/odl.go +++ b/odl.go @@ -2,9 +2,11 @@ package main import ( "bufio" + "bytes" "fmt" "io" "log" + "mime" "net/http" "net/url" "os" @@ -13,76 +15,83 @@ import ( ) func main() { - Download(Read(".", os.Stdin)) + ReadAndDownload(".", os.Stdin) } -type request struct { - w io.Writer - url string -} +func ReadAndDownload(root string, r io.Reader) { + var n int + var name string + scanner := bufio.NewScanner(r) + for scanner.Scan() { + txt := scanner.Text() -func Download(c <-chan request) { - dl := func(w io.Writer, url string) error { - response, err := http.Get(url) + if name == "" || txt == "" { + name = txt + n = 0 + continue + } + err := download(n, root, name, txt) if err != nil { - return fmt.Errorf("downloading file %s: %v", url, err) + log.Println(err) } + n++ + } +} + +func download(n int, root, name, uri string) error { + u, err := url.Parse(uri) + if err != nil { + return fmt.Errorf("parsing uri %s: %v", uri, err) + } + ext := filepath.Ext(u.Path) + + response, err := http.Get(uri) + switch { + case err != nil: + return fmt.Errorf("downloading file %s: %v", uri, err) + case response.StatusCode != http.StatusOK: + return fmt.Errorf("downloading file %s: status code %d", uri, response.StatusCode) + default: defer response.Body.Close() + } - if response.StatusCode != http.StatusOK { - return fmt.Errorf("unexpected status code for %s: %d", url, response.StatusCode) + var r io.Reader = response.Body + if ext == "" { + var err error + ext, r, err = guessExt(response.Body) + if err != nil { + return err } + } - _, err = io.Copy(w, response.Body) + filename := filepath.Join(root, name+strconv.Itoa(n)+ext) + f, err := os.Create(filename) + if err != nil { return err } - for r := range c { - if err := dl(r.w, r.url); err != nil { - log.Println(err) - } - } + _, err = io.Copy(f, r) + return err } -func Read(root string, r io.Reader) <-chan request { - scanner := bufio.NewScanner(r) - - var n int - var name string - c := make(chan request) - - create := func(name, uri string, n int) (io.Writer, error) { - var f io.Writer +func guessExt(r io.Reader) (string, io.Reader, error) { + head, headCopy := make([]byte, 512), make([]byte, 512) + n, readErr := io.ReadFull(r, head) + head = head[:n] + copy(headCopy, head) + hr := bytes.NewReader(headCopy) - u, err := url.Parse(uri) - if err != nil { - return f, fmt.Errorf("parsing url %s: %v", uri, err) - } + t := http.DetectContentType(head) + ext, err := mime.ExtensionsByType(t) - filename := filepath.Join(root, name+strconv.Itoa(n)+filepath.Ext(u.Path)) - return os.Create(filename) + if readErr == io.EOF || readErr == io.ErrUnexpectedEOF { + return ext[0], hr, nil } - go func() { - - for scanner.Scan() { - txt := scanner.Text() - - if name == "" || txt == "" { - name = txt - n = 0 - continue - } + if err != nil { + return "", nil, err + } - f, err := create(name, txt, n) - if err != nil { - log.Println(err) - continue - } - c <- request{w: f, url: txt} - n++ - } - close(c) - }() - return c + return ext[0], io.MultiReader(hr, r), nil } + diff --git a/odl_test.go b/odl_test.go index 92e6226..1b3bd68 100644 --- a/odl_test.go +++ b/odl_test.go @@ -1,7 +1,6 @@ package main import ( - "bytes" "net/http" "net/http/httptest" "os" @@ -14,44 +13,28 @@ var fs = http.FileServer(http.Dir("./testdata")) var srv = httptest.NewServer(fs) -func TestDownload(t *testing.T) { - var ( - b bytes.Buffer - c = make(chan request) - ) - - url := srv.URL + "/somefile" - t.Logf("file URL: %s", url) - - go func() { c <- request{&b, url}; close(c) }() - Download(c) - - want := "somefile inner data" - if got := b.String(); got != want { - t.Errorf("file buffer received unexpected value\nGot:\n%s\nWant:\n%s", got, want) +func TestReadAndDownload(t *testing.T) { + lines := []string{ + "text", + srv.URL + "/sometext.txt", + srv.URL + "/moretext.txt", + "", + "image", + srv.URL + "/tiny-fuji", } -} - -func TestRead(t *testing.T) { - input := `boston -http://localhost/someboston.jpg -http://localhost/anotherboston.mp4 + input := strings.Join(lines, "\n") -orlando -http://localhost/someorlando.jpg -` dir := t.TempDir() t.Logf("root directory: %s", dir) + t.Logf("\ninput:\n%s\n", input) - c := Read(dir, strings.NewReader(input)) - for range c { - } // drain the channel + ReadAndDownload(dir, strings.NewReader(input)) err := fstest.TestFS( os.DirFS(dir), - "boston0.jpg", - "boston1.mp4", - "orlando0.jpg", + "text0.txt", + "text1.txt", + "image0.webp", ) if err != nil { diff --git a/testdata/moretext.txt b/testdata/moretext.txt new file mode 100644 index 0000000..8c36bd1 --- /dev/null +++ b/testdata/moretext.txt @@ -0,0 +1,2 @@ +another file of text + diff --git a/testdata/somefile b/testdata/sometext.txt similarity index 100% rename from testdata/somefile rename to testdata/sometext.txt diff --git a/testdata/tiny-fuji b/testdata/tiny-fuji new file mode 100644 index 0000000..da97086 Binary files /dev/null and b/testdata/tiny-fuji differ