dndmusicbot/routes.go

311 lines
7.7 KiB
Go

package main
import (
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"net/netip"
"net/url"
"os"
"path"
"path/filepath"
"text/template"
"time"
"github.com/bitly/go-simplejson"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/gorilla/sessions"
"github.com/gorilla/websocket"
"github.com/julienschmidt/httprouter"
"github.com/markbates/goth"
"github.com/markbates/goth/gothic"
"github.com/markbates/goth/providers/discord"
"golang.org/x/exp/slices"
)
const COOKIE_NAME = "_dndmusicbot"
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}
func init() {
key := []byte(os.Getenv("SESSION_SECRET"))
maxAge := 86400 * 30 // 30 days
isProd := true // Set to true when serving over https
store := sessions.NewCookieStore([]byte(key))
store.MaxAge(maxAge)
store.Options.Path = "/"
store.Options.HttpOnly = true // HttpOnly should always be enabled
store.Options.Secure = isProd
gothic.Store = store
goth.UseProviders(
discord.New(config.GetString("discord.id"), config.GetString("discord.secret"), config.GetString("discord.callback"), discord.ScopeIdentify, discord.ScopeEmail, discord.ScopeGuilds, discord.ScopeReadGuilds),
)
app.router = httprouter.New()
app.router.GET("/", auth(app.Index))
app.router.GET("/playlists", auth(app.Web_Playlists))
app.router.GET("/ambiance", auth(app.Web_Ambiance))
app.router.GET("/play/:playlist", auth(app.Play))
app.router.GET("/reset", auth(app.Reset))
app.router.GET("/public/*js", auth(app.ServeFiles))
app.router.GET("/css/*css", auth(app.ServeFiles))
app.router.GET("/auth/callback", app.AuthHandler)
app.router.GET("/youtube/:id", local(ProxyTube))
app.router.HandlerFunc("GET", "/ws", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Printf("WS connection from %v\n", r.RemoteAddr)
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Println(err)
return
}
err = ws.join(conn)
if err != nil {
log.Printf("WS connection closed, %v\n", r.RemoteAddr)
}
}))
go func() {
log.Fatal(http.ListenAndServe(":"+config.GetString("web.port"), app.router))
}()
}
type IndexData struct {
Playlists []Playlist
Ambiance []Ambiance
}
func (app App) AuthHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
user, err := gothic.CompleteUserAuth(w, r)
if err != nil {
fmt.Fprintln(w, err)
return
}
profile_url := &url.URL{Scheme: "https",
Host: "discord.com",
Path: "/api/users/@me",
}
member_url := fmt.Sprintf("%s/guilds/%s/member", profile_url.String(), config.GetString("discord.guild"))
member_req, err := http.NewRequest("GET", member_url, nil)
if err != nil {
fmt.Fprintln(w, err)
return
}
member_req.Header.Add("Authorization", "Bearer "+user.AccessToken)
member_resp, err := http.DefaultClient.Do(member_req)
if err != nil {
fmt.Fprintln(w, err)
return
}
member_body, _ := io.ReadAll(member_resp.Body)
defer member_resp.Body.Close()
member, err := simplejson.NewJson(member_body)
if err != nil {
fmt.Fprintln(w, err)
return
}
groups, err := member.GetPath("roles").StringArray()
if err != nil {
fmt.Fprintln(w, err)
return
}
ok := false
for _, group := range config.GetStringSlice("discord.groups") {
if slices.Contains(groups, group) {
ok = true
}
}
if ok {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"nbf": time.Now().Unix(),
"exp": time.Now().Add(time.Hour * 720).Unix(),
})
tokenString, err := token.SignedString([]byte(os.Getenv("SESSION_SECRET")))
if err != nil {
fmt.Fprintln(w, err)
return
}
cookie := new(http.Cookie)
cookie.Name = COOKIE_NAME
cookie.Value = tokenString
cookie.Path = "/"
cookie.Secure = true
cookie.HttpOnly = true
cookie.MaxAge = 86400 * 30
http.SetCookie(w, cookie)
http.RedirectHandler("/", http.StatusFound).ServeHTTP(w, r)
}
}
// Middleware to check that the connection is local.
func local(n httprouter.Handle) httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
addr, err := netip.ParseAddrPort(r.RemoteAddr)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
switch {
case addr.Addr().IsLoopback():
fallthrough
case addr.Addr().IsPrivate():
n(w, r, ps)
default:
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
}
}
}
func auth(n httprouter.Handle) httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
if os.Getenv("APP_ENV") == "test" {
n(w, r, ps)
return
}
values := r.URL.Query()
values.Add("provider", "discord")
r.URL.RawQuery = values.Encode()
auth_cookie, err := r.Cookie(COOKIE_NAME)
if err == nil {
token, err := jwt.Parse(auth_cookie.Value, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(os.Getenv("SESSION_SECRET")), nil
})
if err != nil {
fmt.Fprintln(w, err)
return
}
if token.Valid {
n(w, r, ps)
} else {
gothic.BeginAuthHandler(w, r)
return
}
} else if err == http.ErrNoCookie {
gothic.BeginAuthHandler(w, r)
return
} else if err != nil {
fmt.Fprintln(w, err)
return
}
}
}
func (app App) ServeFiles(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
filePath := filepath.Join(".", r.URL.Path)
file, err := os.Open(filePath)
if err != nil {
log.Println(err)
http.Error(w, "no such file", http.StatusNotFound)
return
}
defer file.Close()
fileStat, err := os.Stat(filePath)
if err != nil {
log.Println(err)
http.Error(w, "unable to get file stat", http.StatusInternalServerError)
}
_, filename := path.Split(filePath)
t := fileStat.ModTime()
http.ServeContent(w, r, filename, t, file)
}
func (app App) Index(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
playlists, err := app.GetPlaylists()
if err != nil {
http.Error(w, "Unable to get playlists. "+err.Error(), http.StatusInternalServerError)
}
amblist, err := GetAmbiances()
if err != nil {
log.Println(err)
return
}
data := IndexData{playlists, amblist}
t := template.Must(template.New("index.tmpl").ParseFiles("tmpl/index.tmpl"))
err = t.Execute(w, data)
if err != nil {
http.Error(w, "Unable to load template. "+err.Error(), http.StatusInternalServerError)
}
}
func (app App) Web_Playlists(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
playlists, err := app.GetPlaylists()
if err != nil {
http.Error(w, "Unable to get playlists. "+err.Error(), http.StatusInternalServerError)
}
err = json.NewEncoder(w).Encode(playlists)
if err != nil {
http.Error(w, "Unable to get playlists. "+err.Error(), http.StatusInternalServerError)
}
}
func (app App) Web_Ambiance(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
ambiance, err := GetAmbiances()
if err != nil {
http.Error(w, "Unable to get ambiance. "+err.Error(), http.StatusInternalServerError)
}
err = json.NewEncoder(w).Encode(ambiance)
if err != nil {
http.Error(w, "Unable to get playlists. "+err.Error(), http.StatusInternalServerError)
}
}
func (app *App) Play(w http.ResponseWriter, r *http.Request, p httprouter.Params) {
plname := p.ByName("playlist")
if plname == "reset" {
app.events.Emit("stop", nil)
return
}
plid, err := uuid.ParseBytes([]byte(plname))
if err != nil {
http.Error(w, "Unable to parse uuid. "+err.Error(), http.StatusInternalServerError)
}
app.events.Emit("new_playlist", plid)
}
func (app *App) Add(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
r.ParseForm()
}
func (app *App) Reset(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
app.events.Emit("stop", nil)
}