From cf82dede3b80ff849ce33b99b02b304b691081e4 Mon Sep 17 00:00:00 2001 From: KoCoder Date: Thu, 21 Aug 2025 19:35:01 +0200 Subject: [PATCH] Bulk commit --- cmd/api/main.go | 4 +- model/mandant.go | 11 + model/user.go | 2 +- query/gen.go | 8 + query/mandants.gen.go | 411 +++++++++++++++++++++++++++++++++++++ routers/ansprechpartner.go | 8 +- routers/firma.go | 8 +- routers/mandant.go | 83 +++----- types/mandant.go | 9 + utils/applicationCtx.go | 12 +- utils/authentication.go | 52 ++++- utils/db.go | 4 +- utils/middleware.go | 5 +- 13 files changed, 540 insertions(+), 77 deletions(-) create mode 100644 model/mandant.go create mode 100644 query/mandants.gen.go create mode 100644 types/mandant.go diff --git a/cmd/api/main.go b/cmd/api/main.go index d776acd..15a6932 100644 --- a/cmd/api/main.go +++ b/cmd/api/main.go @@ -24,13 +24,13 @@ func main() { db := utils.SetupDatabase(os.Getenv("DB_DSN"), logger) - appCtx := &utils.Application{Logger: logger, DB: db} + appCtx := utils.Application{Logger: logger, DB: db} app := fiber.New() utils.RegisterMiddlewares(app) - utils.CreateOIDCClient(context.Background(), app, logger) + utils.CreateOIDCClient(context.Background(), app, appCtx) routers.RegisterMandantRouter(app.Group("/v1/mandant"), appCtx) routers.RegisterAnsprechpartnerRouter(app.Group("/v1/ansprechpartner"), appCtx) diff --git a/model/mandant.go b/model/mandant.go new file mode 100644 index 0000000..9e45d80 --- /dev/null +++ b/model/mandant.go @@ -0,0 +1,11 @@ +package model + +import ( + "git.kocoder.xyz/kocoded/vt/types" + "gorm.io/gorm" +) + +type Mandant struct { + gorm.Model + types.Mandant +} diff --git a/model/user.go b/model/user.go index f7c8560..e058490 100644 --- a/model/user.go +++ b/model/user.go @@ -4,6 +4,6 @@ import "gorm.io/gorm" type User struct { gorm.Model - Sub string `json:"sub"gorm:"unique"` + Sub string `json:"sub" gorm:"unique"` Email string } diff --git a/query/gen.go b/query/gen.go index fe42fea..e61b305 100644 --- a/query/gen.go +++ b/query/gen.go @@ -26,6 +26,7 @@ var ( Kostenstelle *kostenstelle Lager *lager Lagerplatz *lagerplatz + Mandant *mandant Material *material Nachricht *nachricht Projekt *projekt @@ -47,6 +48,7 @@ func SetDefault(db *gorm.DB, opts ...gen.DOOption) { Kostenstelle = &Q.Kostenstelle Lager = &Q.Lager Lagerplatz = &Q.Lagerplatz + Mandant = &Q.Mandant Material = &Q.Material Nachricht = &Q.Nachricht Projekt = &Q.Projekt @@ -69,6 +71,7 @@ func Use(db *gorm.DB, opts ...gen.DOOption) *Query { Kostenstelle: newKostenstelle(db, opts...), Lager: newLager(db, opts...), Lagerplatz: newLagerplatz(db, opts...), + Mandant: newMandant(db, opts...), Material: newMaterial(db, opts...), Nachricht: newNachricht(db, opts...), Projekt: newProjekt(db, opts...), @@ -92,6 +95,7 @@ type Query struct { Kostenstelle kostenstelle Lager lager Lagerplatz lagerplatz + Mandant mandant Material material Nachricht nachricht Projekt projekt @@ -116,6 +120,7 @@ func (q *Query) clone(db *gorm.DB) *Query { Kostenstelle: q.Kostenstelle.clone(db), Lager: q.Lager.clone(db), Lagerplatz: q.Lagerplatz.clone(db), + Mandant: q.Mandant.clone(db), Material: q.Material.clone(db), Nachricht: q.Nachricht.clone(db), Projekt: q.Projekt.clone(db), @@ -147,6 +152,7 @@ func (q *Query) ReplaceDB(db *gorm.DB) *Query { Kostenstelle: q.Kostenstelle.replaceDB(db), Lager: q.Lager.replaceDB(db), Lagerplatz: q.Lagerplatz.replaceDB(db), + Mandant: q.Mandant.replaceDB(db), Material: q.Material.replaceDB(db), Nachricht: q.Nachricht.replaceDB(db), Projekt: q.Projekt.replaceDB(db), @@ -168,6 +174,7 @@ type queryCtx struct { Kostenstelle IKostenstelleDo Lager ILagerDo Lagerplatz ILagerplatzDo + Mandant IMandantDo Material IMaterialDo Nachricht INachrichtDo Projekt IProjektDo @@ -189,6 +196,7 @@ func (q *Query) WithContext(ctx context.Context) *queryCtx { Kostenstelle: q.Kostenstelle.WithContext(ctx), Lager: q.Lager.WithContext(ctx), Lagerplatz: q.Lagerplatz.WithContext(ctx), + Mandant: q.Mandant.WithContext(ctx), Material: q.Material.WithContext(ctx), Nachricht: q.Nachricht.WithContext(ctx), Projekt: q.Projekt.WithContext(ctx), diff --git a/query/mandants.gen.go b/query/mandants.gen.go new file mode 100644 index 0000000..e806c6a --- /dev/null +++ b/query/mandants.gen.go @@ -0,0 +1,411 @@ +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. + +package query + +import ( + "context" + "database/sql" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + + "gorm.io/gen" + "gorm.io/gen/field" + + "gorm.io/plugin/dbresolver" + + "git.kocoder.xyz/kocoded/vt/model" +) + +func newMandant(db *gorm.DB, opts ...gen.DOOption) mandant { + _mandant := mandant{} + + _mandant.mandantDo.UseDB(db, opts...) + _mandant.mandantDo.UseModel(&model.Mandant{}) + + tableName := _mandant.mandantDo.TableName() + _mandant.ALL = field.NewAsterisk(tableName) + _mandant.ID = field.NewUint(tableName, "id") + _mandant.CreatedAt = field.NewTime(tableName, "created_at") + _mandant.UpdatedAt = field.NewTime(tableName, "updated_at") + _mandant.DeletedAt = field.NewField(tableName, "deleted_at") + _mandant.Name = field.NewString(tableName, "name") + _mandant.Logo = field.NewString(tableName, "logo") + _mandant.Plan = field.NewString(tableName, "plan") + _mandant.Color = field.NewString(tableName, "color") + + _mandant.fillFieldMap() + + return _mandant +} + +type mandant struct { + mandantDo + + ALL field.Asterisk + ID field.Uint + CreatedAt field.Time + UpdatedAt field.Time + DeletedAt field.Field + Name field.String + Logo field.String + Plan field.String + Color field.String + + fieldMap map[string]field.Expr +} + +func (m mandant) Table(newTableName string) *mandant { + m.mandantDo.UseTable(newTableName) + return m.updateTableName(newTableName) +} + +func (m mandant) As(alias string) *mandant { + m.mandantDo.DO = *(m.mandantDo.As(alias).(*gen.DO)) + return m.updateTableName(alias) +} + +func (m *mandant) updateTableName(table string) *mandant { + m.ALL = field.NewAsterisk(table) + m.ID = field.NewUint(table, "id") + m.CreatedAt = field.NewTime(table, "created_at") + m.UpdatedAt = field.NewTime(table, "updated_at") + m.DeletedAt = field.NewField(table, "deleted_at") + m.Name = field.NewString(table, "name") + m.Logo = field.NewString(table, "logo") + m.Plan = field.NewString(table, "plan") + m.Color = field.NewString(table, "color") + + m.fillFieldMap() + + return m +} + +func (m *mandant) GetFieldByName(fieldName string) (field.OrderExpr, bool) { + _f, ok := m.fieldMap[fieldName] + if !ok || _f == nil { + return nil, false + } + _oe, ok := _f.(field.OrderExpr) + return _oe, ok +} + +func (m *mandant) fillFieldMap() { + m.fieldMap = make(map[string]field.Expr, 8) + m.fieldMap["id"] = m.ID + m.fieldMap["created_at"] = m.CreatedAt + m.fieldMap["updated_at"] = m.UpdatedAt + m.fieldMap["deleted_at"] = m.DeletedAt + m.fieldMap["name"] = m.Name + m.fieldMap["logo"] = m.Logo + m.fieldMap["plan"] = m.Plan + m.fieldMap["color"] = m.Color +} + +func (m mandant) clone(db *gorm.DB) mandant { + m.mandantDo.ReplaceConnPool(db.Statement.ConnPool) + return m +} + +func (m mandant) replaceDB(db *gorm.DB) mandant { + m.mandantDo.ReplaceDB(db) + return m +} + +type mandantDo struct{ gen.DO } + +type IMandantDo interface { + gen.SubQuery + Debug() IMandantDo + WithContext(ctx context.Context) IMandantDo + WithResult(fc func(tx gen.Dao)) gen.ResultInfo + ReplaceDB(db *gorm.DB) + ReadDB() IMandantDo + WriteDB() IMandantDo + As(alias string) gen.Dao + Session(config *gorm.Session) IMandantDo + Columns(cols ...field.Expr) gen.Columns + Clauses(conds ...clause.Expression) IMandantDo + Not(conds ...gen.Condition) IMandantDo + Or(conds ...gen.Condition) IMandantDo + Select(conds ...field.Expr) IMandantDo + Where(conds ...gen.Condition) IMandantDo + Order(conds ...field.Expr) IMandantDo + Distinct(cols ...field.Expr) IMandantDo + Omit(cols ...field.Expr) IMandantDo + Join(table schema.Tabler, on ...field.Expr) IMandantDo + LeftJoin(table schema.Tabler, on ...field.Expr) IMandantDo + RightJoin(table schema.Tabler, on ...field.Expr) IMandantDo + Group(cols ...field.Expr) IMandantDo + Having(conds ...gen.Condition) IMandantDo + Limit(limit int) IMandantDo + Offset(offset int) IMandantDo + Count() (count int64, err error) + Scopes(funcs ...func(gen.Dao) gen.Dao) IMandantDo + Unscoped() IMandantDo + Create(values ...*model.Mandant) error + CreateInBatches(values []*model.Mandant, batchSize int) error + Save(values ...*model.Mandant) error + First() (*model.Mandant, error) + Take() (*model.Mandant, error) + Last() (*model.Mandant, error) + Find() ([]*model.Mandant, error) + FindInBatch(batchSize int, fc func(tx gen.Dao, batch int) error) (results []*model.Mandant, err error) + FindInBatches(result *[]*model.Mandant, batchSize int, fc func(tx gen.Dao, batch int) error) error + Pluck(column field.Expr, dest interface{}) error + Delete(...*model.Mandant) (info gen.ResultInfo, err error) + Update(column field.Expr, value interface{}) (info gen.ResultInfo, err error) + UpdateSimple(columns ...field.AssignExpr) (info gen.ResultInfo, err error) + Updates(value interface{}) (info gen.ResultInfo, err error) + UpdateColumn(column field.Expr, value interface{}) (info gen.ResultInfo, err error) + UpdateColumnSimple(columns ...field.AssignExpr) (info gen.ResultInfo, err error) + UpdateColumns(value interface{}) (info gen.ResultInfo, err error) + UpdateFrom(q gen.SubQuery) gen.Dao + Attrs(attrs ...field.AssignExpr) IMandantDo + Assign(attrs ...field.AssignExpr) IMandantDo + Joins(fields ...field.RelationField) IMandantDo + Preload(fields ...field.RelationField) IMandantDo + FirstOrInit() (*model.Mandant, error) + FirstOrCreate() (*model.Mandant, error) + FindByPage(offset int, limit int) (result []*model.Mandant, count int64, err error) + ScanByPage(result interface{}, offset int, limit int) (count int64, err error) + Rows() (*sql.Rows, error) + Row() *sql.Row + Scan(result interface{}) (err error) + Returning(value interface{}, columns ...string) IMandantDo + UnderlyingDB() *gorm.DB + schema.Tabler +} + +func (m mandantDo) Debug() IMandantDo { + return m.withDO(m.DO.Debug()) +} + +func (m mandantDo) WithContext(ctx context.Context) IMandantDo { + return m.withDO(m.DO.WithContext(ctx)) +} + +func (m mandantDo) ReadDB() IMandantDo { + return m.Clauses(dbresolver.Read) +} + +func (m mandantDo) WriteDB() IMandantDo { + return m.Clauses(dbresolver.Write) +} + +func (m mandantDo) Session(config *gorm.Session) IMandantDo { + return m.withDO(m.DO.Session(config)) +} + +func (m mandantDo) Clauses(conds ...clause.Expression) IMandantDo { + return m.withDO(m.DO.Clauses(conds...)) +} + +func (m mandantDo) Returning(value interface{}, columns ...string) IMandantDo { + return m.withDO(m.DO.Returning(value, columns...)) +} + +func (m mandantDo) Not(conds ...gen.Condition) IMandantDo { + return m.withDO(m.DO.Not(conds...)) +} + +func (m mandantDo) Or(conds ...gen.Condition) IMandantDo { + return m.withDO(m.DO.Or(conds...)) +} + +func (m mandantDo) Select(conds ...field.Expr) IMandantDo { + return m.withDO(m.DO.Select(conds...)) +} + +func (m mandantDo) Where(conds ...gen.Condition) IMandantDo { + return m.withDO(m.DO.Where(conds...)) +} + +func (m mandantDo) Order(conds ...field.Expr) IMandantDo { + return m.withDO(m.DO.Order(conds...)) +} + +func (m mandantDo) Distinct(cols ...field.Expr) IMandantDo { + return m.withDO(m.DO.Distinct(cols...)) +} + +func (m mandantDo) Omit(cols ...field.Expr) IMandantDo { + return m.withDO(m.DO.Omit(cols...)) +} + +func (m mandantDo) Join(table schema.Tabler, on ...field.Expr) IMandantDo { + return m.withDO(m.DO.Join(table, on...)) +} + +func (m mandantDo) LeftJoin(table schema.Tabler, on ...field.Expr) IMandantDo { + return m.withDO(m.DO.LeftJoin(table, on...)) +} + +func (m mandantDo) RightJoin(table schema.Tabler, on ...field.Expr) IMandantDo { + return m.withDO(m.DO.RightJoin(table, on...)) +} + +func (m mandantDo) Group(cols ...field.Expr) IMandantDo { + return m.withDO(m.DO.Group(cols...)) +} + +func (m mandantDo) Having(conds ...gen.Condition) IMandantDo { + return m.withDO(m.DO.Having(conds...)) +} + +func (m mandantDo) Limit(limit int) IMandantDo { + return m.withDO(m.DO.Limit(limit)) +} + +func (m mandantDo) Offset(offset int) IMandantDo { + return m.withDO(m.DO.Offset(offset)) +} + +func (m mandantDo) Scopes(funcs ...func(gen.Dao) gen.Dao) IMandantDo { + return m.withDO(m.DO.Scopes(funcs...)) +} + +func (m mandantDo) Unscoped() IMandantDo { + return m.withDO(m.DO.Unscoped()) +} + +func (m mandantDo) Create(values ...*model.Mandant) error { + if len(values) == 0 { + return nil + } + return m.DO.Create(values) +} + +func (m mandantDo) CreateInBatches(values []*model.Mandant, batchSize int) error { + return m.DO.CreateInBatches(values, batchSize) +} + +// Save : !!! underlying implementation is different with GORM +// The method is equivalent to executing the statement: db.Clauses(clause.OnConflict{UpdateAll: true}).Create(values) +func (m mandantDo) Save(values ...*model.Mandant) error { + if len(values) == 0 { + return nil + } + return m.DO.Save(values) +} + +func (m mandantDo) First() (*model.Mandant, error) { + if result, err := m.DO.First(); err != nil { + return nil, err + } else { + return result.(*model.Mandant), nil + } +} + +func (m mandantDo) Take() (*model.Mandant, error) { + if result, err := m.DO.Take(); err != nil { + return nil, err + } else { + return result.(*model.Mandant), nil + } +} + +func (m mandantDo) Last() (*model.Mandant, error) { + if result, err := m.DO.Last(); err != nil { + return nil, err + } else { + return result.(*model.Mandant), nil + } +} + +func (m mandantDo) Find() ([]*model.Mandant, error) { + result, err := m.DO.Find() + return result.([]*model.Mandant), err +} + +func (m mandantDo) FindInBatch(batchSize int, fc func(tx gen.Dao, batch int) error) (results []*model.Mandant, err error) { + buf := make([]*model.Mandant, 0, batchSize) + err = m.DO.FindInBatches(&buf, batchSize, func(tx gen.Dao, batch int) error { + defer func() { results = append(results, buf...) }() + return fc(tx, batch) + }) + return results, err +} + +func (m mandantDo) FindInBatches(result *[]*model.Mandant, batchSize int, fc func(tx gen.Dao, batch int) error) error { + return m.DO.FindInBatches(result, batchSize, fc) +} + +func (m mandantDo) Attrs(attrs ...field.AssignExpr) IMandantDo { + return m.withDO(m.DO.Attrs(attrs...)) +} + +func (m mandantDo) Assign(attrs ...field.AssignExpr) IMandantDo { + return m.withDO(m.DO.Assign(attrs...)) +} + +func (m mandantDo) Joins(fields ...field.RelationField) IMandantDo { + for _, _f := range fields { + m = *m.withDO(m.DO.Joins(_f)) + } + return &m +} + +func (m mandantDo) Preload(fields ...field.RelationField) IMandantDo { + for _, _f := range fields { + m = *m.withDO(m.DO.Preload(_f)) + } + return &m +} + +func (m mandantDo) FirstOrInit() (*model.Mandant, error) { + if result, err := m.DO.FirstOrInit(); err != nil { + return nil, err + } else { + return result.(*model.Mandant), nil + } +} + +func (m mandantDo) FirstOrCreate() (*model.Mandant, error) { + if result, err := m.DO.FirstOrCreate(); err != nil { + return nil, err + } else { + return result.(*model.Mandant), nil + } +} + +func (m mandantDo) FindByPage(offset int, limit int) (result []*model.Mandant, count int64, err error) { + result, err = m.Offset(offset).Limit(limit).Find() + if err != nil { + return + } + + if size := len(result); 0 < limit && 0 < size && size < limit { + count = int64(size + offset) + return + } + + count, err = m.Offset(-1).Limit(-1).Count() + return +} + +func (m mandantDo) ScanByPage(result interface{}, offset int, limit int) (count int64, err error) { + count, err = m.Count() + if err != nil { + return + } + + err = m.Offset(offset).Limit(limit).Scan(result) + return +} + +func (m mandantDo) Scan(result interface{}) (err error) { + return m.DO.Scan(result) +} + +func (m mandantDo) Delete(models ...*model.Mandant) (result gen.ResultInfo, err error) { + return m.DO.Delete(models) +} + +func (m *mandantDo) withDO(do gen.Dao) *mandantDo { + m.DO = *do.(*gen.DO) + return m +} diff --git a/routers/ansprechpartner.go b/routers/ansprechpartner.go index 19ee608..0564f2a 100644 --- a/routers/ansprechpartner.go +++ b/routers/ansprechpartner.go @@ -1,8 +1,6 @@ package routers import ( - "log/slog" - "git.kocoder.xyz/kocoded/vt/model" "git.kocoder.xyz/kocoded/vt/query" "git.kocoder.xyz/kocoded/vt/utils" @@ -10,11 +8,11 @@ import ( ) type ansprechpartnerRouter struct { - logger *slog.Logger + utils.Application } -func RegisterAnsprechpartnerRouter(group fiber.Router, appCtx *utils.Application) { - router := &ansprechpartnerRouter{logger: appCtx.Logger} +func RegisterAnsprechpartnerRouter(group fiber.Router, appCtx utils.Application) { + router := &ansprechpartnerRouter{Application: appCtx} group.Post("/new", router.createAnsprechpartner) group.Get("/all", router.getAllAnsprechpartners) diff --git a/routers/firma.go b/routers/firma.go index 454f292..2ff0356 100644 --- a/routers/firma.go +++ b/routers/firma.go @@ -1,8 +1,6 @@ package routers import ( - "log/slog" - "git.kocoder.xyz/kocoded/vt/model" "git.kocoder.xyz/kocoded/vt/query" "git.kocoder.xyz/kocoded/vt/utils" @@ -10,11 +8,11 @@ import ( ) type firmaRouter struct { - logger *slog.Logger + utils.Application } -func RegisterFirmaRouter(group fiber.Router, appCtx *utils.Application) { - router := &firmaRouter{logger: appCtx.Logger} +func RegisterFirmaRouter(group fiber.Router, appCtx utils.Application) { + router := &firmaRouter{Application: appCtx} group.Post("/new", router.createFirma) group.Get("/all", router.getAllFirmen) diff --git a/routers/mandant.go b/routers/mandant.go index ba2936a..a399e83 100644 --- a/routers/mandant.go +++ b/routers/mandant.go @@ -1,54 +1,19 @@ package routers import ( - "log/slog" - "slices" - + "git.kocoder.xyz/kocoded/vt/model" + "git.kocoder.xyz/kocoded/vt/query" "git.kocoder.xyz/kocoded/vt/utils" "github.com/gofiber/fiber/v2" - "github.com/google/uuid" ) -type Mandant struct { - ID string `json:"id"` - Name string `json:"name"` - Logo string `json:"logo"` - Plan string `json:"plan"` - Color string `json:"color"` -} - type mandantRouter struct { - logger *slog.Logger - mandanten []*Mandant - currentMandant *Mandant + utils.Application + currentMandant uint } -func RegisterMandantRouter(group fiber.Router, appCtx *utils.Application) { - mandanten := []*Mandant{ - { - ID: uuid.NewString(), - Name: "Acme Inc", - Logo: "", - Plan: "Enterprise", - Color: "#ff2056", - }, - { - ID: uuid.NewString(), - Name: "Acme Corp.", - Logo: "", - Plan: "Startup", - Color: "#e12afb", - }, - { - ID: uuid.NewString(), - Name: "Evil Corp.", - Logo: "", - Plan: "Free", - Color: "#4f39f6", - }, - } - - router := &mandantRouter{logger: appCtx.Logger, mandanten: mandanten, currentMandant: mandanten[0]} +func RegisterMandantRouter(group fiber.Router, appCtx utils.Application) { + router := &mandantRouter{currentMandant: 1, Application: appCtx} group.Get("/current", router.getCurrentMandant) group.Put("/current", router.setCurrentMandant) @@ -56,24 +21,44 @@ func RegisterMandantRouter(group fiber.Router, appCtx *utils.Application) { } func (r *mandantRouter) getCurrentMandant(c *fiber.Ctx) error { - return c.JSON(r.currentMandant) + m := query.Mandant + + currentMandant, err := m.Where(m.ID.Eq(r.currentMandant)).First() + if err != nil { + r.Logger.Warn("Current mandant not found.", "error", err) + return c.SendStatus(fiber.StatusInternalServerError) + } + + return c.JSON(currentMandant) } func (r *mandantRouter) getAllMandant(c *fiber.Ctx) error { - return c.JSON(r.mandanten) + m := query.Mandant + + mandanten, err := m.Find() + if err != nil { + r.Logger.Warn("Current mandant not found.", "error", err) + return c.SendStatus(fiber.StatusInternalServerError) + } + + return c.JSON(mandanten) } func (r *mandantRouter) setCurrentMandant(c *fiber.Ctx) error { - mandant := &Mandant{} + m := query.Mandant + mandant := &model.Mandant{} if err := c.BodyParser(mandant); err != nil { return err } - mandantId := slices.IndexFunc(r.mandanten, func(m *Mandant) bool { - return m.ID == mandant.ID - }) + r.currentMandant = mandant.ID - r.currentMandant = r.mandanten[mandantId] - return c.JSON(r.currentMandant) + currentMandant, err := m.Where(m.ID.Eq(r.currentMandant)).First() + if err != nil { + r.Logger.Warn("Current mandant not found.", "error", err) + return c.SendStatus(fiber.StatusInternalServerError) + } + + return c.JSON(currentMandant) } diff --git a/types/mandant.go b/types/mandant.go new file mode 100644 index 0000000..4ed287f --- /dev/null +++ b/types/mandant.go @@ -0,0 +1,9 @@ +package types + +type Mandant struct { + // ID string `json:"id"` + Name string `json:"name"` + Logo string `json:"logo"` + Plan string `json:"plan"` + Color string `json:"color"` +} diff --git a/utils/applicationCtx.go b/utils/applicationCtx.go index 8a6c9fd..ad39487 100644 --- a/utils/applicationCtx.go +++ b/utils/applicationCtx.go @@ -2,11 +2,19 @@ package utils import ( "log/slog" + "time" "gorm.io/gorm" ) type Application struct { - Logger *slog.Logger - DB *gorm.DB + Logger *slog.Logger + DB *gorm.DB + ActiveSessions []Session +} + +type Session struct { + Token string + UserID uint + CreatedAt time.Time } diff --git a/utils/authentication.go b/utils/authentication.go index f1ff154..a297302 100644 --- a/utils/authentication.go +++ b/utils/authentication.go @@ -3,11 +3,12 @@ package utils import ( "context" "encoding/json" - "log/slog" "net/http" "os" + "slices" "time" + "git.kocoder.xyz/kocoded/vt/model" "github.com/coreos/go-oidc/v3/oidc" "github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2/middleware/adaptor" @@ -30,10 +31,10 @@ func setCallbackCookieExp(w http.ResponseWriter, r *http.Request, name, value st http.SetCookie(w, c) } -func CreateOIDCClient(ctx context.Context, app *fiber.App, logger *slog.Logger) { +func CreateOIDCClient(ctx context.Context, app *fiber.App, appCtx Application) { provider, err := oidc.NewProvider(ctx, "https://keycloak.kocoder.xyz/realms/che") if err != nil { - logger.Error("Error generating OIDC Provider. ", "error", err) + appCtx.Logger.Error("Error generating OIDC Provider. ", "error", err) } oauthConfig := oauth2.Config{ @@ -49,7 +50,7 @@ func CreateOIDCClient(ctx context.Context, app *fiber.App, logger *slog.Logger) app.Get("/api/auth", adaptor.HTTPHandlerFunc(func(w http.ResponseWriter, r *http.Request) { state, err := RandString(16) if err != nil { - logger.Warn("Unable to create a state", "error", err) + appCtx.Logger.Warn("Unable to create a state", "error", err) http.Error(w, "Unable to create a state", http.StatusInternalServerError) } @@ -61,27 +62,27 @@ func CreateOIDCClient(ctx context.Context, app *fiber.App, logger *slog.Logger) app.Get("/api/auth/callback", adaptor.HTTPHandlerFunc(func(w http.ResponseWriter, r *http.Request) { state, err := r.Cookie("state") if err != nil { - logger.Warn("State cookie not found", "error", err) + appCtx.Logger.Warn("State cookie not found", "error", err) http.Error(w, "state not found", http.StatusBadRequest) return } if r.URL.Query().Get("state") != state.Value { - logger.Warn("State cookie and header not matching", "error", err) + appCtx.Logger.Warn("State cookie and header not matching", "error", err) http.Error(w, "states not matching", http.StatusBadRequest) return } oauth2Token, err := oauthConfig.Exchange(ctx, r.URL.Query().Get("code")) if err != nil { - logger.Warn("Failed to exchange token", "error", err) + appCtx.Logger.Warn("Failed to exchange token", "error", err) http.Error(w, "Failed to exchange token: "+err.Error(), http.StatusInternalServerError) return } userInfo, err := provider.UserInfo(ctx, oauth2.StaticTokenSource(oauth2Token)) if err != nil { - logger.Warn("failed to get userinfo", "error", err) + appCtx.Logger.Warn("failed to get userinfo", "error", err) http.Error(w, "Failed to get userinfo: "+err.Error(), http.StatusInternalServerError) return } @@ -93,16 +94,47 @@ func CreateOIDCClient(ctx context.Context, app *fiber.App, logger *slog.Logger) data, err := json.MarshalIndent(resp, "", " ") if err != nil { - logger.Warn("Failed to parse JSON", "error", err) + appCtx.Logger.Warn("Failed to parse JSON", "error", err) http.Error(w, err.Error(), http.StatusInternalServerError) return } + user := &model.User{} + if appCtx.DB.Where(model.User{Email: resp.UserInfo.Email}).Assign(model.User{Sub: resp.UserInfo.Subject}).FirstOrCreate(user).Error != nil { + appCtx.Logger.Warn("Failed to create user in DB") + http.Error(w, "failed to create user", http.StatusInternalServerError) + return + } + setCallbackCookieExp(w, r, "state", "", -1) + cookie, err := RandString(24) + if err != nil { + appCtx.Logger.Warn("Couldn't generate session-cookie.") + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + setCallbackCookieExp(w, r, "auth-cookie", cookie, int(time.Hour.Seconds())) + appCtx.ActiveSessions = append(appCtx.ActiveSessions, Session{Token: cookie, UserID: user.ID, CreatedAt: time.Now()}) + + http.Redirect(w, r, "http://localhost:3001", http.StatusFound) + _, err = w.Write(data) if err != nil { - logger.Error("Unable to send response", "error", err) + appCtx.Logger.Error("Unable to send response", "error", err) } })) + + app.Get("/api/auth/currentSession", func(c *fiber.Ctx) error { + authToken := c.Cookies("auth-cookie") + + sessionId := slices.IndexFunc(appCtx.ActiveSessions, func(s Session) bool { + return s.Token == authToken + }) + + session := appCtx.ActiveSessions[sessionId] + + return c.JSON(session) + }) } diff --git a/utils/db.go b/utils/db.go index 9f09f32..39c0520 100644 --- a/utils/db.go +++ b/utils/db.go @@ -24,7 +24,7 @@ func SetupDatabase(dsn string, logger *slog.Logger) *gorm.DB { if err != nil { logger.Error("Error setting up Join Tables", "error", err) } - err = db.AutoMigrate(&model.Ansprechpartner{}, &model.FirmaAnsprechpartner{}, &model.Firma{}) + err = db.AutoMigrate(&model.Mandant{}, &model.User{}, &model.Ansprechpartner{}, &model.FirmaAnsprechpartner{}, &model.Firma{}) if err != nil { logger.Error("Error setting up Join Tables", "error", err) } @@ -38,7 +38,7 @@ func SetupDatabase(dsn string, logger *slog.Logger) *gorm.DB { g.UseDB(db) // reuse your gorm db // Generate basic type-safe DAO API for struct `model.User` following conventions - g.ApplyBasic(model.Ansprechpartner{}, model.Dokument{}, model.Firma{}, model.Kalender{}, model.Kalendereintrag{}, model.Kostenstelle{}, model.Lager{}, model.Lagerplatz{}, model.Material{}, model.Nachricht{}, model.Projekt{}, model.Rechnung{}, model.Rechnungsposition{}, model.Scanobject{}, model.User{}, model.Zahlung{}, model.FirmaAnsprechpartner{}) + g.ApplyBasic(model.Mandant{}, model.User{}, model.Ansprechpartner{}, model.Dokument{}, model.Firma{}, model.Kalender{}, model.Kalendereintrag{}, model.Kostenstelle{}, model.Lager{}, model.Lagerplatz{}, model.Material{}, model.Nachricht{}, model.Projekt{}, model.Rechnung{}, model.Rechnungsposition{}, model.Scanobject{}, model.User{}, model.Zahlung{}, model.FirmaAnsprechpartner{}) // Generate the code g.Execute() diff --git a/utils/middleware.go b/utils/middleware.go index 5430500..35e1042 100644 --- a/utils/middleware.go +++ b/utils/middleware.go @@ -18,7 +18,10 @@ func RegisterMiddlewares(app *fiber.App) { app.Use(requestid.New()) app.Use(compress.New()) app.Use(helmet.New()) - app.Use(cors.New()) + app.Use(cors.New(cors.Config{ + AllowOrigins: "http://localhost:3000, http://localhost:3001", + AllowCredentials: true, + })) // app.Use(csrf.New()) // app.Use(healthcheck.New(healthcheck.Config{})) app.Use(idempotency.New())