deadlinks/client.go
2023-12-30 12:24:53 +01:00

279 lines
6.4 KiB
Go

package deadlinks
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net/http"
"git.sr.ht/~adnano/go-gemini"
)
// Client is a thread-safe type which fetches a resource at the given URL,
// returning its MIME type and body. If the MIME type is not known then empty
// string should be returned.
type Client interface {
Get(context.Context, URL) (string, io.ReadCloser, error)
}
// ClientOpts are optional fields which can be provided to NewClient. A nil
// ClientOpts is equivalent to an empty one.
type ClientOpts struct {
// GeminiClient will be used for retrieving resources via the gemini
// protocol.
//
// Defaults to `new(gemini.Client)`.
GeminiClient interface {
Do(context.Context, *gemini.Request) (*gemini.Response, error)
}
// HTTPClient will be used for retrieving resources via the http protocol.
//
// Defaults to `new(http.Client)`.
HTTPClient interface {
Do(*http.Request) (*http.Response, error)
}
// MaxRedirects indicates the maximum number of redirects which will be
// allowed when resolving a resource. A negative value indicates no
// redirects are allowed.
//
// Default: 10.
MaxRedirects int
}
func (o *ClientOpts) withDefaults() *ClientOpts {
if o == nil {
o = new(ClientOpts)
}
if o.GeminiClient == nil {
o.GeminiClient = new(gemini.Client)
}
if o.HTTPClient == nil {
o.HTTPClient = new(http.Client)
}
if o.MaxRedirects == 0 {
o.MaxRedirects = 10
}
return o
}
type client struct {
opts ClientOpts
}
// NewClient initializes and returns a Client which supports commonly used
// transport protocols. The returned Client will error when it encounters an
// unfamiliar protocol.
//
// Supported URL schemas:
// - gemini
// - http/https
//
// Some schemas automatically return success:
// - mailto
// - data
func NewClient(opts *ClientOpts) Client {
return &client{*opts.withDefaults()}
}
func emptyReadCloser() io.ReadCloser {
return io.NopCloser(new(bytes.Buffer))
}
func (c *client) getGemini(
ctx context.Context, url URL, redirectDepth int,
) (
string, io.ReadCloser, error,
) {
req, err := gemini.NewRequest(string(url))
if err != nil {
return "", nil, fmt.Errorf("building request: %w", err)
}
// TODO allow specifying client cert
res, err := c.opts.GeminiClient.Do(ctx, req)
if err != nil {
return "", nil, fmt.Errorf("performing request: %w", err)
}
// all status numbers are grouped by their first digit, and actions taken
// can be entirely based on that.
switch res.Status / 10 {
case 1: // input required
// Assume that input required is fine, even though we don't know the
// MIME type.
defer res.Body.Close()
return "", emptyReadCloser(), nil
case 2: // success
return res.Meta, res.Body, nil
case 3: // redirect
defer res.Body.Close()
if redirectDepth >= c.opts.MaxRedirects {
return "", nil, errors.New("too many redirects")
}
metaURL, err := ParseURL(res.Meta)
if err != nil {
return "", nil, fmt.Errorf("parsing redirect URL %q: %w", res.Meta, err)
}
newURL := url.ResolveReference(metaURL)
return c.get(ctx, newURL, redirectDepth+1)
default:
defer res.Body.Close()
return "", nil, fmt.Errorf(
"response code %d (%v): %q", res.Status, res.Status, res.Meta,
)
}
}
type concatHTTPBody struct {
orig io.ReadCloser
multi io.Reader
}
func (b concatHTTPBody) Read(bb []byte) (int, error) {
return b.multi.Read(bb)
}
func (b concatHTTPBody) Close() error {
return b.orig.Close()
}
func httpResponseMIMEType(res *http.Response) (string, error) {
if t := res.Header.Get("Content-Type"); t != "" {
return t, nil
}
// content type header not provided, do mime type sniffing.
// http.DetectContentType only requires up to the first 512 bytes of the
// body, according to its documentation, so we pull that off.
head := new(bytes.Buffer)
_, err := io.CopyN(head, res.Body, 512)
if err != nil && !errors.Is(err, io.EOF) {
return "", fmt.Errorf("reading head of response body: %w", err)
}
mimeType := http.DetectContentType(head.Bytes())
// since some of the body has been read off the original reader, we have to
// re-concattenate that portion to the beginning of the stream.
res.Body = concatHTTPBody{
orig: res.Body,
multi: io.MultiReader(head, res.Body),
}
return mimeType, nil
}
func (c *client) getHTTP(
ctx context.Context, url URL, redirectDepth int,
) (
string, io.ReadCloser, error,
) {
req, err := http.NewRequestWithContext(ctx, "GET", string(url), nil)
if err != nil {
return "", nil, fmt.Errorf("building request: %w", err)
}
res, err := c.opts.HTTPClient.Do(req)
if err != nil {
return "", nil, fmt.Errorf("performing request: %w", err)
}
mimeType, err := httpResponseMIMEType(res)
if err != nil {
res.Body.Close()
return "", nil, fmt.Errorf("determining response MIME type: %w", err)
}
statusCodeCategory := res.StatusCode / 100
switch {
case statusCodeCategory == 1: // informational
defer res.Body.Close()
return "", emptyReadCloser(), nil
case statusCodeCategory == 2: // success
return mimeType, res.Body, nil
// redirects
case res.StatusCode == 301,
res.StatusCode == 302,
res.StatusCode == 307,
res.StatusCode == 308:
defer res.Body.Close()
loc, err := res.Location()
if err != nil {
return "", nil, fmt.Errorf(
"getting Location header of response with code %v", res.StatusCode,
)
}
locURL, err := ParseURL(loc.String())
if err != nil {
return "", nil, fmt.Errorf("parsing redirect URL %v: %w", loc, err)
}
newURL := url.ResolveReference(locURL)
return c.get(ctx, newURL, redirectDepth+1)
case statusCodeCategory == 3: // unsupported redirections
defer res.Body.Close()
return "", emptyReadCloser(), nil
// all other response codes, 4xx and 5xx, are considered errors
default:
defer res.Body.Close()
return "", nil, fmt.Errorf(
"response code %d (%v)", res.StatusCode, res.Status,
)
}
}
func (c *client) noOpGet() (
string, io.ReadCloser, error,
) {
return "", emptyReadCloser(), nil
}
func (c *client) get(
ctx context.Context, url URL, redirectDepth int,
) (
string, io.ReadCloser, error,
) {
scheme := url.toStd().Scheme
switch scheme {
case "gemini":
return c.getGemini(ctx, url, redirectDepth)
case "http", "https":
return c.getHTTP(ctx, url, redirectDepth)
case "mailto", "data":
return c.noOpGet()
default:
return "", nil, fmt.Errorf("unsupported scheme %q", scheme)
}
}
func (c *client) Get(
ctx context.Context, url URL,
) (
string, io.ReadCloser, error,
) {
return c.get(ctx, url, 0)
}